PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

简介: PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

输出结果

image.png


核心代码

#PyTorch:采用sklearn 工具生成这样的合成数据集+利用PyTorch实现简单合成数据集上的线性回归进行数据分析

from sklearn.datasets import make_regression

import seaborn as sns

import pandas as pd

import matplotlib.pyplot as plt

sns.set()

x_train, y_train, W_target = make_regression(n_samples=100, n_features=1, noise=10, coef = True)

df = pd.DataFrame(data = {'X':x_train.ravel(), 'Y':y_train.ravel()})

sns.lmplot(x='X', y='Y', data=df, fit_reg=True)

plt.show()

x_torch = torch.FloatTensor(x_train)

y_torch = torch.FloatTensor(y_train)

y_torch = y_torch.view(y_torch.size()[0], 1)

class LinearRegression(torch.nn.Module):  #定义LR的类。torch.nn库构建模型

   #PyTorch 的 nn 库中有大量有用的模块,其中一个就是线性模块。如名字所示,它对输入执行线性变换,即线性回归。

   def __init__(self, input_size, output_size):

       super(LinearRegression, self).__init__()

       self.linear = torch.nn.Linear(input_size, output_size)  

   def forward(self, x):

       return self.linear(x)

model = LinearRegression(1, 1)

criterion = torch.nn.MSELoss() #训练线性回归,我们需要从 nn 库中添加合适的损失函数。对于线性回归,我们将使用 MSELoss()——均方差损失函数

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)#还需要使用优化函数(SGD),并运行与之前示例类似的反向传播。本质上,我们重复上文定义的 train() 函数中的步骤。

#不能直接使用该函数的原因是我们实现它的目的是分类而不是回归,以及我们使用交叉熵损失和最大元素的索引作为模型预测。而对于线性回归,我们使用线性层的输出作为预测。

for epoch in range(50):

   data, target = Variable(x_torch), Variable(y_torch)

   output = model(data)

   optimizer.zero_grad()

   loss = criterion(output, target)

   loss.backward()

   optimizer.step()

predicted = model(Variable(x_torch)).data.numpy()

#打印出原始数据和适合 PyTorch 的线性回归

plt.plot(x_train, y_train, 'o', label='Original data')

plt.plot(x_train, predicted, label='Fitted line')

plt.legend()

plt.title(u'Py:PyTorch实现简单合成数据集上的线性回归进行数据分析')

plt.show()


相关文章
|
2月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
3月前
|
JSON 数据挖掘 API
结合数据分析工具,深入挖掘淘宝API接口的商业价值
随着电子商务的蓬勃发展,淘宝作为国内领先的电商平台,不仅为消费者提供了便捷的购物环境,同时也为开发者和数据分析师提供了丰富的数据资源。通过有效地调用淘宝API接口获取商品详情,再结合数据分析工具进行深入的数据挖掘,可以为商家、市场分析师及研究人员等带来巨大的商业价值
|
4月前
|
机器学习/深度学习 存储 SQL
15个超级棒的外文免费数据集,学习数据分析不愁没有数据用了!
15个超级棒的外文免费数据集,学习数据分析不愁没有数据用了!
|
4月前
|
数据采集 数据可视化 数据挖掘
2019 年排名前6的数据分析工具
2019 年排名前6的数据分析工具
|
30天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
56 0
|
3月前
|
数据可视化 数据挖掘 Java
提升代码质量与效率的利器——SonarQube静态代码分析工具从数据到洞察:探索Python数据分析与科学计算库
在现代软件开发中,保证代码质量是至关重要的。本文将介绍SonarQube静态代码分析工具的概念及其实践应用。通过使用SonarQube,开发团队可以及时发现和修复代码中的问题,提高代码质量,从而加速开发过程并减少后期维护成本。 在当今信息爆炸的时代,数据分析和科学计算成为了决策和创新的核心。本文将介绍Python中强大的数据分析与科学计算库,包括NumPy、Pandas和Matplotlib,帮助读者快速掌握这些工具的基本用法和应用场景。无论是数据处理、可视化还是统计分析,Python提供了丰富的功能和灵活性,使得数据分析变得更加简便高效。
|
30天前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
53 0
|
3月前
|
数据挖掘 数据安全/隐私保护 Python
【Python数据分析】<数据分析工具>基于Excel的数据分析
【1月更文挑战第22天】【Python数据分析】<数据分析工具>基于Excel的数据分析
|
5天前
|
机器学习/深度学习 数据挖掘 计算机视觉
python数据分析工具SciPy
【4月更文挑战第15天】SciPy是Python的开源库,用于数学、科学和工程计算,基于NumPy扩展了优化、线性代数、积分、插值、特殊函数、信号处理、图像处理和常微分方程求解等功能。它包含优化、线性代数、积分、信号和图像处理等多个模块。通过SciPy,可以方便地执行各种科学计算任务。例如,计算高斯分布的PDF,需要结合NumPy使用。要安装SciPy,可以使用`pip install scipy`命令。这个库极大地丰富了Python在科学计算领域的应用。
11 1
|
6天前
|
数据可视化 数据挖掘 Python
Python中数据分析工具Matplotlib
【4月更文挑战第14天】Matplotlib是Python的数据可视化库,能生成多种图表,如折线图、柱状图等。以下是一个绘制简单折线图的代码示例: ```python import matplotlib.pyplot as plt x = [1, 2, 3, 4, 5] y = [2, 4, 6, 8, 10] plt.figure() plt.plot(x, y) plt.title('简单折线图') plt.xlabel('X轴') plt.ylabel('Y轴') plt.show() ```
10 1