深度学习之PyTorch物体检测实战(2)

1.神经网络工具torch.nn

from torch import  optim
import  torch
from torch import  nn
import torch.nn.functional as F
#三层感知机
class MLP(nn.Module):
    def __init__(self,in_dim,hid_dim1,hid_dim2,out_dim):
        super(MLP,self).__init__()
        self.layer=nn.Sequential(
            nn.Linear(in_dim,hid_dim1),
            nn.ReLU(),
            nn.Linear(hid_dim1,hid_dim2),
            nn.ReLU(),
            nn.Linear(hid_dim2,out_dim),
            nn.ReLU()
        )

    def forward(self,x):
        x=self.layer(x)
        return x

model=MLP(28*28,300,200,10)

optimizer=optim.SGD(model.parameters(),lr=0.1)
data=torch.randn(10,28*28)
label=torch.Tensor([1,0,4,7,9,3,4,5,3,2]).long()
for _ in range(100):
    output = model(data)
    loss=F.cross_entropy(output,label)
    print(loss)
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()

2.模型处理

from torch import nn
from torchvision import models
import torch
vgg=models.vgg16(pretrained=True)
print(vgg.features[24:])
print(len(vgg.classifier))


#冻结前三层数据不更新
for layer in range(10):
    for p in vgg[layer].parameters():
        p.requires_grad=False


# 只保存模型参数
torch.save(model.state_dict(), '\parameter.pkl')
# 加载
model = TheModelClass(...)
model.load_state_dict(torch.load('\parameter.pkl'))

# 保存完整模型
torch.save(model, '\model.pkl')
# 加载

3.数据处理

#数据加载
from torch.utils.data import Dataset
class my_Data(Dataset):
    def __init__(self,image_path,annotation_path,transform=None):
       # 读取数据
    def __len__(self):
        #数据集长度
    def __getitem__(self, item):
        #获取指定id数据

dataset=my_Data('image_path','annotation_path')
for data in dataset:
    print(data)

#数据增强
from torchvision import  transforms
dataset=my_Data('image_path','annotation_path',transform=transforms.Compose([
    transforms.Resize(256),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
]))

#数据读取
from torch.utils.data import DataLoader
dataloader=DataLoader(dataset,batch_size=4,shuffle=True,num_workers=4)
data_iter=iter(dataloader)
for step in range(iters_per_epoch):
    data=next(data_iter)

你可能感兴趣的:(深度学习之PyTorch物体检测实战(2))