Pytorch模型训练与在线部署

本文涉及的产品
模型训练 PAI-DLC,5000CU*H 3个月
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
交互式建模 PAI-DSW,5000CU*H 3个月
简介: 本文以CIFAR10数据集为例,通过自定义神经元网络,完成模型的训练,并通过Flask完成模型的在线部署与调用,考略到实际生产模型高并发调用的述求,使用service_streamer提升模型在线并发能力。

一、训练模型

# 1. 加载并标准化数据集importtorchimporttorchvisionimporttorchvision.transformsastransformsimportsslssl._create_default_https_context=ssl._create_unverified_contexttransform=transforms.Compose(
    [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size=4trainset=torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader=torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=0)
testset=torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader=torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=0)
classes= ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 作图展示部分数据集样例importmatplotlib.pyplotaspltimportnumpyasnpdefimshow(img):
img=img/2+0.5npimg=img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 随机获取部分样例数据dataiter=iter(trainloader)
images, labels=next(dataiter)
# show imagesimshow(torchvision.utils.make_grid(images))
# print labelsprint(' '.join(f'{classes[labels[j]]:5s}'forjinrange(batch_size)))
# 2. 定义神经网络importtorch.nnasnnimporttorch.nn.functionalasFclassNet(nn.Module):
def__init__(self):
super().__init__()
self.conv1=nn.Conv2d(3, 6, 5)
self.pool=nn.MaxPool2d(2, 2)
self.conv2=nn.Conv2d(6, 16, 5)
self.fc1=nn.Linear(16*5*5, 120)
self.fc2=nn.Linear(120, 84)
self.fc3=nn.Linear(84, 10)
defforward(self, x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=torch.flatten(x, 1) # flatten all dimensions except batchx=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
returnxnet=Net()
# 3. 定义损失函数和优化器importtorch.optimasoptimcriterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 4. 训练神经网络forepochinrange(2):  # loop over the dataset multiple timesrunning_loss=0.0fori, datainenumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]inputs, labels=data# zero the parameter gradientsoptimizer.zero_grad()
# forward + backward + optimizeoutputs=net(inputs)
loss=criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statisticsrunning_loss+=loss.item()
ifi%2000==1999:    # print every 2000 mini-batchesprint(f'[{epoch+1}, {i+1:5d}] loss: {running_loss/2000:.3f}')
running_loss=0.0print('Finished Training')
# 保存模型PATH='./cifar_net.pth'torch.save(net.state_dict(), PATH)

二、使用本地图片测试模型

importtorch.nnasnnimporttorch.nn.functionalasFimporttorchimporttorchvision.transformsastransformsimportiofromPILimportImageclassNet(nn.Module):
def__init__(self):
super().__init__()
self.conv1=nn.Conv2d(3, 6, 5)
self.pool=nn.MaxPool2d(2, 2)
self.conv2=nn.Conv2d(6, 16, 5)
self.fc1=nn.Linear(16*5*5, 120)
self.fc2=nn.Linear(120, 84)
self.fc3=nn.Linear(84, 10)
defforward(self, x):
x=self.pool(F.relu(self.conv1(x)))
x=self.pool(F.relu(self.conv2(x)))
x=torch.flatten(x, 1) # flatten all dimensions except batchx=F.relu(self.fc1(x))
x=F.relu(self.fc2(x))
x=self.fc3(x)
returnx# 加载网络模型参数PATH='./cifar_net.pth'net=Net()
net.load_state_dict(torch.load(PATH))
transform=transforms.Compose(
    [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
deftransform_image(image_bytes):
my_transforms=transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
image=Image.open(io.BytesIO(image_bytes))
returnmy_transforms(image).unsqueeze(0)
file=open('cat.jpg', 'rb')
img_bytes=file.read()
tensor=transform_image(image_bytes=img_bytes)
outputs=net(tensor)
classes= ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
_, predicted=torch.max(outputs, 1)
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'forjinrange(1)))
  • 运行效果

图片.png

三、Flask 在线模型服务

fromflaskimportFlaskimportsslssl._create_default_https_context=ssl._create_unverified_contextimportioimportjsonimporttorchfromtorchvisionimportmodelsfromtorchvisionimporttransformsfromPILimportImageapp=Flask(__name__)
imagenet_class_index=json.load(open('./imagenet_class_index.json'))
model=models.densenet121(pretrained=True)
model.eval()
device='cpu'deftransform_image(image_bytes):
my_transforms=transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
image=Image.open(io.BytesIO(image_bytes))
returnmy_transforms(image).unsqueeze(0)
defget_prediction(image_bytes):
tensor=transform_image(image_bytes=image_bytes)
outputs=model.forward(tensor)
_, y_hat=outputs.max(1)
predicted_idx=str(y_hat.item())
returnimagenet_class_index[predicted_idx]
@app.route('/predict', methods=['POST'])
defpredict():
ifrequest.method=='POST':
file=request.files['file']
img_bytes=file.read()
class_id, class_name=get_prediction(image_bytes=img_bytes)
returnjsonify({'class_id': class_id, 'class_name': class_name})
defbatch_prediction(image_bytes_batch):
image_tensors= [transform_image(image_bytes=image_bytes) forimage_bytesinimage_bytes_batch]
tensor=torch.cat(image_tensors).to(device)
outputs=model.forward(tensor)
_, y_hat=outputs.max(1)
predicted_ids=y_hat.tolist()
return [imagenet_class_index[str(i)] foriinpredicted_ids]
fromflaskimportjsonify, requestfromservice_streamerimportThreadedStreamerstreamer=ThreadedStreamer(batch_prediction, batch_size=64)
@app.route('/stream_predict', methods=['POST'])
defstream_predict():
ifrequest.method=='POST':
file=request.files['file']
img_bytes=file.read()
class_id, class_name=streamer.predict([img_bytes])[0]
returnjsonify({'class_id': class_id, 'class_name': class_name})
if__name__=='__main__':
app.run()

四、调用在线模型

importrequestsresp=requests.post("http://localhost:5000/stream_predict",
files={"file": open('dog.jpg','rb')})
print(resp.json())
  • 链接效果

图片.png

参考链接

Vision Recognition Service with Flask and service streamer

通过带Flask的REST API在Python中部署PyTorch

TRAINING A CLASSIFIER

相关实践学习
使用PAI-EAS一键部署ChatGLM及LangChain应用
本场景中主要介绍如何使用模型在线服务(PAI-EAS)部署ChatGLM的AI-Web应用以及启动WebUI进行模型推理,并通过LangChain集成自己的业务数据。
机器学习概览及常见算法
机器学习(Machine Learning, ML)是人工智能的核心,专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能,它是使计算机具有智能的根本途径,其应用遍及人工智能的各个领域。 本课程将带你入门机器学习,掌握机器学习的概念和常用的算法。
相关文章
|
3天前
|
机器学习/深度学习 监控 API
基于云计算的机器学习模型部署与优化
【8月更文第17天】随着云计算技术的发展,越来越多的数据科学家和工程师开始使用云平台来部署和优化机器学习模型。本文将介绍如何在主要的云计算平台上部署机器学习模型,并讨论模型优化策略,如模型压缩、超参数调优以及分布式训练。
17 2
|
4天前
|
机器学习/深度学习 JSON API
【Python奇迹】FastAPI框架大显神通:一键部署机器学习模型,让数据预测飞跃至Web舞台,震撼开启智能服务新纪元!
【8月更文挑战第16天】在数据驱动的时代,高效部署机器学习模型至关重要。FastAPI凭借其高性能与灵活性,成为搭建模型API的理想选择。本文详述了从环境准备、模型训练到使用FastAPI部署的全过程。首先,确保安装了Python及相关库(fastapi、uvicorn、scikit-learn)。接着,以线性回归为例,构建了一个预测房价的模型。通过定义FastAPI端点,实现了基于房屋大小预测价格的功能,并介绍了如何运行服务器及测试API。最终,用户可通过HTTP请求获取预测结果,极大地提升了模型的实用性和集成性。
14 1
|
11天前
|
机器学习/深度学习 API 网络架构
"解锁机器学习超级能力!Databricks携手Mlflow,让模型训练与部署上演智能风暴,一触即发,点燃你的数据科学梦想!"
【8月更文挑战第9天】机器学习模型的训练与部署流程复杂,涵盖数据准备、模型训练、性能评估及部署等步骤。本文详述如何借助Databricks与Mlflow的强大组合来管理这一流程。首先需在Databricks环境内安装Mlflow库。接着,利用Mlflow跟踪功能记录训练过程中的参数与性能指标。最后,通过Mlflow提供的模型服务功能,采用REST API或Docker容器等方式部署模型。这一流程充分利用了Databricks的数据处理能力和Mlflow的生命周期管理优势。
31 7
|
6天前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
20 1
|
6天前
|
机器学习/深度学习 人工智能 关系型数据库
【机器学习】Qwen2大模型原理、训练及推理部署实战
【机器学习】Qwen2大模型原理、训练及推理部署实战
40 0
【机器学习】Qwen2大模型原理、训练及推理部署实战
|
1月前
|
机器学习/深度学习 算法 PyTorch
使用Pytorch中从头实现去噪扩散概率模型(DDPM)
在本文中,我们将构建基础的无条件扩散模型,即去噪扩散概率模型(DDPM)。从探究算法的直观工作原理开始,然后在PyTorch中从头构建它。本文主要关注算法背后的思想和具体实现细节。
8620 3
|
10天前
|
机器学习/深度学习 人工智能 自然语言处理
基于PAI 低代码实现大语言模型微调和部署
【8月更文挑战第10天】基于PAI 低代码实现大语言模型微调和部署
|
13天前
|
人工智能 异构计算
基于PAI-EAS一键部署ChatGLM及LangChain应用
【8月更文挑战第7天】基于PAI-EAS一键部署ChatGLM及LangChain应用
|
20天前
|
机器学习/深度学习 自然语言处理 数据挖掘
机器学习不再是梦!PyTorch助你轻松驾驭复杂数据分析场景
【7月更文挑战第31天】机器学习已深深嵌入日常生活,从智能推荐到自动驾驶皆为其应用。PyTorch作为一个开源库,凭借简洁API、动态计算图及GPU加速能力,降低了学习门槛并提高了开发效率。通过一个使用PyTorch构建简单CNN识别MNIST手写数字的例子,展现了如何快速搭建神经网络。随着技能提升,开发者能运用PyTorch及其丰富的生态系统(如torchvision、torchtext和torchaudio)应对复杂场景,如自然语言处理和强化学习。掌握PyTorch,意味着掌握了数据时代的关键技能。
10 1
|
6天前
|
机器学习/深度学习 数据采集 物联网
【机器学习】Google开源大模型Gemma2:原理、微调训练及推理部署实战
【机器学习】Google开源大模型Gemma2:原理、微调训练及推理部署实战
25 0

热门文章

最新文章

相关产品

  • 人工智能平台 PAI