【RL工具类】强化学习常用函数工具类(Python代码)

简介: 【RL工具类】强化学习常用函数工具类(Python代码)

@[toc]


一、注意事项

  • 设置中文字体,注意需要根据自己电脑情况更改字体路径,否则可能会报错

二、代码

# -*-coding:utf-8-*-

import os
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import json
import random
import torch
import pandas as pd

from matplotlib.font_manager import FontProperties  # 导入字体模块


# 设置中文字体,注意需要根据自己电脑情况更改字体路径,否则还是默认的字体
def chinese_font():
    try:
        font = FontProperties(
            # 系统字体路径
            fname='C:\\Windows\\Fonts\\方正粗黑宋简体.ttf', size=14)
    except:
        font = None
    return font


# 中文画图
def plot_rewards_cn(rewards, cfg, path=None, tag='train'):
    sns.set()
    plt.figure()
    plt.title(u"{}环境下{}算法的学习曲线".format(cfg['env_name'],
                                       cfg['algo_name']), fontproperties=chinese_font())
    plt.xlabel(u'回合数', fontproperties=chinese_font())
    plt.plot(rewards)
    plt.plot(smooth(rewards))
    plt.legend(('奖励', '滑动平均奖励',), loc="best", prop=chinese_font())
    if cfg['save_fig']:
        plt.savefig(f"{path}/{tag}ing_curve_cn.png")
    if cfg['show_fig']:
        plt.show()


# 用于平滑曲线,类似于Tensorboard中的smooth
def smooth(data, weight=0.9):
    '''
    Args:
        data (List):输入数据
        weight (Float): 平滑权重,处于0-1之间,数值越高说明越平滑,一般取0.9

    Returns:
        smoothed (List): 平滑后的数据
    '''
    last = data[0]  # First value in the plot (first timestep)
    smoothed = list()
    for point in data:
        smoothed_val = last * weight + (1 - weight) * point  # 计算平滑值
        smoothed.append(smoothed_val)
        last = smoothed_val
    return smoothed


def plot_rewards(rewards, cfg, path=None, tag='train'):
    sns.set()
    plt.figure()  # 创建一个图形实例,方便同时多画几个图
    plt.title(f"{tag}ing curve on {cfg['device']} of {cfg['algo_name']} for {cfg['env_name']}")
    plt.xlabel('epsiodes')
    plt.plot(rewards, label='rewards')
    plt.plot(smooth(rewards), label='smoothed')
    plt.legend()
    if cfg['save_fig']:
        plt.savefig(f"{path}/{tag}ing_curve.png")
    if cfg['show_fig']:
        plt.show()


def plot_losses(losses, algo="DQN", save=True, path='./'):
    sns.set()
    plt.figure()
    plt.title("loss curve of {}".format(algo))
    plt.xlabel('epsiodes')
    plt.plot(losses, label='rewards')
    plt.legend()
    if save:
        plt.savefig(path + "losses_curve")
    plt.show()


# 保存奖励
def save_results(res_dic, tag='train', path=None):
    '''
    '''
    Path(path).mkdir(parents=True, exist_ok=True)
    df = pd.DataFrame(res_dic)
    df.to_csv(f"{path}/{tag}ing_results.csv", index=None)
    print('结果已保存: ' + f"{path}/{tag}ing_results.csv")


# 创建文件夹
def make_dir(*paths):
    for path in paths:
        Path(path).mkdir(parents=True, exist_ok=True)


# 删除目录下所有空文件夹
def del_empty_dir(*paths):
    for path in paths:
        dirs = os.listdir(path)
        for dir in dirs:
            if not os.listdir(os.path.join(path, dir)):
                os.removedirs(os.path.join(path, dir))


class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


# 保存参数
def save_args(args, path=None):
    Path(path).mkdir(parents=True, exist_ok=True)
    with open(f"{path}/params.json", 'w') as fp:
        json.dump(args, fp, cls=NpEncoder)
    print("参数已保存: " + f"{path}/params.json")


# 为所有随机因素设置一个统一的种子
def all_seed(env, seed=520):
    # 环境种子设置
    env.seed(seed)
    # numpy随机数种子设置
    np.random.seed(seed)
    # python自带随机数种子设置
    random.seed(seed)
    # CPU种子设置
    torch.manual_seed(seed)
    # GPU种子设置
    torch.cuda.manual_seed(seed)
    # python scripts种子设置
    os.environ['PYTHONHASHSEED'] = str(seed)
    # cudnn的配置
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False
目录
相关文章
|
1月前
|
存储 JavaScript Java
(Python基础)新时代语言!一起学习Python吧!(四):dict字典和set类型;切片类型、列表生成式;map和reduce迭代器;filter过滤函数、sorted排序函数;lambda函数
dict字典 Python内置了字典:dict的支持,dict全称dictionary,在其他语言中也称为map,使用键-值(key-value)存储,具有极快的查找速度。 我们可以通过声明JS对象一样的方式声明dict
158 1
|
1月前
|
算法 Java Docker
(Python基础)新时代语言!一起学习Python吧!(三):IF条件判断和match匹配;Python中的循环:for...in、while循环;循环操作关键字;Python函数使用方法
IF 条件判断 使用if语句,对条件进行判断 true则执行代码块缩进语句 false则不执行代码块缩进语句,如果有else 或 elif 则进入相应的规则中执行
243 1
|
1月前
|
Java 数据处理 索引
(numpy)Python做数据处理必备框架!(二):ndarray切片的使用与运算;常见的ndarray函数:平方根、正余弦、自然对数、指数、幂等运算;统计函数:方差、均值、极差;比较函数...
ndarray切片 索引从0开始 索引/切片类型 描述/用法 基本索引 通过整数索引直接访问元素。 行/列切片 使用冒号:切片语法选择行或列的子集 连续切片 从起始索引到结束索引按步长切片 使用slice函数 通过slice(start,stop,strp)定义切片规则 布尔索引 通过布尔条件筛选满足条件的元素。支持逻辑运算符 &、|。
137 0
|
1月前
|
测试技术 Python
Python装饰器:为你的代码施展“魔法”
Python装饰器:为你的代码施展“魔法”
232 100
|
1月前
|
开发者 Python
Python列表推导式:一行代码的艺术与力量
Python列表推导式:一行代码的艺术与力量
337 95
|
2月前
|
设计模式 缓存 监控
Python装饰器:优雅增强函数功能
Python装饰器:优雅增强函数功能
264 101
|
2月前
|
Python
Python的简洁之道:5个让代码更优雅的技巧
Python的简洁之道:5个让代码更优雅的技巧
229 104
|
2月前
|
开发者 Python
Python神技:用列表推导式让你的代码更优雅
Python神技:用列表推导式让你的代码更优雅
423 99
|
1月前
|
缓存 Python
Python装饰器:为你的代码施展“魔法
Python装饰器:为你的代码施展“魔法
149 88
|
1月前
|
监控 机器人 编译器
如何将python代码打包成exe文件---PyInstaller打包之神
PyInstaller可将Python程序打包为独立可执行文件,无需用户安装Python环境。它自动分析代码依赖,整合解释器、库及资源,支持一键生成exe,方便分发。使用pip安装后,通过简单命令即可完成打包,适合各类项目部署。

推荐镜像

更多