PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析-阿里云开发者社区

开发者社区> 一个处女座的程序猿> 正文

PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

简介: PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析
+关注继续查看

输出结果

image.png


核心代码

#PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

from sklearn.datasets import make_regression

import seaborn as sns

import pandas as pd

import matplotlib.pyplot as plt

sns.set()

x_train, y_train, W_target = make_regression(n_samples=100, n_features=1, noise=10, coef = True)

df = pd.DataFrame(data = {'X':x_train.ravel(), 'Y':y_train.ravel()})

sns.lmplot(x='X', y='Y', data=df, fit_reg=True)

plt.show()

x_torch = torch.FloatTensor(x_train)

y_torch = torch.FloatTensor(y_train)

y_torch = y_torch.view(y_torch.size()[0], 1)

class LinearRegression(torch.nn.Module):  #定义LR的类。torch.nn库构建模型

   #PyTorch 的 nn 库中有大量有用的模块,其中一个就是线性模块。如名字所示,它对输入执行线性变换,即线性回归。

   def __init__(self, input_size, output_size):

       super(LinearRegression, self).__init__()

       self.linear = torch.nn.Linear(input_size, output_size)  

   def forward(self, x):

       return self.linear(x)

model = LinearRegression(1, 1)

criterion = torch.nn.MSELoss() #训练线性回归,我们需要从 nn 库中添加合适的损失函数。对于线性回归,我们将使用 MSELoss()——均方差损失函数

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)#还需要使用优化函数(SGD),并运行与之前示例类似的反向传播。本质上,我们重复上文定义的 train() 函数中的步骤。

#不能直接使用该函数的原因是我们实现它的目的是分类而不是回归,以及我们使用交叉熵损失和最大元素的索引作为模型预测。而对于线性回归,我们使用线性层的输出作为预测。

for epoch in range(50):

   data, target = Variable(x_torch), Variable(y_torch)

   output = model(data)

   optimizer.zero_grad()

   loss = criterion(output, target)

   loss.backward()

   optimizer.step()

predicted = model(Variable(x_torch)).data.numpy()

#打印出原始数据和适合 PyTorch 的线性回归

plt.plot(x_train, y_train, 'o', label='Original data')

plt.plot(x_train, predicted, label='Fitted line')

plt.legend()

plt.title(u'Py:PyTorch实现简单合成数据集上的线性回归进行数据分析')

plt.show()


版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。

相关文章
kubernetes RBAC实战 kubernetes 用户角色访问控制,dashboard访问,kubectl配置生成
kubernetes RBAC实战 环境准备 先用kubeadm安装好kubernetes集群,[包地址在此](https://market.aliyun.com/products/56014009/cmxz022571.
1758 0
Charted – 自动化的可视化数据生成工具
  Charted 是一个让数据自动生成可视化图表的工具。只需要提供一个数据文件的链接,它就能返回一个美丽的,可共享的图表。Charted 不会存储任何数据。它只是获取和让链接提供的数据可视化。     在线演示      插件下载   您可能感兴趣的相关文章 网站开发中很有用...
749 0
MyEclipse 从数据库反向生成Hibernate实体类
         第一个大步骤 window-->open Perspective-->MyEclipse Java Persistence 进行了上面的 操作后会出现一个视图DB Brower:MyEclipse Derby,点击右键新建一个在出现的面板中,driver templat...
799 0
推荐14款非常有用的 CSS 网格系统生成工具
今天这篇文章向大家推荐14款非常有用的 CSS 网格系统生成工具,它们能够帮助你构建适合你网站项目的 CSS 网格系统。一个系统化、结构合理的布局使得能够更快更轻松的组织网站的内容。网格系统为网页设计师们提供了一种快速构造网页内容布局的方法,帮助设计师们节省了大量的时间和精力。
522 0
+关注
一个处女座的程序猿
国内互联网圈知名博主、人工智能领域优秀创作者,全球最大中文IT社区博客专家、CSDN开发者联盟生态成员、中国开源社区专家、华为云社区专家、51CTO社区专家、Python社区专家等,曾受邀采访和评审十多次。仅在国内的CSDN平台,博客文章浏览量超过2500万,拥有超过57万的粉丝。
1701
文章
0
问答
文章排行榜
最热
最新
相关电子书
更多
文娱运维技术
立即下载
《SaaS模式云原生数据仓库应用场景实践》
立即下载
《看见新力量:二》电子书
立即下载