pytorch常用代码集合

一、固定随机种子 

def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seeds()

二、对数据集标准化

from torchvision import transforms as T
transform = T.Compose([
    T.ToTensor(),
    T.Normalize([0.625, 0.448, 0.688], [0.131, 0.177, 0.101]),
])

三、数据增强

import albumentations as A

trfm = A.Compose([
    A.Resize(NEW_SIZE,NEW_SIZE),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    
    A.OneOf([
        A.RandomContrast(),
        A.RandomGamma(),
        A.RandomBrightness(),
        A.ColorJitter(brightness=0.07, contrast=0.07,
                   saturation=0.1, hue=0.1, always_apply=False, p=0.3),
        ], p=0.3),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        A.GridDistortion(),
        A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
        ], p=0.0),
    A.ShiftScaleRotate(),
])

 四、每个N个训练样本抽取一个验证样本

# 假设数据集为ds,每隔7个训练样本抽取一个验证样本
valid_idx = []
train_idx = []
for i in range(len(ds)):
    if ds.slices[i][0] == 7:
        valid_idx.append(i)
    else:
        train_idx.append(i)


train_ds = D.Subset(ds, train_idx)
valid_ds = D.Subset(ds, valid_idx)

trn_loader = D.DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = D.DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

五、修改resnet的head

# 此处以resnet50为例。
# 其中pth文件假设已经下载到某个文件夹。
# 加载权重后,将head改为1个类别的分类器

def get_model():
    model = torchvision.models.segmentation.fcn_resnet50(False)
    
    pth = torch.load("../input/pretrain-coco-weights-pytorch/fcn_resnet50_coco-1167a1af.pth")
    for key in ["aux_classifier.0.weight", "aux_classifier.1.weight", "aux_classifier.1.bias", "aux_classifier.1.running_mean", "aux_classifier.1.running_var", "aux_classifier.1.num_batches_tracked", "aux_classifier.4.weight", "aux_classifier.4.bias"]:
        del pth[key]
    
    model.classifier[4] = nn.Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
    return model

你可能感兴趣的:(#,pytorch,人工智能,python)