[pytorch]FixMatch代码详解(超详细)

FixMatch代码详解-训练过程

  • 参数 default parameters
  • 数据产生 generate data
  • 构建模型 Build the model
  • 训练参数设置 Training parameter settings
    • weight decay(权值衰减)
    • 学习率衰减(learning rate decay)
    • 指数移动平均(EMA)model
  • 训练过程 training process
  • 运行结果 result

上一篇大概讲了数据加载的过程,这一篇更进一步,分析一下训练是怎样进行的
上一篇链接: [pytorch]FixMatch代码详解-数据加载

思维导图如下链接,非常详细的写出了代码的整体框架
思维导图

[pytorch]FixMatch代码详解(超详细)_第1张图片

参数 default parameters

数据集链接
4000个带标签的数据集,也就是每个类400张带标签的数据

所有的参数我都默认使用作者给出的例子:

python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/cifar10@4000.5

其运行时每个参数的值如下:

INFO - __main__ -   {'T': 1, 'amp': False, 'arch': 'wideresnet', 'batch_size': 64, 'dataset': 'cifar10', 'device': device(type='cuda', index=0), 'ema_decay': 0.999, 'eval_step': 1024, 'expand_labels': True, 'gpu_id': 0, 'lambda_u': 1, 'local_rank': -1, 'lr': 0.03, 'mu': 7, 'n_gpu': 1, 'nesterov': True, 'no_progress': False, 'num_labeled': 4000, 'num_workers': 4, 'opt_level': 'O1', 'out': 'results/[email protected]', 'resume': '', 'seed': 5, 'start_epoch': 0, 'threshold': 0.95, 'total_steps': 1048576, 'use_ema': True, 'warmup': 0, 'wdecay': 0.0005, 'world_size': 1}

然后我们将这些参数带入,看看每一步是怎样运行的.

数据产生 generate data

首先,是产生带标签和不带标签数据的索引,其在cifar.py文件中的代码分析见上篇

base_dataset = datasets.CIFAR10(
        './CIFAR10', train=True, download=True)
labels = base_dataset.targets
label_per_class = 4000 // 10
labels = np.array(labels)
labeled_idx = []
# unlabeled data: all data (https://github.com/kekmodel/FixMatch-pytorch/issues/10)
unlabeled_idx = np.array(range(len(labels)))
for i in range(10):
    idx = np.where(labels == i)[0]
    idx = np.random.choice(idx, label_per_class, False)
    labeled_idx.extend(idx)
labeled_idx = np.array(labeled_idx)
print('number labeled_idx =',len(labeled_idx))
assert len(labeled_idx) == 4000

if True or 4000 < 64:
    num_expand_x = math.ceil(
        64 * 1024 / 4000)  #16.384 = 17
    labeled_idx = np.hstack([labeled_idx for _ in range(num_expand_x)])
np.random.shuffle(labeled_idx)
print('number labeled_idx = ',len(labeled_idx))
print('number unlabeled_idx =', len(unlabeled_idx))
train_labeled_idxs = labeled_idx
train_unlabeled_idxs = unlabeled_idx

结果如下,不带标签的数据使用了所有的数据,而带标签的数据经过数据扩增之后为68000个

number labeled_idx = 4000
number labeled_idx =  68000
number unlabeled_idx = 50000

让我们看一下图片的变化
首先,是不带任何变化的原始数据图像:

train_labeled_dataset = CIFAR10SSL(
        './data', train_labeled_idxs, train=True,
        transform=transforms.ToTensor())
train_iter = iter(train_labeled_dataset)
# 可视化方法,重复执行可得到不同的图片数据
imgs, label = next(train_iter)
print(image.size) # (32, 32)
image = transforms.ToPILImage()(imgs).convert('RGB')
image.show()
print(label)

[pytorch]FixMatch代码详解(超详细)_第2张图片
然后,我们使用不带数据增强的变化,也就是作者对验证集使用的图像变化. ToTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1). 注意图片大小没有变化,只是我截图的时候放大了图片.

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)])
train_labeled_dataset = CIFAR10SSL(
        './data', train_labeled_idxs, train=True,
        transform=transform_val)
train_iter = iter(train_labeled_dataset)

imgs, label = next(train_iter)
print(image.size) # (32, 32)
image = transforms.ToPILImage()(imgs).convert('RGB')
image.show()
print(label)

