莫烦python视频学习笔记 视频链接https://www.bilibili.com/video/BV1Vx411j7kT?from=search&seid=3065687802317837578
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
if __name__ == '__main__':
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 将数据放入数据库,用x来训练,用y来计算误差
# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader( # 将数据分批
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True, # Do you want to break this order?
num_workers=2,
)
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training...
print('Epoch', epoch, '|Step', step, '|batch x', batch_x.numpy(), '|batch y:', batch_y.numpy())
输出:
Epoch 0 |Step 0 |batch x [ 1. 3. 7. 6. 10.] |batch y: [10. 8. 4. 5. 1.]
Epoch 0 |Step 1 |batch x [5. 8. 9. 4. 2.] |batch y: [6. 3. 2. 7. 9.]
Epoch 1 |Step 0 |batch x [ 2. 10. 3. 7. 8.] |batch y: [9. 1. 8. 4. 3.]
Epoch 1 |Step 1 |batch x [9. 5. 1. 6. 4.] |batch y: [ 2. 6. 10. 5. 7.]
Epoch 2 |Step 0 |batch x [ 4. 1. 5. 10. 6.] |batch y: [ 7. 10. 6. 1. 5.]
Epoch 2 |Step 1 |batch x [2. 8. 7. 3. 9.] |batch y: [9. 3. 4. 8. 2.]
代码在运行中报错:init() got an unexpected keyword argument ‘data_tensor’
此处参考 (https://blog.csdn.net/thunderf/article/details/94733747)
其次还要注意代码的缩进问题。