一、张量的变形
1.张量的形状变换
代码
import torch
import numpy as np
if __name__ == '__main__':
t = torch.randn(4, 6)
print(t.shape)
t1 = t.view(3, 8)
print(t1.size())
t2 = t.view(-1, 1)
print(t2.shape)
print('----------------------------------------')
t = torch.randn(12, 4, 4, 3)
print(t.shape)
tt = t.view(12, 4 * 4 * 3)
print(tt.shape)
结果
2.去除张量中维度为1的维度
代码
import torch
import numpy as np
if __name__ == '__main__':
t = torch.randn(4, 6)
t = t.view(1, 4, 6)
print(t.size())
t1 = torch.squeeze(t)
print(t1.shape)
结果
二、张量的自动微分
代码
import torch
import numpy as np
if __name__ == '__main__':
t = torch.ones(2, 2, requires_grad=True)
y = t + 5
print(y.grad_fn)
print(y.requires_grad)
z = y * 2
print(z.grad_fn)
print(z.requires_grad)
out = z.mean()
out.backward()
result = out.detach()
print(result.requires_grad)
print(t.grad)
with torch.no_grad():
yy = t + 2
print(yy.requires_grad)
结果