PyTorch:一篇使用技巧汇总

简介: 一些常用的基础操作

设定 tensor 默认的 dtype:torch.set_default_tensor_type(torch.DoubleTensor)

Pytorch 有八个类型:

Daya type dtype Tensor types
32-bit 浮点 torch.float32 or torch.float torch.*.FloatTensor
64-bit 浮点 torch.float64 or torch.double torch.*.DoubleTensor
16-bit 浮点 torch.float16 or torch.half torch.*.HalfTensor
8-bit 整型(无符号) torch.uint8 torch.*.ByteTensor
8-bit 整型(有符号) torch.int8 torch.*.CharTensor
16-bit 整型(有符号) torch.int16 or torch.short torch.*.ShortTensor
32-bit 整型(有符号) torch.int32 or torch.int torch.*.IntTensor
64-bit 整型(有符号) torch.int64 or torch.long torch.*.LongTensor

保存模型:

def save_checkpoint(model, optimizer, scheduler, save_path):
    # 如果还有其它变量想要保存,也可以添加
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
    }, save_path)

# 加载模型
checkpoint = torch.load(pretrain_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
...

打印模型的梯度:

# 打印梯度
for name, parameters in model.named_parameters():
    print('{}\'s grad is:\n{}\n'.format(name, parameters.grad))

使用梯度衰减策略:

# 指数衰减
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
# 阶梯衰减
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
# 自定义间隔衰减
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[400], gamma=0.5)

梯度截断:

def clip_gradient(optimizer, grad_clip):
    """
    Clips gradients computed during backpropagation to avoid explosion of gradients.

    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
    """
    for group in optimizer.param_groups:
        for param in group["params"]:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

自定义激活函数示例:

class OutExp(nn.Module):
    def __init__(self):
        super(OutExp, self).__init__()

    def forward(self, x):
        x = -torch.exp(x)
        return x

修改模型某一层参数:nn.Parameter()

# 修改第 2 层的 bias(`layer` 是模型定义时给的名称)
model.layer[2].bias = nn.Parameter(torch.tensor([-0.01, -0.4], device=device, requires_grad=True))

模型参数初始化:

# 自定义权重初始化
def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight, gain=0.1)
        nn.init.constant_(m.bias, 0)
    # 也可以判断是否为 conv2d,使用相应的初始化方式
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    # 是否为批归一化层
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

# 模型应用函数
model.apply(weight_init)
目录
相关文章
|
1月前
|
人工智能
2025数字人短视频工具TOP5榜单:从入门到进阶的必备神器 
随着人工智能技术的快速发展,数字人短视频工具正成为内容创作领域的重要助力。从入门级简单操作到进阶专业应用,各类工具功能各异。本文将为您揭晓2025年最值得关注的五款数字人工具,助您轻松选择最适合的创作伙伴。
|
9月前
|
算法 安全 Go
公司局域网管理系统里的 Go 语言 Bloom Filter 算法,太值得深挖了
本文探讨了如何利用 Go 语言中的 Bloom Filter 算法提升公司局域网管理系统的性能。Bloom Filter 是一种高效的空间节省型数据结构,适用于快速判断元素是否存在于集合中。文中通过具体代码示例展示了如何在 Go 中实现 Bloom Filter,并应用于局域网的 IP 访问控制,显著提高系统响应速度和安全性。随着网络规模扩大和技术进步,持续优化算法和结合其他安全技术将是企业维持网络竞争力的关键。
202 2
公司局域网管理系统里的 Go 语言 Bloom Filter 算法,太值得深挖了
|
自然语言处理 机器人 Python
ChatGPT使用学习:ChatPaper安装到测试详细教程(一文包会)
ChatPaper是一个基于文本生成技术的智能研究论文工具,能够根据用户输入进行智能回复和互动。它支持快速下载、阅读论文,并通过分析论文的关键信息帮助用户判断是否需要深入了解。用户可以通过命令行或网页界面操作,进行论文搜索、下载、总结等。
336 1
ChatGPT使用学习:ChatPaper安装到测试详细教程(一文包会)
|
机器学习/深度学习 人工智能 自动驾驶
深度学习之自适应控制器设计
人工智能基于深度学习的自适应控制器设计在自动化系统、机器人控制、工业制造、无人驾驶等领域中有着广泛应用。自适应控制器借助深度学习模型的强大特征提取和学习能力,能够在未知或动态变化的环境中对系统进行实时调节,从而提升系统的响应速度、稳定性和控制精度。
425 1
|
关系型数据库 MySQL Linux
成功解决:2003 -Can‘t connect toMySQL server on ‘10.1.46.42(10060 “Unknown error“) 使用navicate连接虚拟机出错
这篇文章记录了在CentOS 7系统上安装并配置MySQL后,使用Navicat尝试进行远程连接但失败的问题。问题的主要原因是虚拟机的防火墙没有关闭。文章详细介绍了如何检查防火墙的状态,如何临时关闭它,以及如何禁止防火墙在系统启动时自动启动。当防火墙处于开启状态时,远程连接无法成功;关闭或禁用防火墙后,远程连接便能成功建立。
成功解决:2003 -Can‘t connect toMySQL server on ‘10.1.46.42(10060 “Unknown error“) 使用navicate连接虚拟机出错
|
定位技术 网络虚拟化 数据中心
VLAN与VXLAN技术解析:仅一字之差的深远区别
通过深入了解VLAN与VXLAN的技术细节和应用场景,网络工程师可以根据具体需求选择最合适的技术来优化网络架构。对于现代网络环境,尤其是大规模和多变的网络结构,理解并合理运用这些技术是提高网络效率和安全性的关键。
435 1
|
编解码 JavaScript 前端开发
vue cli3 PC端适配
【8月更文挑战第12天】
221 3
|
图形学 开发者
U3D小游戏开发秘籍:实战代码优化与性能提升技巧
【7月更文第13天】Unity 3D(U3D)作为游戏开发界的瑞士军刀,以其强大的灵活性和跨平台能力,让无数创意化为指尖上的精彩。但对于初涉U3D的小游戏开发者而言,如何高效构建项目,确保流畅体验,是一门必修课。本文将深入浅出,结合实战代码示例,分享一系列优化与提升U3D小游戏性能的宝贵技巧。
358 1
|
SQL 关系型数据库 MySQL
drds兼容性与迁移问题
drds兼容性与迁移问题
248 6
|
安全 数据安全/隐私保护 Android开发
三款远程控制软件对比,5大挑选指标:安全、稳定、易用、兼容、功能
众所周知,随着科技的发展和时代的进步,远程控制软件在日常工作中的运用所处可见,社会对远程控制软件的需求也在日益增长,为用户提供了便捷的远程访问和控制计算机的能力,使其在现代化社会中扮演着重要的角色。因此,市面上也出现了很多相似功能的产品,但是他们各有不同的特点和功能。本文主要介绍的是ToDesk、向日葵、TeamViewer三款远程控制软件间的不同之处,我将从安全、稳定、易用、兼容、功能几个方面来进行对比,帮助大家选择适合自己需求的远程办公软件。
564 1
三款远程控制软件对比,5大挑选指标:安全、稳定、易用、兼容、功能