【RL工具类】强化学习常用函数工具类(Python代码)

简介: 【RL工具类】强化学习常用函数工具类(Python代码)

@[toc]


一、注意事项

  • 设置中文字体,注意需要根据自己电脑情况更改字体路径,否则可能会报错

二、代码

# -*-coding:utf-8-*-

import os
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import json
import random
import torch
import pandas as pd

from matplotlib.font_manager import FontProperties  # 导入字体模块


# 设置中文字体,注意需要根据自己电脑情况更改字体路径,否则还是默认的字体
def chinese_font():
    try:
        font = FontProperties(
            # 系统字体路径
            fname='C:\\Windows\\Fonts\\方正粗黑宋简体.ttf', size=14)
    except:
        font = None
    return font


# 中文画图
def plot_rewards_cn(rewards, cfg, path=None, tag='train'):
    sns.set()
    plt.figure()
    plt.title(u"{}环境下{}算法的学习曲线".format(cfg['env_name'],
                                       cfg['algo_name']), fontproperties=chinese_font())
    plt.xlabel(u'回合数', fontproperties=chinese_font())
    plt.plot(rewards)
    plt.plot(smooth(rewards))
    plt.legend(('奖励', '滑动平均奖励',), loc="best", prop=chinese_font())
    if cfg['save_fig']:
        plt.savefig(f"{path}/{tag}ing_curve_cn.png")
    if cfg['show_fig']:
        plt.show()


# 用于平滑曲线,类似于Tensorboard中的smooth
def smooth(data, weight=0.9):
    '''
    Args:
        data (List):输入数据
        weight (Float): 平滑权重,处于0-1之间,数值越高说明越平滑,一般取0.9

    Returns:
        smoothed (List): 平滑后的数据
    '''
    last = data[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in data:
        smoothed_val = last * weight + (1 - weight) * point  # 计算平滑值
        smoothed.append(smoothed_val)
        last = smoothed_val
    return smoothed


def plot_rewards(rewards, cfg, path=None, tag='train'):
    sns.set()
    plt.figure()  # 创建一个图形实例,方便同时多画几个图
    plt.title(f"{tag}ing curve on {cfg['device']} of {cfg['algo_name']} for {cfg['env_name']}")
    plt.xlabel('epsiodes')
    plt.plot(rewards, label='rewards')
    plt.plot(smooth(rewards), label='smoothed')
    plt.legend()
    if cfg['save_fig']:
        plt.savefig(f"{path}/{tag}ing_curve.png")
    if cfg['show_fig']:
        plt.show()


def plot_losses(losses, algo="DQN", save=True, path='./'):
    sns.set()
    plt.figure()
    plt.title("loss curve of {}".format(algo))
    plt.xlabel('epsiodes')
    plt.plot(losses, label='rewards')
    plt.legend()
    if save:
        plt.savefig(path + "losses_curve")
    plt.show()


# 保存奖励
def save_results(res_dic, tag='train', path=None):
    '''
    '''
    Path(path).mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame(res_dic)
    df.to_csv(f"{path}/{tag}ing_results.csv", index=None)
    print('结果已保存: ' + f"{path}/{tag}ing_results.csv")


# 创建文件夹
def make_dir(*paths):
    for path in paths:
        Path(path).mkdir(parents=True, exist_ok=True)


# 删除目录下所有空文件夹
def del_empty_dir(*paths):
    for path in paths:
        dirs = os.listdir(path)
        for dir in dirs:
            if not os.listdir(os.path.join(path, dir)):
                os.removedirs(os.path.join(path, dir))


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


# 保存参数
def save_args(args, path=None):
    Path(path).mkdir(parents=True, exist_ok=True)
    with open(f"{path}/params.json", 'w') as fp:
        json.dump(args, fp, cls=NpEncoder)
    print("参数已保存: " + f"{path}/params.json")


# 为所有随机因素设置一个统一的种子
def all_seed(env, seed=520):
    # 环境种子设置
    env.seed(seed)
    # numpy随机数种子设置
    np.random.seed(seed)
    # python自带随机数种子设置
    random.seed(seed)
    # CPU种子设置
    torch.manual_seed(seed)
    # GPU种子设置
    torch.cuda.manual_seed(seed)
    # python scripts种子设置
    os.environ['PYTHONHASHSEED'] = str(seed)
    # cudnn的配置
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
目录
相关文章
|
7天前
|
数据库连接 开发者 Python
Python进阶宝典:十个实用技巧提升代码效率
Python进阶宝典:十个实用技巧提升代码效率
16 0
|
7天前
|
数据采集 数据格式 Python
享一些可以提高数据采集准确性的 Python 代码
这段Python代码示例提供了几个实用功能以提升数据采集的准确性:数据源验证、去除重复值、数据范围检查和数据格式验证。通过这些工具,可以确保所采集的数据在合理范围内且格式正确,有效提高了数据的质量。示例展示了如何使用这些功能进行数据清理与验证。
|
2天前
|
开发工具 git Python
通过Python脚本git pull 自动重试拉取代码
通过Python脚本git pull 自动重试拉取代码
83 4
|
4天前
|
对象存储 Python
Python代码解读-理解-定义一个User类的基本写法
以上描述清晰地阐述了如何在Python中定义 `User`类的基本方法以及如何创建和使用该类的实例。这是面向对象编程中的核心概念,是紧密结合抽象和实现,封装数据并提供操作数据的接口。由于用简单通用的语言易于理解,这样的解释对于初学者而言应该是友好且有帮助的。
13 4
|
2天前
|
Shell Python 容器
Python模块是其代码组织和重用的基本方式。
【8月更文挑战第18天】Python模块是其代码组织和重用的基本方式。
7 1
|
4天前
|
存储 缓存 算法
Python中的hash函数
Python中的hash函数
|
6天前
|
Python
Python学习笔记---函数
这篇文章是一份Python函数学习的笔记,涵盖了使用函数的优势、内置函数的调用、自定义函数的定义、函数参数的不同类型(必须参数、关键字参数、默认参数、可变参数)、有返回值和无返回值的函数、形参和实参、变量作用域、返回函数、递归函数、匿名函数、偏函数以及输入和输出函数等多个函数相关的主题。
|
6天前
|
Python
安装notepad++ 安装Python Python环境变量的数值。怎样在notepad++上运行Python的代码
这篇文章提供了在notepad++上安装和配置Python环境的详细步骤,包括安装Python、配置环境变量、在notepad++中设置Python语言和快捷编译方式,以及解决可能遇到的一些问题。
安装notepad++ 安装Python Python环境变量的数值。怎样在notepad++上运行Python的代码
|
4天前
|
Python
Python生成Thinkphp6代码工具类
Python生成Thinkphp6代码工具类
8 0
|
4天前
|
Python
Python常用工具类
Python常用工具类
6 0