[pytorch]FixMatch代码详解(超详细)_第3张图片
然后我们看看带数据的图片所使用的数据增强(两次)

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)
transform_labeled = transforms.Compose([
    transforms.RandomHorizontalFlip(), #Horizontally flip the given image randomly with a given probability.
    transforms.RandomCrop(size=32,
                          padding=int(32*0.125),
                          padding_mode='reflect'),
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])
train_labeled_dataset = CIFAR10SSL(
        './data', train_labeled_idxs, train=True,
        transform=transform_labeled)
train_iter = iter(train_labeled_dataset)

imgs, label = next(train_iter)
print(image.size) # (32, 32)
image = transforms.ToPILImage()(imgs).convert('RGB')
image.show()
print(label) # 2 

[pytorch]FixMatch代码详解(超详细)_第4张图片[pytorch]FixMatch代码详解(超详细)_第5张图片

对于不带数据的标签,我们有两种数据增强,弱增强和强增强. 强增强操作在论文中的描述.
[pytorch]FixMatch代码详解(超详细)_第6张图片
在这里插入图片描述

class TransformFixMatch(object):
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=32,
                                  padding=int(32*0.125),
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)
# 强增强的操作。在randaugment.py文件中
def fixmatch_augment_pool():
    # FixMatch paper
    augs = [(AutoContrast, None, None),
            (Brightness, 0.9, 0.05),
            (Color, 0.9, 0.05),
            (Contrast, 0.9, 0.05),
            (Equalize, None, None),
            (Identity, None, None),
            (Posterize, 4, 4),
            (Rotate, 30, 0),
            (Sharpness, 0.9, 0.05),
            (ShearX, 0.3, 0),
            (ShearY, 0.3, 0),
            (Solarize, 256, 0),
            (TranslateX, 0.3, 0),
            (TranslateY, 0.3, 0)]
    return augs
    
class RandAugmentMC(object):
    def __init__(self, n, m):
        assert n >= 1
        assert 1 <= m <= 10
        self.n = n
        self.m = m
        self.augment_pool = fixmatch_augment_pool()

    def __call__(self, img):
        ops = random.choices(self.augment_pool, k=self.n)
        for op, max_v, bias in ops:
            v = np.random.randint(1, self.m)
            if random.random() < 0.5:
                img = op(img, v=v, max_v=max_v, bias=bias)
        img = CutoutAbs(img, int(32*0.5))
        return img
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

