FastAPI + ONNX 部署机器学习模型最佳实践

简介: 本文介绍了如何结合FastAPI和ONNX实现机器学习模型的高效部署。面对模型兼容性、性能瓶颈、服务稳定性和安全性等挑战,FastAPI与ONNX提供了高性能、易于开发维护、跨框架支持和活跃社区的优势。通过将模型转换为ONNX格式、构建FastAPI应用、进行性能优化及考虑安全性,可以简化部署流程,提升推理性能,确保服务的可靠性与安全性。最后,以手写数字识别模型为例,展示了完整的部署过程,帮助读者更好地理解和应用这些技术。

引言

随着人工智能的迅猛发展,将训练好的模型部署到生产环境中,为用户提供实时预测服务,已成为众多企业和开发者关注的重点。然而,模型部署并非易事,涉及到模型格式转换、服务框架选择、性能优化等多个方面。本篇文章将介绍如何结合 FastAPIONNX,实现机器学习模型的高效部署,并分享其中的最佳实践。

背景介绍 🎨

机器学习模型的部署,常常会遇到以下挑战:

  • 模型兼容性:不同的深度学习框架(如 TensorFlow、PyTorch)有各自的模型格式,直接部署可能会有兼容性问题,导致部署困难。
  • 性能瓶颈:模型推理速度直接影响用户体验和系统资源消耗,性能优化至关重要。
  • 服务稳定性:需要确保服务在高并发情况下的稳定性和可靠性,否则可能会崩溃。
  • 安全性:需要防范潜在的安全风险,如输入数据的验证、攻击防护等,保障应用安全。

看到这里,可能有人会问:“有没有一种简单的方法,可以解决这些问题呢?”答案就是——FastAPI + ONNX

为什么选择 FastAPI 与 ONNX

  • 高性能:FastAPI 与 ONNX Runtime 的组合,提供了高效的推理和响应速度,让你的服务飞起来!
  • 易于开发和维护:FastAPI 简洁的代码结构和自动文档生成功能,大大降低了开发和维护的成本,不再为繁琐的配置烦恼。
  • 跨框架支持:ONNX 支持多种主流的深度学习框架,方便模型的转换和部署,再也不用陷入框架之争。
  • 社区活跃:两个项目都有活跃的社区支持,丰富的资源和教程,遇到问题有人帮,进步之路不孤单。

最佳实践 🛠️

1.模型转换为 ONNX 格式

模型转换是部署的第一步。将训练好的模型转换为 ONNX 格式,可以提高模型的兼容性和性能。

PyTorch 模型转换

假设你有一个训练好的 PyTorch 模型,将其转换为 ONNX 格式呢只需几行代码,如下:

import torch
import torch.onnx

# 加载训练好的模型
model = torch.load('model.pth')
model.eval()

# 定义一个输入张量(示例输入)
dummy_input = torch.randn(1, 3, 224, 224)

# 导出为 ONNX 格式
torch.onnx.export(model, dummy_input, 'model.onnx', 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['output'])

print("✅ 模型已成功转换为 ONNX 格式!")

TensorFlow 模型转换

对于 TensorFlow 的模型,也是类似的操作。

import tensorflow as tf
import tf2onnx

# 加载训练好的模型
model = tf.keras.models.load_model('model.h5')

# 转换为 ONNX 格式
spec = (tf.TensorSpec(model.inputs[0].shape, dtype=tf.float32, name="input"),)
output_path = "model.onnx"
model_proto, _ = tf2onnx.convert.from_keras(model, 
                                            input_signature=spec, 
                                            opset=11, 
                                            output_path=output_path)

print("✅ 模型已成功转换为 ONNX 格式!")

验证转换后的模型

转换完成后,别忘了验证一下模型是否正常工作!

import onnx
import onnxruntime as ort

# 加载 ONNX 模型
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
print("✅ 模型格式验证通过!")

# 使用 ONNX Runtime 进行推理
ort_session = ort.InferenceSession('model.onnx')

# 准备输入数据
import numpy as np
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 运行推理
outputs = ort_session.run(None, {'input': input_data})

print("输出结果:", outputs)

示例输出:
1.png

2.构建 FastAPI 应用

