FastAPI + ONNX 部署机器学习模型最佳实践

简介: 本文介绍了如何结合FastAPI和ONNX实现机器学习模型的高效部署。面对模型兼容性、性能瓶颈、服务稳定性和安全性等挑战,FastAPI与ONNX提供了高性能、易于开发维护、跨框架支持和活跃社区的优势。通过将模型转换为ONNX格式、构建FastAPI应用、进行性能优化及考虑安全性,可以简化部署流程,提升推理性能,确保服务的可靠性与安全性。最后,以手写数字识别模型为例,展示了完整的部署过程,帮助读者更好地理解和应用这些技术。

引言

随着人工智能的迅猛发展,将训练好的模型部署到生产环境中,为用户提供实时预测服务,已成为众多企业和开发者关注的重点。然而,模型部署并非易事,涉及到模型格式转换、服务框架选择、性能优化等多个方面。本篇文章将介绍如何结合 FastAPIONNX,实现机器学习模型的高效部署,并分享其中的最佳实践。

背景介绍 🎨

机器学习模型的部署,常常会遇到以下挑战:

  • 模型兼容性:不同的深度学习框架(如 TensorFlow、PyTorch)有各自的模型格式,直接部署可能会有兼容性问题,导致部署困难。
  • 性能瓶颈:模型推理速度直接影响用户体验和系统资源消耗,性能优化至关重要。
  • 服务稳定性:需要确保服务在高并发情况下的稳定性和可靠性,否则可能会崩溃。
  • 安全性:需要防范潜在的安全风险,如输入数据的验证、攻击防护等,保障应用安全。

看到这里,可能有人会问:“有没有一种简单的方法,可以解决这些问题呢?”答案就是——FastAPI + ONNX

为什么选择 FastAPI 与 ONNX

  • 高性能:FastAPI 与 ONNX Runtime 的组合,提供了高效的推理和响应速度,让你的服务飞起来!
  • 易于开发和维护:FastAPI 简洁的代码结构和自动文档生成功能,大大降低了开发和维护的成本,不再为繁琐的配置烦恼。
  • 跨框架支持:ONNX 支持多种主流的深度学习框架,方便模型的转换和部署,再也不用陷入框架之争。
  • 社区活跃:两个项目都有活跃的社区支持,丰富的资源和教程,遇到问题有人帮,进步之路不孤单。

最佳实践 🛠️

1.模型转换为 ONNX 格式

模型转换是部署的第一步。将训练好的模型转换为 ONNX 格式,可以提高模型的兼容性和性能。

PyTorch 模型转换

假设你有一个训练好的 PyTorch 模型,将其转换为 ONNX 格式呢只需几行代码,如下:

import torch
import torch.onnx

# 加载训练好的模型
model = torch.load('model.pth')
model.eval()

# 定义一个输入张量(示例输入)
dummy_input = torch.randn(1, 3, 224, 224)

# 导出为 ONNX 格式
torch.onnx.export(model, dummy_input, 'model.onnx', 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['output'])

print("✅ 模型已成功转换为 ONNX 格式!")
AI 代码解读

TensorFlow 模型转换

对于 TensorFlow 的模型,也是类似的操作。

import tensorflow as tf
import tf2onnx

# 加载训练好的模型
model = tf.keras.models.load_model('model.h5')

# 转换为 ONNX 格式
spec = (tf.TensorSpec(model.inputs[0].shape, dtype=tf.float32, name="input"),)
output_path = "model.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model, 
                                            input_signature=spec, 
                                            opset=11, 
                                            output_path=output_path)

print("✅ 模型已成功转换为 ONNX 格式!")
AI 代码解读

验证转换后的模型

转换完成后,别忘了验证一下模型是否正常工作!

import onnx
import onnxruntime as ort

# 加载 ONNX 模型
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
print("✅ 模型格式验证通过!")

# 使用 ONNX Runtime 进行推理
ort_session = ort.InferenceSession('model.onnx')

# 准备输入数据
import numpy as np
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 运行推理
outputs = ort_session.run(None, {'input': input_data})

print("输出结果:", outputs)
AI 代码解读

示例输出:
1.png

2.构建 FastAPI 应用

现在,我们来创建一个基于 FastAPI 的应用,将模型部署为一个 API 服务。

安装依赖

首先,安装必要的依赖包:

pip install fastapi uvicorn[standard] onnxruntime
AI 代码解读

定义 FastAPI 应用

编写应用主文件 main.py

from fastapi import FastAPI
import onnxruntime as ort
import numpy as np
from pydantic import BaseModel

app = FastAPI(title="机器学习模型部署 API 🚀")

# 加载 ONNX 模型
ort_session = ort.InferenceSession('model.onnx')

# 定义输入数据模型
class InputData(BaseModel):
    data: list

@app.post("/predict")
async def predict(input_data: InputData):
    # 将输入数据转换为 numpy 数组
    input_array = np.array(input_data.data).astype(np.float32)

    # 进行推理
    outputs = ort_session.run(None, {
   'input': input_array})

    # 返回结果
    return {
   "prediction": outputs[0].tolist()}
