模型部署是将训练好的模型投入实际应用的关键步骤,涉及:
本章使用ResNet18实现图像分类,并演示完整部署流程。
import torch
import torchvision
# 加载预训练模型
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
model.eval()
# 示例输入
dummy_input = torch.rand(1, 3, 224, 224)
2. 导出为TorchScript
# 方法一:追踪执行路径(适合无控制流模型)
traced_model = torch.jit.trace(model, dummy_input)
torch.jit.save(traced_model, "resnet18_traced.pt")
# 方法二:直接转换(适合含if/for的模型)
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, "resnet18_scripted.pt")
# 加载测试
loaded_model = torch.jit.load("resnet18_traced.pt")
output = loaded_model(dummy_input)
print("TorchScript输出形状:", output.shape) # 应输出torch.Size([1, 1000])
torch.onnx.export(
model,
dummy_input,
"resnet18.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
# 验证ONNX模型
import onnx
onnx_model = onnx.load("resnet18.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX模型输入输出:")
print(onnx_model.graph.input)
print(onnx_model.graph.output)
from fastapi import FastAPI, File, UploadFile
from PIL import Image
import io
import numpy as np
import torchvision.transforms as transforms
app = FastAPI()
# 加载TorchScript模型
model = torch.jit.load("resnet18_traced.pt")
# 图像预处理
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
@app.post("/predict")
async def predict(image: UploadFile = File(...)):
# 读取并预处理图像
image_data = await image.read()
img = Image.open(io.BytesIO(image_data)).convert("RGB")
tensor = preprocess(img).unsqueeze(0)
# 执行推理
with torch.no_grad():
output = model(tensor)
# 获取预测结果
_, pred = torch.max(output, 1)
return {"class_id": int(pred)}
# 运行命令:uvicorn main:app --reload
import requests
# 准备测试图片
url = "https://images.unsplash.com/photo-1517849845537-4d257902454a?auto=format&fit=crop&w=224&q=80"
response = requests.get(url)
with open("test_dog.jpg", "wb") as f:
f.write(response.content)
# 发送预测请求
with open("test_dog.jpg", "rb") as f:
files = {"image": f}
response = requests.post("http://localhost:8000/predict", files=files)
print("预测结果:", response.json()) # 应输出对应类别ID
import coremltools as ct
# 从PyTorch转换
example_input = torch.rand(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
mlmodel = ct.convert(
traced_model,
inputs=[ct.TensorType(shape=example_input.shape)]
)
mlmodel.save("ResNet18.mlmodel")
// Android示例代码(Java)
Module module = Module.load(assetFilePath(this, "resnet18_traced.pt"));
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
TensorImageUtils.TORCHVISION_NORM_STD_RGB
);
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
# 动态量化
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
torch.jit.save(torch.jit.script(quantized_model), "resnet18_quantized.pt")
# 测试量化效果
print("原始模型大小:", sum(p.numel() for p in model.parameters()))
print("量化模型大小:", sum(p.numel() for p in quantized_model.parameters()))
import onnxruntime
ort_session = onnxruntime.InferenceSession("resnet18.onnx")
ort_inputs = {ort_session.get_inputs().name: dummy_input.numpy()}
ort_outputs = ort_session.run(None, ort_inputs)
print("ONNX Runtime输出形状:", ort_outputs.shape)
torch.onnx.export(..., opset_version=13)
pip install netron
netron resnet18.onnx
@app.middleware("http")
async def add_process_time(request, call_next):
start_time = time.time()
response = await call_next(request)
response.headers["X-Process-Time"] = str(time.time() - start_time)
return response
本文重点:
下篇预告:
第六篇将深入PyTorch生态,介绍分布式训练与多GPU加速策略,实现工业级训练效率!