【技术分享】强化学习中使用seaborn绘制带有均值Reward的图片

简介: 【技术分享】强化学习中使用seaborn绘制带有均值Reward的图片

1.Seaborn介绍

matplotlib是python最常见的绘图包,强大之处不言而喻。然而在数据科学领域,可视化库Seaborn也是重量级的存在。由于matplotlib比较底层,想要绘制漂亮的图非常麻烦,需要写大量的代码。


Seaborn是在matplotlib基础上进行了高级API封装,图表装饰更加容易,你可以用更少的代码做出更美观的图。同时,Seaborn高度兼容了numy、pandas、scipy等库,使得数据可视化更加方便快捷。


2.Seaborn绘图代码

代码:

import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
def get_data():
    '''
    获取数据
    '''
    baseline1 = np.array([[18, 20, 19, 18, 13, 4, 1],[20, 17, 12, 9, 3, 0, 0],[20, 20, 20, 12, 5, 3, 0]])
    algorithm1 = np.array([[18, 19, 18, 19, 20, 15, 14],[19, 20, 18, 16, 20, 15, 9],[19, 20, 20, 20, 17, 10, 0]]) 
    algorithm2 = np.array([[20, 20, 20, 20, 19, 17, 4],[20, 20, 20, 20, 20, 19, 7],[19, 20, 20, 19, 19, 15, 2]]) 
    algorithm3 = np.array([[20, 20, 20, 20, 19, 17, 12],[18, 20, 19, 18, 13, 4, 1], [20, 19, 18, 17, 13, 2, 0]])    
    return baseline1, algorithm1, algorithm2, algorithm3
data = get_data()
label = ['algo1', 'algo2', 'algo3', 'algo4']
df=[]
for i in range(len(data)):
    df.append(pd.DataFrame(data[i]).melt(var_name='episode',value_name='loss'))
    df[i]['algo']= label[i] 
df=pd.concat(df) # 合并
plt.figure(figsize=(10, 6))
sns.lineplot(x="episode", y="loss", hue="algo", style="algo",data=df)
plt.title("algorithm loss")
plt.show()

绘图结果:

f511d7715636637d45ef05c44e2ceb3c_1354b518cfa441e58699696e5bc0cfec.png


目录
相关文章
|
9月前
|
算法 数据可视化 数据挖掘
使用Python实现K均值聚类算法
使用Python实现K均值聚类算法
84 1
|
9月前
|
PyTorch 算法框架/工具 Python
Python 量化投资(一):滑动均值、布林带、MACD、RSI、KDJ、OBV
Python 量化投资(一):滑动均值、布林带、MACD、RSI、KDJ、OBV
173 0
|
9月前
|
算法 计算机视觉 Python
OpenCV均值、中值滤波器的讲解及实战应用(附Python源码)
OpenCV均值、中值滤波器的讲解及实战应用(附Python源码)
868 0
|
机器学习/深度学习 自然语言处理 算法
【机器学习实战】10分钟学会Python怎么用K均值K-means进行聚类(九)
【机器学习实战】10分钟学会Python怎么用K均值K-means进行聚类(九)
292 0
|
算法 数据挖掘 Python
AIGC背后的技术分析 | K均值聚类算法Python实现
本篇介绍K均值聚类算法实现。
197 0
|
数据可视化 Python
python移动窗口求股票预测误差均值
python移动窗口求股票预测误差均值
123 0
python移动窗口求股票预测误差均值
Python, Numpy求 list 数组均值,方差,标准差
Python, Numpy求 list 数组均值,方差,标准差
|
资源调度 Python
Python:怎么画出均值和置信区间的图
在统计学上,置信区间是从已观测到的数据中统计出来的一个估计。它给出了未知参数可能落在的区域。而通俗的讲,就是我们去估计一个参数(大部分情况是一个平均值或期望),但是估计一定会有误差,所以置信区间就告诉我们,这个平均值的误差范围。
571 0
Python:怎么画出均值和置信区间的图
|
机器学习/深度学习 算法 数据挖掘
100天搞定机器学习|day44 k均值聚类数学推导与python实现
100天搞定机器学习|day44 k均值聚类数学推导与python实现
100天搞定机器学习|day44 k均值聚类数学推导与python实现
|
Go Python
CSP 202104-2 邻域均值 python 二维前缀和
CSP 202104-2 邻域均值 python 二维前缀和
CSP 202104-2 邻域均值 python 二维前缀和

热门文章

最新文章