AI 代码解读

运行应用

使用uvicorn启动应用:

uvicorn main:app --host 0.0.0.0 --port 8000
AI 代码解读

测试接口

可以使用 curl 或其他工具测试一下接口是否正常工作:

curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{
    "data": [0.5, 0.3, 0.2]
}'
AI 代码解读

示例输出:

{
   
    "prediction": [[0.1, 0.9]]
}
AI 代码解读

至此我们的 API 已经可以正常工作了!

3.性能优化

性能对于一个服务来说至关重要,这里介绍一些优化技巧。

模型优化

  • 使用模型优化工具:ONNX 提供了模型优化工具,可简化和加速模型。
python -m onnxruntime.tools.optimizer_cli --input model.onnx --output model_optimized.onnx --optimization_level all
AI 代码解读
  • 量化模型:通过模型量化,将浮点数精度降低,减小模型大小,加速推理。
python -m onnxruntime.quantization.quantize --input model.onnx --output model_quant.onnx --per_channel
AI 代码解读

推理加速

  • 使用 GPU 加速:如果有 GPU 资源,可以使用 GPU 提供商提升推理速度。
ort_session = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
AI 代码解读
  • 多线程或多进程:根据服务器性能,调整并发数,充分利用硬件资源。

4.安全性考虑

安全是服务的底线,我们需要考虑以下几点。

输入验证

  • 数据格式验证:使用 Pydantic 模型,确保输入数据的格式和类型正确。
  • 异常处理:捕获可能的异常,如数据维度错误,返回友好的错误信息。
@app.post("/predict")
async def predict(input_data: InputData):
    try:
        # 输入验证
        input_array = np.array(input_data.data).astype(np.float32)
        # 检查输入维度(根据模型需求调整)
        if input_array.shape != (1, 3, 224, 224):
            return {
   "error": "输入数据维度不正确"}
        # 进行推理
        outputs = ort_session.run(None, {
   'input': input_array})
        return {
   "prediction": outputs[0].tolist()}
    except Exception as e:
        return {
   "error": str(e)}
AI 代码解读

安全防护

  • 限制请求频率:通过中间件或网关,防止恶意请求和 DDoS 攻击。
  • SSL/HTTPS:在生产环境中,确保通信的安全性。

案例示例 🎯

下面以一个手写数字识别模型为例,展示完整的部署过程。

1.模型训练与转换

# 使用 MNIST 数据集训练模型
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 定义简单的神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc(x)
        return x

# 创建模型实例
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 加载数据集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=64, shuffle=True)

# 训练模型(这里只训练一个 epoch 作示例)
for epoch in range(1):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# 保存模型
torch.save(model.state_dict(), 'mnist.pth')

print("✅ 模型训练完成并已保存!")

# 加载模型并转换为 ONNX
model = Net()
model.load_state_dict(torch.load('mnist.pth'))
model.eval()

dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, 'mnist.onnx', input_names=['input'], output_names=['output'])

print("✅ 模型已成功转换为 ONNX 格式!")
AI 代码解读

2.构建 FastAPI 应用

from fastapi import FastAPI, File, UploadFile
from PIL import Image
import onnxruntime as ort
import numpy as np

app = FastAPI(title="手写数字识别 API 🖊️")
ort_session = ort.InferenceSession('mnist.onnx')

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    # 读取上传的图片
    image = Image.open(file.file).convert('L')
    # 图片预处理
    image = image.resize((28, 28))
    image_data = np.array(image).astype(np.float32).reshape(1, 1, 28, 28)
    # 归一化
    image_data /= 255.0
    # 推理
    outputs = ort_session.run(None, {
   'input': image_data})
    prediction = np.argmax(outputs[0])
    # 返回结果
    return {
   "prediction": int(prediction)}
AI 代码解读

3.测试

使用工具(如 cURL、Postman)发送请求,验证接口功能。

curl -X POST "http://localhost:8000/predict" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@test_digit.png;type=image/png"
AI 代码解读

示例输出:

    {
   
        "prediction": 7
    }
AI 代码解读

小结 ⚡️

通过遵循上述最佳实践,我们可以简化部署流程,提高模型的推理性能,增强服务的可靠性和安全性。当然,在实际应用中,我们还需要根据具体情况进行优化和调整,希望本篇文章可以对各位读者有所帮助!

