开发者社区> 问答> 正文

如何加速航路热图

与searborn facetgrid热图问题斗争缓慢。我已经扩展了以前的问题的数据集,感谢@Diziet Asahi提供了facetgrid问题的解决方案。 现在,我有20x20的网格,每个网格中有625个点需要映射。即使是一个层little1,也需要很长时间才能得到输出。我有成千上万个真实数据的小层。 我的代码如下:

import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap

print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}

ltt= ['little1']

methods=["m" + str(i) for i in range(1,21)]
labels=["l" + str(i) for i in range(1,20)]

times = range(0,100,4)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
data['nw_score'] = np.random.choice([0,1],data.shape[0])

数据输出:

Out[36]: 
            ltt method labels  dtsi  rtsi  nw_score
0       little1     m1     l1     0     0         1
1       little1     m1     l1     0     4         0
2       little1     m1     l1     0     8         0
3       little1     m1     l1     0    12         1
4       little1     m1     l1     0    16         0
        ...    ...    ...   ...   ...       ...
237495  little1    m20    l19    96    80         0
237496  little1    m20    l19    96    84         1
237497  little1    m20    l19    96    88         0
237498  little1    m20    l19    96    92         0
237499  little1    m20    l19    96    96         1

[237500 rows x 6 columns]

绘制和定义刻面函数:

labels_fill = {0:'red',1:'blue'}

del methods
del labels

def facet(data,color):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)

for lt in data.ltt.unique():
    with sns.plotting_context(font_scale=5.5):
        g = sns.FacetGrid(data[data.ltt==lt],row="labels", col="method", size=2, aspect=1,margin_titles=False)
        g = g.map_dataframe(facet)
        g.add_legend()
        g.set_titles(template="")

        for ax,method in zip(g.axes[0,:],data.method.unique()):
            ax.set_title(method, fontweight='bold', fontsize=12)
        for ax,label in zip(g.axes[:,0],data.labels.unique()):
            ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
        g.fig.suptitle(lt, fontweight='bold', fontsize=12)
        g.fig.tight_layout()
        g.fig.subplots_adjust(top=0.8) # make some room for the title

        g.savefig(lt+'.png', dpi=300)

我在一段时间后停止了代码,我们可以看到网格正在一个接一个地填充,这非常耗时。生成这张热图的速度慢得令人难以忍受。 我想知道有没有更好的方法来加快这个过程? 提前谢谢! 问题来源StackOverflow 地址:/questions/59380853/how-to-speed-up-seaborn-heatmaps

展开
收起
kun坤 2019-12-28 14:08:51 455 0
1 条回答
写回答
取消 提交回答
  • Seaborn是缓慢的。 如果您使用的是matplotlib而不是seaborn,则每个图大约需要半分钟。这仍然很长,但是考虑到你产生了一个~12000x12000像素的图形,这是一种期望。

    import time
    import pandas as pd
    import numpy as np
    import itertools
    import seaborn as sns
    from matplotlib.colors import ListedColormap
    import matplotlib.pyplot as plt
    
    print("seaborn version {}".format(sns.__version__))
    # R expand.grid() function in Python
    # https://stackoverflow.com/a/12131385/1135316
    def expandgrid(*itrs):
       product = list(itertools.product(*itrs))
       return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}
    
    ltt= ['little1']
    
    methods=["m" + str(i) for i in range(1,21)]
    #methods=['method 1', 'method 2', 'method 3', 'method 4']
    #labels = ['label1','label2']
    labels=["l" + str(i) for i in range(1,20)]
    
    times = range(0,100,4)
    data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
    data.columns = ['ltt','method','labels','dtsi','rtsi']
    #data['nw_score'] = np.random.sample(data.shape[0])
    data['nw_score'] = np.random.choice([0,1],data.shape[0])
    
    labels_fill = {0:'red',1:'blue'}
    
    del methods
    del labels
    
    
    cmap=ListedColormap(['red', 'blue'])
    
    def facet(data, ax):
        data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
        ax.imshow(data, cmap=cmap)
    
    def myfacetgrid(data, row, col, figure=None):
        rows = np.unique(data[row].values)  
        cols = np.unique(data[col].values)
    
        fig, axs = plt.subplots(len(rows), len(cols), 
                                figsize=(2*len(cols)+1, 2*len(rows)+1))
    
    
        for i, r in enumerate(rows):
            row_data = data[data[row] == r]
            for j, c in enumerate(cols):
                this_data = row_data[row_data[col] == c]
                facet(this_data, axs[i,j])
        return fig, axs
    
    
    for lt in data.ltt.unique():
    
        with sns.plotting_context(font_scale=5.5):
            t = time.time()
            fig, axs = myfacetgrid(data[data.ltt==lt], row="labels", col="method")
            print(time.time()-t)
            for ax,method in zip(axs[0,:],data.method.unique()):
                ax.set_title(method, fontweight='bold', fontsize=12)
            for ax,label in zip(axs[:,0],data.labels.unique()):
                ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
            print(time.time()-t)
            fig.suptitle(lt, fontweight='bold', fontsize=12)
            fig.tight_layout()
            fig.subplots_adjust(top=0.8) # make some room for the title
            print(time.time()-t)
            fig.savefig(lt+'.png', dpi=300)
            print(time.time()-t)
    

    在这里,计时分为~6秒创建facetgrid, ~7秒优化网格布局(通过tight_layout—请考虑去掉它!)和15秒绘制图形。

    2019-12-28 14:09:00
    赞同 展开评论 打赏
问答分类:
问答地址:
问答排行榜
最热
最新

相关电子书

更多
FindPixel三维实景重建云上实践 立即下载
驾驭时空中国⾸辆⾃动驾驶低速电动⻋发布 立即下载
液体活检,精准分析 立即下载