pytorch深度学习总结1

最近使用pytorch踩过的一些坑,记录一下,偏应用。

1.图片加载

pytorch中的datasets.ImageFolder函数直接可以读取自己的图片的数据集。
数据集存放:
把每一类的图片放到一个文件夹里面,加载时地址只用写到类别文件夹的上一级目录。例如下图中dataset文件夹存放了4个类别的图片,那么图片加载时写入的地址就是** ‘F:\dataset’** 。datasets.ImageFolder会自动根据文件夹类别给数据打上标签。
pytorch深度学习总结1_第1张图片

from torchvision import datasets, transforms
import torch
import os
ef load_data(root_path, dir, batch_size, phase):
    transform_dict = {
        'tar':transforms.Compose(
            [transforms.Resize(224),
             transforms.ToTensor(),
             transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225]),])}
    data = datasets.ImageFolder(root=os.path.join(root_path,dir), transform=transform_dict[phase])  ##即各类别文件夹所在目录的上一级目录
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True,drop_last=False, num_workers=4)
    #设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。drop_last=False不丢弃
    return data_loader

上面这段示例代码可以当作模板使用,但其实我还有一个问题没搞懂,就是transforms.Normalize标准化时输入的平均值和标准差为什么不是0.5,而是一堆奇怪的小数,有哪位大佬路过可以帮忙解答一下。

2.模型搭建

pytorch的模型搭建这里,没有什么特别要记录的,网上很多例子,照着搭自己的模型就行。只有一点,是在我真正开始动手操作的时候才发现的,之前照着书学习的时候没发现或者看到了没有注意就略过了。

在pytorch中搭建一个My_Net类作为自己的模型,在调用时按照下面流程调用传入自己的数据就行,它会直接执行My_Net类中的forward 函数,完成前向传播,不需要单独调用forward函数。

class My_Net(nn.Module):
    def __init__(self,):
        super(My_Net,self).__init__()
        ……
    def forward(data):
        ……
        return result
    def my_loss(  ):
        ……
        return loss
 #调用
 data
 model = My_Net()
 result=model(data)

你可能感兴趣的:(深度学习,pytorch)