train_labeled_dataset = CIFAR10SSL(
        './data', train_labeled_idxs, train=True,
        transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
train_iter = iter(train_labeled_dataset)

(inputs_u_w, inputs_u_s), _ = next(train_iter)
print(inputs_u_s.size) # (32, 32)
image = transforms.ToPILImage()(inputs_u_s).convert('RGB')
image.show()

弱增强的图像结果(两次):
[pytorch]FixMatch代码详解(超详细)_第7张图片[pytorch]FixMatch代码详解(超详细)_第8张图片
强增强的结果(运行四次):

[pytorch]FixMatch代码详解(超详细)_第9张图片[pytorch]FixMatch代码详解(超详细)_第10张图片[pytorch]FixMatch代码详解(超详细)_第11张图片[pytorch]FixMatch代码详解(超详细)_第12张图片
所以,产生的带标签/不带标签/验证集的dataset类及dataloader如下:

labeled_dataset = CIFAR10SSL(
    './data', train_labeled_idxs, train=True,
    transform=transform_labeled)
# len = 68000
unlabeled_dataset = CIFAR10SSL(
    './data', train_unlabeled_idxs, train=True,
    transform=TransformFixMatch(mean=cifar10_mean, std=cifar10_std))
# len = 50000
test_dataset = datasets.CIFAR10(
    './data', train=False, transform=transform_val, download=False)
# len = 10000
train_sampler = RandomSampler
labeled_trainloader = DataLoader(
        labeled_dataset,
        sampler=train_sampler(labeled_dataset),
        batch_size=64,
        num_workers=4,
        drop_last=True)
# len = 6800/64 = 1062.5 (drop_last=True) = 1062 
unlabeled_trainloader = DataLoader(
    unlabeled_dataset,
    sampler=train_sampler(unlabeled_dataset),
    batch_size=64*7, # mu coefficient of unlabeled batch size 原文中的超参数μ 
    num_workers=4,
    drop_last=True)
# len = 50000/(64*7) = 111
test_loader = DataLoader(
    test_dataset,
    sampler=SequentialSampler(test_dataset),
    batch_size=64,
    num_workers=7)
# len = 10000/64 = 156.25(drop_last=False)= 157 

构建模型 Build the model

def create_model():
    import models.wideresnet as models
    model = models.build_wideresnet(depth=28,
                                    widen_factor=2,
                                    dropout=0,
                                    num_classes=10)
    return model
    
model = create_model()
#print(model)
#for p in model.parameters():
#    print(p.numel())
total_num = sum(p.numel() for p in model.parameters())
print(total_num) # 1467610 模型总参数

训练参数设置 Training parameter settings

在参数设置时,有许多模型训练的tricks. 我简单的说一下他们的设置. 这里是作者的一些结论.
[pytorch]FixMatch代码详解(超详细)_第13张图片
[pytorch]FixMatch代码详解(超详细)_第14张图片

weight decay(权值衰减)

weight decay(权值衰减)其目的是防止过拟合。在损失函数中,weight decay是放在正则项(regularization)前面的一个系数,正则项一般指示模型的复杂度,所以weight decay的作用是调节模型复杂度对损失函数的影响,若weight decay很大,则复杂的模型损失函数的值也就大。
同时,作者也提到了要使用SGD优化器。

# weight decay default=5e-4
no_decay = ['bias', 'bn']
grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(
            nd in n for nd in no_decay)], 'weight_decay': 5e-4},
        {'params': [p for n, p in model.named_parameters() if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
optimizer = optim.SGD(grouped_parameters, lr=0.03,
                      momentum=0.9, nesterov=True)

除了 bias和bn层,其他层使用weight decay.
[pytorch]FixMatch代码详解(超详细)_第15张图片
[pytorch]FixMatch代码详解(超详细)_第16张图片

学习率衰减(learning rate decay)

正如作者在原文中提到的,对于学习率调整,我们使用余弦学习率衰减. 同时还加上了Warmup操作. 学习率一开始很小,在到达设定的num_warmup_steps前,学习率慢慢增大,最后达到设定的学习率的值。之后,使用余弦学习率衰减,其公式如上面原文中提到的。

def get_cosine_schedule_with_warmup(optimizer,
                                    num_warmup_steps,
                                    num_training_steps,
                                    num_cycles=7./16.,
                                    last_epoch=-1):
    def _lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        no_progress = float(current_step - num_warmup_steps) / \
            float(max(1, num_training_steps - num_warmup_steps))
        return max(0., math.cos(math.pi * num_cycles * no_progress))

    return LambdaLR(optimizer, _lr_lambda, last_epoch)

scheduler = get_cosine_schedule_with_warmup(optimizer, 0, 2**20)

学习率是神经网络训练中最重要的超参数之一,针对学习率的优化方式很多,Warmup是其中的一种
(一)、什么是Warmup?
Warmup是在ResNet论文中提到的一种学习率预热的方法,它在训练开始的时候先选择使用一个较小的学习率,训练了一些epoches或者steps(比如4个epoches,10000steps),再修改为预先设置的学习来进行训练。

(二)、为什么使用Warmup?
由于刚开始训练时,模型的权重(weights)是随机初始化的,此时若选择一个较大的学习率,可能带来模型的不稳定(振荡),选择Warmup预热学习率的方式,可以使得开始训练的几个epoches或者一些steps内学习率较小,在预热的小学习率下,模型可以慢慢趋于稳定,等模型相对稳定后再选择预先设置的学习率进行训练,使得模型收敛速度变得更快,模型效果更佳。

ExampleExample:Resnet论文中使用一个110层的ResNet在cifar10上训练时,先用0.01的学习率训练直到训练误差低于80%(大概训练了400个steps),然后使用0.1的学习率进行训练。

自定义调整:自定义调整学习率 LambdaLR。
torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)
[pytorch]FixMatch代码详解(超详细)_第17张图片

指数移动平均(EMA)model

This algorithm is one of the most important algorithms currently in usage. From financial time series, signal processing to neural networks , it is being used quite extensively. Basically any data that is in a sequence.
We mostly use this algorithm to reduce the noise in noisy time-series data. The term we use for this is called “smoothing” the data.
The way we achieve this is by essentially weighing the number of observations and using their average. This is called as Moving Average.
In deep learning, the EMA (Exponential Moving Average) method is often used to average the parameters of the model in order to improve the test index and increase the robustness of the model.
在深度学习中,经常会使用EMA(指数移动平均)这个方法对模型的参数做平均,以求提高测试指标并增加模型鲁棒。
这个技巧我也不是很懂,可以看别人的文章介绍: 【炼丹技巧】指数移动平均(EMA)的原理及PyTorch实现

训练过程 training process

[pytorch]FixMatch代码详解(超详细)_第18张图片

都写在注释中了,每一步的过程很清楚

# 准备
epochs = math.ceil(2**20/ 1024) #1024 总epoch
start_epoch = 0
test_accs = []
end = time.time() #返回当前时间的时间戳
def interleave(x, size):
    s = list(x.shape)
    return x.reshape([-1, size] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])
