pytorch detach().numpy()

 

    for epoch in range(EPOCH):
        sum_D = 0
        sum_G = 0
        for step, (images, imagesLabel) in enumerate(train_loader):
            print(step)
            G_ideas = t.randn((BATCH_SIZE, Len_Z, 1, 1))

            G_paintings = G(G_ideas)
            prob_artist0 = D(images)  # D try to increase this prob
            prob_artist1 = D(G_paintings)
            p0 = t.squeeze(prob_artist0)
            p1 = t.squeeze(prob_artist1)

            errD_real = criterion(p0, label_Real)

            errD_fake = criterion(p1, label_Fake)
            # errD_fake.backward()

            errD = errD_fake + errD_real
            errG = criterion(p1, label_Real)
            sum_D=sum_D+errD.detach().numpy()
            sum_G=sum_G+errG.detach().numpy()
            #print("errD is %f"%errD)
            #print("sumD is %f"%sum_D)
            optimD.zero_grad()
            errD.backward(retain_graph=True)
            optimD.step()

            optimG.zero_grad()
            errG.backward(retain_graph=True)
            optimG.step()

今天在实验时直接使用sum_D=sum_D+errD,发现内存快速飙升。后来改成sum_D=sum_D+errD.detach().numpy(),总算没问题了,因为第一种表达式等于是在搭网络节点,当然会不断提升网络容量,提高内存消耗量。

你可能感兴趣的:(深度学习)