现在,我们来创建一个基于 FastAPI 的应用,将模型部署为一个 API 服务。

安装依赖

首先,安装必要的依赖包:

pip install fastapi uvicorn[standard] onnxruntime

定义 FastAPI 应用

编写应用主文件 main.py

from fastapi import FastAPI
import onnxruntime as ort
import numpy as np
from pydantic import BaseModel

app = FastAPI(title="机器学习模型部署 API 🚀")

# 加载 ONNX 模型
ort_session = ort.InferenceSession('model.onnx')

# 定义输入数据模型
class InputData(BaseModel):
    data: list

@app.post("/predict")
async def predict(input_data: InputData):
    # 将输入数据转换为 numpy 数组
    input_array = np.array(input_data.data).astype(np.float32)

    # 进行推理
    outputs = ort_session.run(None, {
   'input': input_array})

    # 返回结果
    return {
   "prediction": outputs[0].tolist()}

运行应用

使用uvicorn启动应用:

uvicorn main:app --host 0.0.0.0 --port 8000

测试接口

可以使用 curl 或其他工具测试一下接口是否正常工作:

curl -X POST "http://localhost:8000/predict" -H "Content-Type: application/json" -d '{
    "data": [0.5, 0.3, 0.2]
}'

示例输出:

{
   
    "prediction": [[0.1, 0.9]]
}

至此我们的 API 已经可以正常工作了!

3.性能优化

性能对于一个服务来说至关重要,这里介绍一些优化技巧。

模型优化

  • 使用模型优化工具:ONNX 提供了模型优化工具,可简化和加速模型。
python -m onnxruntime.tools.optimizer_cli --input model.onnx --output model_optimized.onnx --optimization_level all
  • 量化模型:通过模型量化,将浮点数精度降低,减小模型大小,加速推理。
python -m onnxruntime.quantization.quantize --input model.onnx --output model_quant.onnx --per_channel

推理加速

  • 使用 GPU 加速:如果有 GPU 资源,可以使用 GPU 提供商提升推理速度。
ort_session = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
  • 多线程或多进程:根据服务器性能,调整并发数,充分利用硬件资源。

4.安全性考虑

安全是服务的底线,我们需要考虑以下几点。

输入验证

  • 数据格式验证:使用 Pydantic 模型,确保输入数据的格式和类型正确。
  • 异常处理:捕获可能的异常,如数据维度错误,返回友好的错误信息。
@app.post("/predict")
async def predict(input_data: InputData):
    try:
        # 输入验证
        input_array = np.array(input_data.data).astype(np.float32)
        # 检查输入维度(根据模型需求调整)
        if input_array.shape != (1, 3, 224, 224):
            return {
   "error": "输入数据维度不正确"}
        # 进行推理
        outputs = ort_session.run(None, {
   'input': input_array})
        return {
   "prediction": outputs[0].tolist()}
    except Exception as e:
        return {
   "error": str(e)}

安全防护

  • 限制请求频率:通过中间件或网关,防止恶意请求和 DDoS 攻击。
  • SSL/HTTPS:在生产环境中,确保通信的安全性。

案例示例 🎯

下面以一个手写数字识别模型为例,展示完整的部署过程。

1.模型训练与转换

# 使用 MNIST 数据集训练模型
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 定义简单的神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc(x)
        return x

# 创建模型实例
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 加载数据集
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('.', train=True, download=True, transform=transforms.ToTensor()),
    batch_size=64, shuffle=True)

# 训练模型(这里只训练一个 epoch 作示例)
for epoch in range(1):
    for data, target in train_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# 保存模型
torch.save(model.state_dict(), 'mnist.pth')

print("✅ 模型训练完成并已保存!")

# 加载模型并转换为 ONNX
model = Net()
model.load_state_dict(torch.load('mnist.pth'))
model.eval()

dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, 'mnist.onnx', input_names=['input'], output_names=['output'])

print("✅ 模型已成功转换为 ONNX 格式!")

2.构建 FastAPI 应用

from fastapi import FastAPI, File, UploadFile
from PIL import Image
import onnxruntime as ort
import numpy as np