def de_interleave(x, size):
    s = list(x.shape)
    return x.reshape([size, -1] + s[1:]).transpose(0, 1).reshape([-1] + s[1:])
labeled_iter = iter(labeled_trainloader)
unlabeled_iter = iter(unlabeled_trainloader)
model.train()
for epoch in range(start_epoch, epochs):
    #batch_time = AverageMeter()#它仅用于计算和存储一些统计信息,例如关于损失的统计信息。
    #data_time = AverageMeter()
    #losses = AverageMeter()
    #losses_x = AverageMeter()
    #losses_u = AverageMeter()
    #mask_probs = AverageMeter()
    p_bar = tqdm(range(1024))
    for batch_idx in range(1024):
        
        # 使用iter(next)读取指定次数的batch,而不通过Dataloader。Dataloader的长度也不同。
        try:
            inputs_x, targets_x = labeled_iter.next()
            #print(inputs_x.shape) # torch.Size([64, 3, 32, 32])
            #print(targets_x.shape) # torch.Size([64])
            #print(targets_x)
        except:  # 当循环结束时,重新开始循环
            labeled_iter = iter(labeled_trainloader)
            inputs_x, targets_x = labeled_iter.next()
        try:
            (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
            #print(inputs_u_w.shape) #torch.Size([448, 3, 32, 32])
            #print(inputs_u_s.shape) #torch.Size([448, 3, 32, 32])
        except:
            unlabeled_iter = iter(unlabeled_trainloader)
            (inputs_u_w, inputs_u_s), _ = unlabeled_iter.next()
        # print(time.time() - end) # data_time = 200秒左右 读取一组数据的时间
        
        
        batch_size = inputs_x.shape[0] #64
        new_data = interleave(
                torch.cat((inputs_x, inputs_u_w, inputs_u_s)), 2*7+1) #'mu': 7
        # print(new_data.shape) torch.Size([960, 3, 32, 32]) 448+448+64 64*(2*7+1) 将数据合并一起
        inputs = new_data.to(device)
        targets_x = targets_x.to(device)
        
        
        logits = model(inputs)
        #print(logits.shape) #torch.Size([960, 10])
        logits = de_interleave(logits, 2*7+1)
        #print(logits.shape) #torch.Size([960, 10])
        logits_x = logits[:batch_size]
        #print(logits_x.shape) #torch.Size([64, 10])
        logits_u_w, logits_u_s = logits[batch_size:].chunk(2)
        #print(logits_u_w.shape) #torch.Size([448, 10]) 
        
        #通过weak_augment样本计算伪标记pseudo label和mask,
        #其中,mask用来筛选哪些样本最大预测概率超过阈值,可以拿来使用,哪些不能使用
        
        Lx = F.cross_entropy(logits_x, targets_x, reduction='mean') #带标签数据的loss
        #print(Lx) # tensor(2.6575, device='cuda:0', grad_fn=)
        pseudo_label = torch.softmax(logits_u_w.detach()/1, dim=-1) #输出变成概率
        # pseudo label temperature = 1 原来的softmax函数是T = 1的特例。 
        # T越高,softmax的output probability distribution越趋于平滑,其分布的熵越大,
        # 负标签携带的信息会被相对地放大,模型训练将更加关注负标签。
        max_probs, targets_u = torch.max(pseudo_label, dim=-1)
        #print(max_probs.shape) # torch.Size([448]) 448个最大概率值
        #print(targets_u.shape) # torch.Size([448]) 448个伪标签的值
        #print(targets_u) #tensor([3, 5, 1 ....], device='cuda:0')
        mask = max_probs.ge(0.95).float() #'threshold': 0.95
        # torch.ge(a,b)逐个元素比较a,b的大小
        # print(mask.shape) #torch.Size([448]) 448个0/1
        # print(F.cross_entropy(logits_u_s, targets_u,reduction='none')) # reduction='none'不求平均,返回448个值
        Lu = (F.cross_entropy(logits_u_s, targets_u,
                                  reduction='none') * mask).mean() #不带标签数据的loss,其中通过mask进行样本筛选
        #print(Lu) #tensor(0., device='cuda:0', grad_fn=)
                
        loss = Lx + 1 * Lu # 'lambda_u': 1 #完整损失函数
        
        
        print(time.time() - end) # 3439秒 batch_time 计算完一组数据的时间
        end = time.time() #为下一轮做准备
        print(mask.mean().item()) # mask_probs = mask的均值 代表超过threshold的个数比例

运行结果 result

首先先来看一下程序的运行结果
开头

(torch) liyihao@liyihao-Precision-5820-Tower:~/LI/FixMatch-pytorch-master$ python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/test1
02/16/2022 17:48:12 - WARNING - __main__ -   Process rank: -1, device: cuda:0, n_gpu: 1, distributed training: False, 16-bits training: False
02/16/2022 17:48:12 - INFO - __main__ -   {'T': 1, 'amp': False, 'arch': 'wideresnet', 'batch_size': 64, 'dataset': 'cifar10', 'device': device(type='cuda', index=0), 'ema_decay': 0.999, 'eval_step': 1024, 'expand_labels': True, 'gpu_id': 0, 'lambda_u': 1, 'local_rank': -1, 'lr': 0.03, 'mu': 7, 'n_gpu': 1, 'nesterov': True, 'no_progress': False, 'num_labeled': 4000, 'num_workers': 4, 'opt_level': 'O1', 'out': 'results/test1', 'resume': '', 'seed': 5, 'start_epoch': 0, 'threshold': 0.95, 'total_steps': 1048576, 'use_ema': True, 'warmup': 0, 'wdecay': 0.0005, 'world_size': 1}
Files already downloaded and verified
02/16/2022 17:48:14 - INFO - models.wideresnet -   Model: WideResNet 28x2
02/16/2022 17:48:14 - INFO - __main__ -   Total params: 1.47M
02/16/2022 17:48:18 - INFO - __main__ -   ***** Running training *****
02/16/2022 17:48:18 - INFO - __main__ -     Task = cifar10@4000
02/16/2022 17:48:18 - INFO - __main__ -     Num Epochs = 1024
02/16/2022 17:48:18 - INFO - __main__ -     Batch size per GPU = 64
02/16/2022 17:48:18 - INFO - __main__ -     Total train batch size = 64
02/16/2022 17:48:18 - INFO - __main__ -     Total optimization steps = 1048576
Train Epoch: 1/1024. Iter: 1024/1024. LR: 0.0300. Data: 0.045s. Batch: 0.207s. Loss: 1.2336. Loss_x: 1.1920. Loss_u: 0.0416. Mask: 0.07. : 100%|| 102
Test Iter:  157/ 157. Data: 0.005s. Batch: 0.012s. Loss: 1.8805. top1: 31.71. top5: 81.31. : 100%|██████████████████| 157/157 [00:02<00:00, 77.27it/s]
02/16/2022 17:51:52 - INFO - __main__ -   top-1 acc: 31.71
02/16/2022 17:51:52 - INFO - __main__ -   top-5 acc: 81.31
02/16/2022 17:51:52 - INFO - __main__ -   Best top-1 acc: 31.71
02/16/2022 17:51:52 - INFO - __main__ -   Mean top-1 acc: 31.71

Train Epoch: 2/1024. Iter: 1024/1024. LR: 0.0300. Data: 0.046s. Batch: 0.206s. Loss: 0.7871. Loss_x: 0.6212. Loss_u: 0.1659. Mask: 0.31. : 100%|| 102
Test Iter:  157/ 157. Data: 0.005s. Batch: 0.012s. Loss: 0.9442. top1: 66.99. top5: 97.58. : 100%|██████████████████| 157/157 [00:01<00:00, 80.80it/s]
02/16/2022 17:55:22 - INFO - __main__ -   top-1 acc: 66.99
02/16/2022 17:55:22 - INFO - __main__ -   top-5 acc: 97.58
02/16/2022 17:55:22 - INFO - __main__ -   Best top-1 acc: 66.99
02/16/2022 17:55:22 - INFO - __main__ -   Mean top-1 acc: 49.35

Train Epoch: 3/1024. Iter: 1024/1024. LR: 0.0300. Data: 0.045s. Batch: 0.206s. Loss: 0.5908. Loss_x: 0.3215. Loss_u: 0.2692. Mask: 0.50. : 100%|| 102
Test Iter:  157/ 157. Data: 0.005s. Batch: 0.012s. Loss: 0.6990. top1: 75.80. top5: 98.54. : 100%|██████████████████| 157/157 [00:02<00:00, 77.19it/s]
02/16/2022 17:58:53 - INFO - __main__ -   top-1 acc: 75.80
02/16/2022 17:58:53 - INFO - __main__ -   top-5 acc: 98.54
02/16/2022 17:58:54 - INFO - __main__ -   Best top-1 acc: 75.80
02/16/2022 17:58:54 - INFO - __main__ -   Mean top-1 acc: 58.17

运行了100多个epoch之后

Train Epoch: 150/1024. Iter: 1024/1024. LR: 0.0294. Data: 0.017s. Batch: 0.157s. Loss: 0.2174. Loss_x: 0.0090. Loss_u: 0.2084. Mask: 0.90. : 100%|| 1
Test Iter:  157/ 157. Data: 0.004s. Batch: 0.009s. Loss: 0.2418. top1: 94.17. top5: 99.87. : 100%|██████████████████| 157/157 [00:01<00:00, 99.77it/s]
02/17/2022 00:54:18 - INFO - __main__ -   top-1 acc: 94.17
02/17/2022 00:54:18 - INFO - __main__ -   top-5 acc: 99.87
02/17/2022 00:54:18 - INFO - __main__ -   Best top-1 acc: 94.28
02/17/2022 00:54:18 - INFO - __main__ -   Mean top-1 acc: 94.03

Train Epoch: 151/1024. Iter: 1024/1024. LR: 0.0294. Data: 0.018s. Batch: 0.158s. Loss: 0.2118. Loss_x: 0.0066. Loss_u: 0.2052. Mask: 0.90. : 100%|| 1
Test Iter:  157/ 157. Data: 0.004s. Batch: 0.010s. Loss: 0.2393. top1: 94.37. top5: 99.91. : 100%|██████████████████| 157/157 [00:01<00:00, 89.00it/s]
02/17/2022 00:57:00 - INFO - __main__ -   top-1 acc: 94.37
02/17/2022 00:57:00 - INFO - __main__ -   top-5 acc: 99.91
02/17/2022 00:57:00 - INFO - __main__ -   Best top-1 acc: 94.37
02/17/2022 00:57:00 - INFO - __main__ -   Mean top-1 acc: 94.05

Train Epoch: 152/1024. Iter: 1024/1024. LR: 0.0294. Data: 0.017s. Batch: 0.158s. Loss: 0.2209. Loss_x: 0.0097. Loss_u: 0.2113. Mask: 0.90. : 100%|| 1
Test Iter:  157/ 157. Data: 0.004s. Batch: 0.009s. Loss: 0.2414. top1: 94.19. top5: 99.86. : 100%|█████████████████| 157/157 [00:01<00:00, 100.27it/s]
02/17/2022 00:59:41 - INFO - __main__ -   top-1 acc: 94.19
02/17/2022 00:59:41 - INFO - __main__ -   top-5 acc: 99.86
02/17/2022 00:59:41 - INFO - __main__ -   Best top-1 acc: 94.37
02/17/2022 00:59:41 - INFO - __main__ -   Mean top-1 acc: 94.06

Train Epoch: 153/1024. Iter: 1024/1024. LR: 0.0294. Data: 0.017s. Batch: 0.159s. Loss: 0.2210. Loss_x: 0.0110. Loss_u: 0.2100. Mask: 0.90. : 100%|| 1
Test Iter:  157/ 157. Data: 0.003s. Batch: 0.009s. Loss: 0.2439. top1: 94.07. top5: 99.87. : 100%|█████████████████| 157/157 [00:01<00:00, 101.27it/s]
02/17/2022 01:02:24 - INFO - __main__ -   top-1 acc: 94.07
02/17/2022 01:02:24 - INFO - __main__ -   top-5 acc: 99.87
02/17/2022 01:02:24 - INFO - __main__ -   Best top-1 acc: 94.37
02/17/2022 01:02:24 - INFO - __main__ -   Mean top-1 acc: 94.06

tensorboard看看各个参数的变化
验证集的参数变化
[pytorch]FixMatch代码详解(超详细)_第19张图片
训练集上的参数变化
[pytorch]FixMatch代码详解(超详细)_第20张图片

[pytorch]FixMatch代码详解(超详细)_第21张图片

你可能感兴趣的:(半监督学习,pytorch,深度学习,python)