【Python】如何搭建一个简单Pytorch模型(基础版)
文章目录
- 【Python】如何搭建一个简单Pytorch模型(基础版)
-
- 1. 介绍
- 2. 解决方法
-
- 2.1 搭建模型
- 2.2 数据读入
- 2.3 主要过程、训练
- 3. 效果展示
1. 介绍
已知的函数关系如下:

需求为例:
- 现有1w多组数据对(PV,Ta,Vw)
- 我们需要去拟合五个参数:b0-b4

2. 解决方法
2.1 搭建模型
import torch
import torch.nn as nn
class Demo(nn.Module):
def __init__(self):
super(Demo, self).__init__()
self.b0 = nn.Parameter(torch.Tensor([0.0]), requires_grad = True)
self.b1 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
self.b2 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
self.b3 = nn.Parameter(torch.Tensor([1.0]), requires_grad = True)
self.b4 = nn.Parameter(torch.Tensor([0.1]), requires_grad = True)
def forward(self, T, V):
P_hat = self.b0 + self.b1 * T + self.b2 / (self.b3 + self.b4 * V)
return P_hat
2.2 数据读入
import pandas as pd
import numpy as np
def read_data(file_path):
df = pd.read_csv(file_path)
T = np.array(df['Ta'], dtype=float)
V = np.array(df['Vw'], dtype=float)
P = np.array(df['PV'], dtype=float)
return T, V, P
2.3 主要过程、训练
import os
import sys
import math
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from utils import *
from model import Demo
import random
# read_data
file_path = './data/data.csv'
T, V, P = read_data(file_path)
T, V, P = torch.FloatTensor(T), torch.FloatTensor(V), torch.FloatTensor(P)
# train : valid = 4 : 1
n = T.shape[0]
s = [[i] for i in range(n)]
random.shuffle(s)
spl = int(n * 0.8)
T_train, V_train, P_train = T[s[:spl]], V[s[:spl]], P[s[:spl]]
T_valid, V_valid, P_valid = T[s[spl:]], V[s[spl:]], P[s[spl:]]
# Model Defination
model = Demo()
# Loss_Fuction Defination
loss_fn = nn.MSELoss(reduction="mean")
# Base Parameters
N_epoch = 500
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
min_loss = 1e5
B = (0.0, 0.0, 0.0, 0.0, 0.0)
model.train()
for epoch in range(N_epoch):
# train
P_hat = model(T_train, V_train)
loss = loss_fn(P_hat, P_train)
print(f"epooch {epoch}...train_loss: {loss}, min_valid_loss: {min_loss}")
optimizer.zero_grad()
loss.backward()
optimizer.step()
# valid
if epoch % 20 == 0 or epoch == N_epoch - 1:
model.eval()
P_hat = model(T_valid, V_valid)
loss = loss_fn(P_hat, P_valid)
if loss < min_loss:
min_loss = loss
B = (model.b0.data[0], model.b1.data[0], model.b2.data[0], model.b3.data[0], model.b4.data[0])
model.train()
# 获得拟合参数
print(B)
3. 效果展示
