【PyTorch】识别黑白图中的服装图案

识别黑白图中的服装图案

  • 1 数据集
    • 1.1代码处理数据集
    • 1.2代码处理数据集
  • 2 制作批次数据
  • 3 构建并训练模型

1 数据集

Fashion-MNIST
链接: https://github.com/zalandoresearch/fashion-mnist
【PyTorch】识别黑白图中的服装图案_第1张图片

1.1代码处理数据集

GitHub直接下载

1.2代码处理数据集

torchvision库可以直接对Fashion-MNIST数据集进行下载,需要指定好数据集路径。调用torchvision库中的datasets.FashionMNIST()方法进行数据集下载,download参数为True表明需要从网络下载,train设置为True,从training.pt创建数据集,否则从test.pt创建。
transform.ToTensor()
【PyTorch】识别黑白图中的服装图案_第2张图片
ToTensor将图片转为Pytorch支持的形状([通道,高,宽]),同时也将图片的数值归一化0-1的小数。
pylab结合了numpy和matplotlib.pyplot,既可以画图又可以进行简单计算。

import torchvision
import torchvision.transforms as tranforms
import pylab

data_dir='./fashion-mnist'
tranform=tranforms.Compose([tranforms.ToTensor()])
train_dataset=torchvision.datasets.FashionMNIST(data_dir,train=True,transform=tranform,download=True)

'''读取和显示图片'''
print('训练数据集条数',len(train_dataset))
val_dataset=torchvision.datasets.FashionMNIST(root=data_dir,train=False,transform=tranform)
print('测试数据集条数',len(val_dataset))
im=train_dataset[0][0].numpy()
im=im.reshape(-1,28)
pylab.imshow(im)
pylab.show()
print('该图片的标签为:',train_dataset[0][1])

【PyTorch】识别黑白图中的服装图案_第3张图片

2 制作批次数据

class DataLoader(dataset,batch_size=1,shuffle=Flase,sampler=None,num_workers=0,
collate_fn=,pin_memory=False,drop_last=False,timeout=0,
worker_init_fn=None,multiprocessing_context=None)
(1)dataset:待加载的数据集
(2)batch_size:每批次数据加载的样本数量,默认是1
(3)shuffle:是否把样本的数据打乱,默认False
(4)sampler:接收一个采样器对象,用于按照指定的样本提取策略从数据集中提取样本。
(5)num_workers:设置加载数据的额外进程数量,默认为0,即用主进程加载数据
(6)collate_fn:接受一个自定义函数,即对数据二次加工
(7)pin_memory:内存寄存,表示在数据返回前是否将数据复制到CUDA内存中
(8)drop_last:是否丢弃最后不能被batch_size整除的数据
(9)timeout:读取数据的超时时间,如果超过超时时间还没有读到数据,系统报错
(10)worker_init_fn:每个子进程的初始化函数,在加载数据之前进行
(11)multiprocessing_context:用于多进程处理的配置参数
多个采样器子类:
SequentialSampler:按照原有的顺序采样
RandomSampler:随机采样,可以设置是否重复采样
SubsetRandomSampler:按照指定的集合或索引列表进行随机采样
WeightedRandomSampler:按照指定的概率采样
BatchSampler:按照指定的批次索引采样

batch_size=10
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_loader=torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False)

3 构建并训练模型

两个卷积层结合三个全连接层.
这篇文章讲解了MaxPool2d,推荐新手可以学习了解一下https://blog.csdn.net/qq_44864833/article/details/125513812

class myConNet(torch.nn.Module):
    def __init__(self):
        super(myConNet, self).__init__()
        #定义卷积层
        self.conv1=torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5)
        self.conv2=torch.nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5)
        #定义全连接层
        self.fc1=torch.nn.Linear(in_features=12*4*4,out_features=120)
        self.fc2=torch.nn.Linear(in_features=120,out_features=60)
        self.out=torch.nn.Linear(in_features=60,out_features=10)
        #是因为有10个类别,所以是10维

    def forward(self,t):#搭建正向结构
        t=self.conv1(t)
        t=F.relu(t)
        t=F.max_pool2d(t,kernel_size=2,stride=2)
        t=self.conv2(t)
        t=F.relu(t)
        t=F.max_pool2d(t,kernel_size=2,stride=2)
        t=t.reshape(-1,12*4*4)
        t=self.fc1(t)
        t=F.relu(t)
        #第二层全连接
        t=self.fc2(t)
        t=F.relu(t)
        #第三层全连接
        t=self.out(t)
        return t

if __name__=='__main__':
    network=myConNet()
    #指定设备
    device=torch.device("cuda:0")
    print(device)
    network.to(device)#将模型对象转储到GPU设备上
    print(network)#打印网络结构
    criterion=torch.nn.CrossEntropyLoss()
    optimizer=torch.optim.Adam(network.parameters(),lr=0.01)
    for epoch in range(num_epoch):
        running_loss=0.0
        for i,data in enumerate(train_loader,0):
            inputs,labels=data
            inputs,labels=inputs.to(device),labels.to(device)
            optimizer.zero_grad()#清空之前的梯度
            outputs=network(inputs)
            loss=criterion(outputs,labels)#计算损失
            loss.backward()#反向传播
            optimizer.step()#更新参数
            running_loss+=loss.item()
            if i%1000==999:
                print('[%d,%5d] loss: %.3f' % (epoch+1,i+1,running_loss/2000))
                running_loss=0.0
    print('finish training')
    torch.save(network.state_dict(),'./FashionMNIST.pth')#保存模型

输出:
【PyTorch】识别黑白图中的服装图案_第4张图片
【PyTorch】识别黑白图中的服装图案_第5张图片

参考:《PyTorch深度学习和图神经网络》

你可能感兴趣的:(Pytoch,深度学习,pytorch,cnn)