app = FastAPI(title="手写数字识别 API 🖊️")
ort_session = ort.InferenceSession('mnist.onnx')

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    # 读取上传的图片
    image = Image.open(file.file).convert('L')
    # 图片预处理
    image = image.resize((28, 28))
    image_data = np.array(image).astype(np.float32).reshape(1, 1, 28, 28)
    # 归一化
    image_data /= 255.0
    # 推理
    outputs = ort_session.run(None, {
   'input': image_data})
    prediction = np.argmax(outputs[0])
    # 返回结果
    return {
   "prediction": int(prediction)}

3.测试

使用工具(如 cURL、Postman)发送请求,验证接口功能。

curl -X POST "http://localhost:8000/predict" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@test_digit.png;type=image/png"

示例输出:

    {
   
        "prediction": 7
    }

小结 ⚡️

通过遵循上述最佳实践,我们可以简化部署流程,提高模型的推理性能,增强服务的可靠性和安全性。当然,在实际应用中,我们还需要根据具体情况进行优化和调整,希望本篇文章可以对各位读者有所帮助!

目录
相关文章
|
12天前
|
机器学习/深度学习 存储 设计模式
特征时序化建模:基于特征缓慢变化维度历史追踪的机器学习模型性能优化方法
本文探讨了数据基础设施设计中常见的一个问题:数据仓库或数据湖仓中的表格缺乏构建高性能机器学习模型所需的历史记录,导致模型性能受限。为解决这一问题,文章介绍了缓慢变化维度(SCD)技术,特别是Type II类型的应用。通过SCD,可以有效追踪维度表的历史变更,确保模型训练数据包含完整的时序信息,从而提升预测准确性。文章还从数据工程师、数据科学家和产品经理的不同视角提供了实施建议,强调历史数据追踪对提升模型性能和业务洞察的重要性,并建议采用渐进式策略逐步引入SCD设计模式。
25 8
特征时序化建模:基于特征缓慢变化维度历史追踪的机器学习模型性能优化方法
|
15天前
|
机器学习/深度学习 人工智能 算法
机器学习算法的优化与改进:提升模型性能的策略与方法
机器学习算法的优化与改进:提升模型性能的策略与方法
111 13
机器学习算法的优化与改进:提升模型性能的策略与方法
|
7天前
|
机器学习/深度学习 人工智能 自然语言处理
云上一键部署 DeepSeek-V3 模型,阿里云 PAI-Model Gallery 最佳实践
本文介绍了如何在阿里云 PAI 平台上一键部署 DeepSeek-V3 模型,通过这一过程,用户能够轻松地利用 DeepSeek-V3 模型进行实时交互和 API 推理,从而加速 AI 应用的开发和部署。
|
NoSQL 测试技术 Redis
FastAPI(八十四)实战开发《在线课程学习系统》--接口测试(下)
FastAPI(八十四)实战开发《在线课程学习系统》--接口测试(下)
FastAPI(八十四)实战开发《在线课程学习系统》--接口测试(下)
|
存储 测试技术 数据安全/隐私保护
FastAPI(八十三)实战开发《在线课程学习系统》--注册接口单元测试
FastAPI(八十三)实战开发《在线课程学习系统》--注册接口单元测试
FastAPI(八十三)实战开发《在线课程学习系统》--注册接口单元测试
|
测试技术 数据安全/隐私保护
FastAPI(八十四)实战开发《在线课程学习系统》--接口测试(上)
FastAPI(八十四)实战开发《在线课程学习系统》--接口测试(上)
FastAPI(八十二)实战开发《在线课程学习系统》接口开发-- 课程上架下架
FastAPI(八十二)实战开发《在线课程学习系统》接口开发-- 课程上架下架
|
NoSQL Redis 数据库
FastAPI(八十一)实战开发《在线课程学习系统》接口开发-- 推荐课程列表与课程点赞
FastAPI(八十一)实战开发《在线课程学习系统》接口开发-- 推荐课程列表与课程点赞
FastAPI(八十)实战开发《在线课程学习系统》接口开发-- 课程列表
FastAPI(八十)实战开发《在线课程学习系统》接口开发-- 课程列表
FastAPI(七十九)实战开发《在线课程学习系统》接口开发-- 加入课程和退出课程
FastAPI(七十九)实战开发《在线课程学习系统》接口开发-- 加入课程和退出课程
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等