mlflow

main.py

# -*- encoding: utf-8 -*-
import torch
from torch import nn
import mlflow
import mlflow.pytorch

x = torch.tensor([[1.0], [2.0], [3.0]])
y = torch.tensor([[2.0], [4.0], [6.0]])


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.liner = nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.liner(x)
        return y_pred


model = Model()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

num_epochs = 100

for epoch in range(num_epochs):
    # forward
    y_predict = model(x)
    # compute loss
    loss = loss_fn(y_predict, y)
    print(epoch, loss.data.item())
    # update parameters
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# After training
for i in [4.0, 5.0, 6.0]:
    hour_var = torch.tensor([i])
    print('predict (after training): {},{}'.format(i, model(hour_var).data.item()))

# log model to service
with mlflow.start_run() as run:
    mlflow.log_param('epoch', num_epochs)
    mlflow.pytorch.log_model(model, 'model_example')

conda.yaml

name: pytorch-example
channels:
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
  - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
dependencies:
  - numpy==1.18.1
  - tqdm==4.32.1
  - scikit-learn==0.21.1
  - pytorch==1.4.0
  - pip:
    - torchtext==0.4.0
    - torchvision==0.5.0
    - mlflow

MLproject

name: pytorch-example
conda_env: conda.yaml

entry_points:
  main:
    command: |
      python main.py

训练模型并保存:

mlflow run pytorch_example

启动服务

mlflow models serve -m runs:/937f655626b54bdeb905d326e0bd4c47/model_example -h 172.16.1.120 -p 1234

MLmodel

artifact_path: model_example
flavors:
  python_function:
    data: data
    env: conda.yaml
    loader_module: mlflow.pytorch
    pickle_module_name: mlflow.pytorch.pickle_module
    python_version: 3.6.2
  pytorch:
    model_data: data
    pytorch_version: 1.4.0
run_id: 937f655626b54bdeb905d326e0bd4c47
utc_time_created: '2020-04-20 13:31:54.677497'

http请求POST:

  • 地址:http://172.16.1.120:1234/invocations
  • 参数:json,可以生成pd.DataFrame
{
    "columns": ["x1"],
    "data": [2.0, 3.0]
}

你可能感兴趣的:(模型部署)