目录
打赏
0
20
20
1
153
分享
相关文章
【新模型速递】PAI一键云上零门槛部署DeepSeek-V3-0324、Qwen2.5-VL-32B
PAI-Model Gallery 集成国内外 AI 开源社区中优质的预训练模型,涵盖了 LLM、AIGC、CV、NLP 等各个领域,用户可以通过 PAI 以零代码方式实现从训练到部署再到推理的全过程,获得更快、更高效、更便捷的 AI 开发和应用体验。 现阿里云PAI-Model Gallery已同步接入DeepSeek-V3-0324、Qwen2.5-VL-32B-Instruct两大新模型,提供企业级部署方案。
AI训练师入行指南(三):机器学习算法和模型架构选择
从淘金到雕琢,将原始数据炼成智能珠宝!本文带您走进数字珠宝工坊,用算法工具打磨数据金砂。从基础的经典算法到精密的深度学习模型,结合电商、医疗、金融等场景实战,手把手教您选择合适工具,打造价值连城的智能应用。掌握AutoML改装套件与模型蒸馏术,让复杂问题迎刃而解。握紧算法刻刀,为数字世界雕刻文明!
18 6
云上一键部署通义千问 QwQ-32B 模型,阿里云 PAI 最佳实践
3月6日阿里云发布并开源了全新推理模型通义千问 QwQ-32B,在一系列权威基准测试中,千问QwQ-32B模型表现异常出色,几乎完全超越了OpenAI-o1-mini,性能比肩Deepseek-R1,且部署成本大幅降低。并集成了与智能体 Agent 相关的能力,够在使用工具的同时进行批判性思考,并根据环境反馈调整推理过程。阿里云人工智能平台 PAI-Model Gallery 现已经支持一键部署 QwQ-32B,本实践带您部署体验专属 QwQ-32B模型服务。
DistilQwen2.5蒸馏小模型在PAI-ModelGallery的训练、评测、压缩及部署实践
DistilQwen2.5 是阿里云人工智能平台 PAI 推出的全新蒸馏大语言模型系列。通过黑盒化和白盒化蒸馏结合的自研蒸馏链路,DistilQwen2.5各个尺寸的模型在多个基准测试数据集上比原始 Qwen2.5 模型有明显效果提升。这一系列模型在移动设备、边缘计算等资源受限的环境中具有更高的性能,在较小参数规模下,显著降低了所需的计算资源和推理时长。阿里云的人工智能平台 PAI,作为一站式的机器学习和深度学习平台,对 DistilQwen2.5 模型系列提供了全面的技术支持。本文详细介绍在 PAI 平台使用 DistilQwen2.5 蒸馏小模型的全链路最佳实践。
阿里万相重磅开源,人工智能平台PAI一键部署教程来啦
阿里云视频生成大模型万相2.1(Wan)重磅开源!Wan2.1 在处理复杂运动、还原真实物理规律、提升影视质感以及优化指令遵循方面具有显著的优势,轻松实现高质量的视频生成。同时,万相还支持业内领先的中英文文字特效生成,满足广告、短视频等领域的创意需求。阿里云人工智能平台 PAI-Model Gallery 现已经支持一键部署阿里万相重磅开源的4个模型,可获得您的专属阿里万相服务。
基于机器学习的数据分析:PLC采集的生产数据预测设备故障模型
本文介绍如何利用Python和Scikit-learn构建基于PLC数据的设备故障预测模型。通过实时采集温度、振动、电流等参数,进行数据预处理和特征提取,选择合适的机器学习模型(如随机森林、XGBoost),并优化模型性能。文章还分享了边缘计算部署方案及常见问题排查,强调模型预测应结合定期维护,确保系统稳定运行。
103 0
全网首发 | PAI Model Gallery一键部署阶跃星辰Step-Video-T2V、Step-Audio-Chat模型
Step-Video-T2V 是一个最先进的 (SoTA) 文本转视频预训练模型,具有 300 亿个参数,能够生成高达 204 帧的视频;Step-Audio 则是行业内首个产品级的开源语音交互模型,通过结合 130B 参数的大语言模型,语音识别模型与语音合成模型,实现了端到端的文本、语音对话生成,能和用户自然地进行高质量对话。PAI Model Gallery 已支持阶跃星辰最新发布的 Step-Video-T2V 文生视频模型与 Step-Audio-Chat 大语言模型的一键部署,本文将详细介绍具体操作步骤。
多元线性回归:机器学习中的经典模型探讨
多元线性回归是统计学和机器学习中广泛应用的回归分析方法,通过分析多个自变量与因变量之间的关系,帮助理解和预测数据行为。本文深入探讨其理论背景、数学原理、模型构建及实际应用,涵盖房价预测、销售预测和医疗研究等领域。文章还讨论了多重共线性、过拟合等挑战,并展望了未来发展方向,如模型压缩与高效推理、跨模态学习和自监督学习。通过理解这些内容,读者可以更好地运用多元线性回归解决实际问题。
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构
K-means聚类算法是机器学习中常用的一种聚类方法,通过将数据集划分为K个簇来简化数据结构。本文介绍了K-means算法的基本原理,包括初始化、数据点分配与簇中心更新等步骤,以及如何在Python中实现该算法,最后讨论了其优缺点及应用场景。
255 6
基于机器学习的人脸识别算法matlab仿真,对比GRNN,PNN,DNN以及BP四种网络
本项目展示了人脸识别算法的运行效果(无水印),基于MATLAB2022A开发。核心程序包含详细中文注释及操作视频。理论部分介绍了广义回归神经网络(GRNN)、概率神经网络(PNN)、深度神经网络(DNN)和反向传播(BP)神经网络在人脸识别中的应用,涵盖各算法的结构特点与性能比较。

热门文章

最新文章

AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等