PyTorch:一篇使用技巧汇总

简介: 一些常用的基础操作

设定 tensor 默认的 dtype:torch.set_default_tensor_type(torch.DoubleTensor)

Pytorch 有八个类型:

Daya type dtype Tensor types
32-bit 浮点 torch.float32 or torch.float torch.*.FloatTensor
64-bit 浮点 torch.float64 or torch.double torch.*.DoubleTensor
16-bit 浮点 torch.float16 or torch.half torch.*.HalfTensor
8-bit 整型(无符号) torch.uint8 torch.*.ByteTensor
8-bit 整型(有符号) torch.int8 torch.*.CharTensor
16-bit 整型(有符号) torch.int16 or torch.short torch.*.ShortTensor
32-bit 整型(有符号) torch.int32 or torch.int torch.*.IntTensor
64-bit 整型(有符号) torch.int64 or torch.long torch.*.LongTensor

保存模型:

def save_checkpoint(model, optimizer, scheduler, save_path):
    # 如果还有其它变量想要保存,也可以添加
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }, save_path)

# 加载模型
checkpoint = torch.load(pretrain_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
...

打印模型的梯度:

# 打印梯度
for name, parameters in model.named_parameters():
    print('{}\'s grad is:\n{}\n'.format(name, parameters.grad))

使用梯度衰减策略:

# 指数衰减
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
# 阶梯衰减
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
# 自定义间隔衰减
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400], gamma=0.5)

梯度截断:

def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.

    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group["params"]:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

自定义激活函数示例:

class OutExp(nn.Module):
    def __init__(self):
        super(OutExp, self).__init__()

    def forward(self, x):
        x = -torch.exp(x)
        return x

修改模型某一层参数:nn.Parameter()

# 修改第 2 层的 bias(`layer` 是模型定义时给的名称)
model.layer[2].bias = nn.Parameter(torch.tensor([-0.01, -0.4], device=device, requires_grad=True))

模型参数初始化:

# 自定义权重初始化
def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=0.1)
        nn.init.constant_(m.bias, 0)
    # 也可以判断是否为 conv2d,使用相应的初始化方式
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    # 是否为批归一化层
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

# 模型应用函数
model.apply(weight_init)
目录
相关文章
|
3月前
|
人工智能 自然语言处理 大数据
互联网医院智能导诊系统的技术实现原理
互联网医院智能导诊系统利用人工智能与大数据技术,通过自然语言处理、医学知识图谱、多模态交互等技术,实现患者症状的智能识别与科室匹配,提升挂号效率与准确率,优化就医流程。
159 10
|
3月前
|
数据采集 监控 调度
干货分享“用 多线程 爬取数据”:单线程 + 协程的效率反超 3 倍,这才是 Python 异步的正确打开方式
在 Python 爬虫中,多线程因 GIL 和切换开销效率低下,而协程通过用户态调度实现高并发,大幅提升爬取效率。本文详解协程原理、实战对比多线程性能,并提供最佳实践,助你掌握异步爬虫核心技术。
|
11月前
|
安全 网络安全 数据安全/隐私保护
SSL/TLS证书**是一种用于加密网络通信的数字证书
SSL/TLS证书**是一种用于加密网络通信的数字证书
485 6
|
5月前
|
人工智能 运维 网络安全
重构门店网络:从“打补丁“到“造地基“的跨越
传统网络架构正在威胁门店数字化转型,其“三大致命矛盾”架构老化、业务爆发、新兴技术卡壳等问题日益严重。传统网络的“人肉运维”模式效率低下,人肉容灾能力不足。随着云化需求的增加,传统网络架构无法适配云计算、AI应用等新兴技术,云化受阻。
|
机器学习/深度学习 数据可视化 TensorFlow
深入探索TensorBoard:使用可视化工具提升模型调试与优化的效率和效果
【8月更文挑战第31天】在深度学习领域,理解和优化复杂的神经网络模型充满挑战。TensorBoard作为TensorFlow的强大可视化工具,能帮助我们清晰地展示模型结构、激活值、损失函数变化等关键信息,从而更高效地调试和优化模型。
371 0
|
应用服务中间件 Apache
Tomcat国内镜像下载地址【速度超快】
Tomcat国内镜像下载地址【速度超快】
3477 0
|
Prometheus Kubernetes Cloud Native
Flagger(应用自动发布)介绍和原理剖析
## 简介 [Flagger](https://github.com/weaveworks/flagger)是一个能使运行在k8s体系上的应用发布流程全自动(无人参与)的工具, 它能减少发布的人为关注时间, 并且在发布过程中能自动识别一些风险(例如:RT,成功率,自定义metrics)并回滚. ## 主要特性 ![features](https://intranetproxy.ali
4834 0
@RequestMapping详解
在我们的Java web开发中也有一个同样神奇的法宝,可以为我们节省好多时间和代码,从而实现浏览器与服务器之间的映射,它就是——RequestMapping注解,下面我们一起来了解一下吧。
600 0
@RequestMapping详解
|
Java 容器
sprigboot中过滤器执行顺序源码解读
本文主要是搞清楚对于同一请求在springboot项目中自定义的filter和jar包中的filter的执行顺序是如何指定的。
sprigboot中过滤器执行顺序源码解读
|
机器学习/深度学习 人工智能 计算机视觉
深度学习经典网络解析图像分类篇(七):ResNet
 如果说你对深度学习略有了解,那你一定听过大名鼎鼎的ResNet,正所谓ResNet 一出,谁与争锋?现如今2022年,依旧作为各大CV任务的backbone,比如ResNet-50、ResNet-101等。ResNet是2015年的ImageNet大规模视觉识别竞赛(ImageNet Large Scale Visual Recognition Challenge, ILSVRC)中获得了图像分类和物体识别的冠军,是中国人何恺明、张祥雨、任少卿、孙剑在微软亚洲研究院(AI黄埔军校)的研究成果。
724 0