torch.Tensor() =torchFloatTensor()
生成单精度浮点类型张量
torch.tensor() 普通函数
PyTorch Fashion
1、prepare dataset
2、design model using Class # 目的是为了前向传播forward,即计算y hat(预测值)
3、Construct loss and optimizer (using PyTorch API)
其中,计算loss是为了进行反向传播,optimizer是为了更新梯度。
4、Training cycle (forward,backward,update)
numpy中的广播
广播(broadcast)是numpy中经常使用的一个技能点,他能够对不同形状的数组进行各种方式的计算。
模型:线性单元
模型内必须实现以下两个函数
class LinearModel(torch.nn.Module):#继承moudle
def __init__(self):#构造函数
super(LinearModel,self).__init__()
self.linear = torch.nn.Linear(1, 1)
#构造对象,并说明输入输出的维数,第三个参数默认为true,表示用到b
#torch.nn.Linear是pytorch中的一个类,(1,1)构造对象,包含权重和偏置
def forward(self, x):
y_pred = self.linear(x))#callable可调用对象,计算y=wx+b
return y_pred
model = LinearModel()#实例化模型
如果要使用一个可调用对象,那么在类的声明的时候要定义一个 call()函数就OK了,就像这样
class Foobar:
def __init__(self):
pass
def __call__(self,*args,**kwargs):
pass
其中参数*args代表把前面n个参数变成n元组,**kwargsd会把参数变成一个词典,举个例子:
def func(*args,**kwargs):
print(args)
print(kwargs)
#调用一下
func(1,2,3,4,x=3,y=5)
结果:
(1,2,3,4)
{‘x’:3,‘y’:5}
损失函数
criterion = torch.nn.MSELoss(reduction = 'sum')
优化器
优化器通过参数知道要对那些权重做优化,也知道学习了lr
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)
# model.parameters()自动完成参数的初始化操作,这个地方我可能理解错了
import torch
import matplotlib.pyplot as plt
x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.linear = torch.nn.Linear(1, 1)
def forward(self, x):
y_pred = self.linear(x)
return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
epoch_list = []
loss_list = []
for epoch in range(1000):
y_pred = model(x_data)
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
epoch_list.append(epoch)
loss_list.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()#更新
print('w= ', model.linear.weight.item())
print('b= ', model.linear.bias.item())
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred= ', y_test.data)
plt.plot(epoch_list, loss_list)
plt.show()