【PyTorch基础教程5】Pytorch完整小栗子(学不会来打我啊)

简介: 之前在【Pytorch基础教程1】也跑过线性模型的代码(没用框架),这次让我们以该模型为基础用pytorch走一遍一个完整流程。

一、PyTorch四部曲

之前在【Pytorch基础教程1】也跑过线性模型的代码(没用框架),这次让我们以该模型为基础用pytorch走一遍一个完整流程。


image.png

二、细节

(1)loss算出来要是一个标量,否则用不了backward。

(2)一般model是继承nn.Module,也可以通过继承Functions构建自己的Module(但是要自己设计反向传播函数)。

image.png

(3)self.linear()是一个可调用对象(callable),类似下图有__call__成员函数。


(4)只要是要调用计算图,都需要继承module类。

(5)过程:求y;求loss;求backward;更新。

image.png

三、线性模型

上次我们手工进行梯度下降,这次用pytorch实现,详看注释。

# -*- coding: utf-8 -*-
"""
Created on Sun Oct 17 21:51:40 2021
@author: 86493
"""
import torch
import torch.nn as nn 
import matplotlib.pyplot as plt
# x和y数据必须是矩阵,所以如[1.0]
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
losslst = []
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # 实例化一个linear对象
        self.linear = nn.Linear(1, 1)
    def forward(self, x):
        # 可调用的对象,pythonic
        y_pred = self.linear(x)
        return y_pred
model = LinearModel()
# 这里的MSE不除以N
# criterion = torch.nn.MSELoss(size_average=False)
criterion = torch.nn.MSELoss(reduction = 'sum')
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
# 训练
for epoch in range(100):
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    # 打印loss对象会自动调用__str__(),不会产生计算图
    print(epoch, loss.item())
    losslst.append(loss.item())
    optimizer.zero_grad()
    # 梯度归零后反向传播
    loss.backward()
    optimizer.step()
# 画图
plt.plot(range(100), losslst)
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.show()
# 输出weight和bias    
# 不用item也行,但就是矩阵[[]] 
print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())
print('-' *60)
# Test model
# 输入是1×1矩阵,输出也是1×1矩阵
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)

image.png

关于nn.Linear的更多介绍可以查看官方文档,这里贴一个官方文档的栗子:

>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
相关文章
|
16天前
|
Android开发 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(二十)(2)
PyTorch 2.2 中文官方教程(二十)
43 0
PyTorch 2.2 中文官方教程(二十)(2)
|
16天前
|
iOS开发 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(二十)(1)
PyTorch 2.2 中文官方教程(二十)
45 0
PyTorch 2.2 中文官方教程(二十)(1)
|
PyTorch 算法框架/工具 并行计算
PyTorch 2.2 中文官方教程(十九)(4)
PyTorch 2.2 中文官方教程(十九)
27 0
|
16天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(4)
PyTorch 2.2 中文官方教程(十八)
54 1
|
16天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十八)(3)
PyTorch 2.2 中文官方教程(十八)
28 1
PyTorch 2.2 中文官方教程(十八)(3)
|
16天前
|
并行计算 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十七)(4)
PyTorch 2.2 中文官方教程(十七)
25 2
PyTorch 2.2 中文官方教程(十七)(4)
|
16天前
|
PyTorch 算法框架/工具 机器学习/深度学习
PyTorch 2.2 中文官方教程(十七)(2)
PyTorch 2.2 中文官方教程(十七)
38 1
PyTorch 2.2 中文官方教程(十七)(2)
|
16天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十五)(1)
PyTorch 2.2 中文官方教程(十五)
46 1
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十四)(4)
PyTorch 2.2 中文官方教程(十四)
64 1
PyTorch 2.2 中文官方教程(十四)(4)
|
16天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 2.2 中文官方教程(十四)(2)
PyTorch 2.2 中文官方教程(十四)
50 1