本文采用pytorch进行天气预报预测,掌握神经网络模型训练的基本步骤
提示:本文通过采集部分天气预报数据,采用神经网络对相关数据进行预测分析,通过模型训练,掌握采用pytorh框架进行网络训练
提示:完整源代码可以私信我
步骤:
(1)功能函数导入
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import datetime
(2)数据预处理
采集部分地区某段实际的天气情况,进行整理汇总,保存为.csv格式,然后导入数据,数据形式如下:
data=pd.read_csv('temp.csv')
print(data.head())#默认的显示前5行
year month day week temp_2 temp_1 average actual random
0 2021 1 1 Fri 45 45 45.6 45 29
1 2021 1 2 Sat 44 45 45.7 44 61
2 2021 1 3 Sun 45 44 45.8 41 56
3 2021 1 4 Mon 44 41 45.9 40 53
4 2021 1 5 Tues 41 40 46.0 44 41
数据表中
year,moth,day,week分别表示的具体的时间
temp_2:前天的最高温度值
temp_1:昨天的最高温度值
average:在历史中,每年这一天的平均最高温度值
actual:这就是我们的标签值了,当天的真实最高温度
random:随机数据。
datas=[str(int(year))+'-'+str(int(month))+'-'+str(int(day)) for year,month ,day in zip(yaers,months,days)]
print(datas)
datas=[datetime.datetime.strptime(date,'%Y-%m-%d') for date in datas]
print(datas)
显示结果:
['2021-1-1', '2021-1-2', '2021-1-3', '2021-1-4', '2021-1-5', '2021-1-6', '2021-1-7', '2021-1-8', '2021-1-9', '2021-1-10', '2021-1-11', '2021-1-12', '2021-1-13', '2021-1-14', '2021-1-15', '2021-1-16', '2021-1-17', '2021-1-18', '2021-1-19', '2021-1-20', '2021-1-21', '2021-1-22', '2021-1-23', '2021-1-24', '2021-1-25', '2021-1-26', '2021-1-27', '2021-1-28', '2021-1-29', '2021-1-30', '2021-1-31', '2021-2-1', '2021-2-2', '2021-2-3', '2021-2-4', '2021-2-5', '2021-2-6', '2021-2-7', '2021-2-8', '2021-2-9', '2021-2-10', '2021-2-11', '2021-2-12', '2021-2-15', '2021-2-16', '2021-2-17', '2021-2-18', '2021-2-19', '2021-2-20', '2021-2-21', '2021-2-22', '2021-2-23', '2021-2-24', '2021-2-25', '2021-2-26', '2021-2-27', '2021-2-28', '2021-3-1', '2021-3-2', '2021-3-3', '2021-3-4', '2021-3-5', '2021-3-6', '2021-3-7', '2021-3-8', '2021-3-9', '2021-3-10', '2021-3-11', '2021-3-12', '2021-3-13', '2021-3-14', '2021-3-15', '2021-3-16', '2021-3-17', '2021-3-18', '2021-3-19', '2021-3-20', '2021-3-21', '2021-3-22', '2021-3-23', '2021-3-24', '2021-3-25', '2021-3-26', '2021-3-27', '2021-3-28', '2021-3-29', '2021-3-30', '2021-3-31', '2021-4-1', '2021-4-2', '2021-4-3', '2021-4-4', '2021-4-5', '2021-4-6', '2021-4-7', '2021-4-8', '2021-4-9', '2021-4-10', '2021-4-11', '2021-4-12', '2021-4-13', '2021-4-14', '2021-4-15', '2021-4-16', '2021-4-17', '2021-4-18', '2021-4-19', '2021-4-20', '2021-4-21', '2021-4-22', '2021-4-23', '2021-4-24', '2021-4-25', '2021-4-26', '2021-4-27', '2021-4-28', '2021-4-29', '2021-4-30', '2021-5-1', '2021-5-2', '2021-5-3', '2021-5-4', '2021-5-5', '2021-5-6', '2021-5-7', '2021-5-8', '2021-5-9', '2021-5-10', '2021-5-11', '2021-5-12', '2021-5-13', '2021-5-14', '2021-5-15', '2021-5-16', '2021-5-17', '2021-5-18', '2021-5-19', '2021-5-20', '2021-5-21', '2021-5-22', '2021-5-23', '2021-5-24', '2021-5-25', '2021-5-26', '2021-5-27', '2021-5-28', '2021-5-29', '2021-5-30', '2021-5-31', '2021-6-1', '2021-6-2', '2021-6-3', '2021-6-4', '2021-6-5', '2021-6-6', '2021-6-7', '2021-6-8', '2021-6-9', '2021-6-10', '2021-6-11', '2021-6-12', '2021-6-13', '2021-6-14', '2021-6-15', '2021-6-16', '2021-6-17', '2021-6-18', '2021-6-19', '2021-6-20', '2021-6-21', '2021-6-22', '2021-6-23', '2021-6-24', '2021-6-25', '2021-6-26', '2021-6-27', '2021-6-28', '2021-6-29', '2021-6-30', '2021-7-1', '2021-7-2', '2021-7-3', '2021-7-4', '2021-7-5', '2021-7-6', '2021-7-7', '2021-7-8', '2021-7-9', '2021-7-10', '2021-7-11', '2021-7-12', '2021-7-13', '2021-7-14', '2021-7-15', '2021-7-16', '2021-7-17', '2021-7-18', '2021-7-19', '2021-7-20', '2021-7-21', '2021-7-22', '2021-7-23', '2021-7-24', '2021-7-25', '2021-7-26', '2021-7-27', '2021-7-28', '2021-7-29', '2021-7-30', '2021-7-31', '2021-8-1', '2021-8-2', '2021-8-3', '2021-8-4', '2021-8-5', '2021-8-6', '2021-8-7', '2021-8-8', '2021-8-9', '2021-8-10', '2021-8-11', '2021-8-12', '2021-8-13', '2021-8-14', '2021-8-15', '2021-8-16', '2021-8-23', '2021-8-28', '2021-8-30', '2021-9-3', '2021-9-4', '2021-9-5', '2021-9-6', '2021-9-7', '2021-9-8', '2021-9-9', '2021-9-10', '2021-9-11', '2021-9-12', '2021-9-13', '2021-9-14', '2021-9-15', '2021-9-16', '2021-9-17', '2021-9-18', '2021-9-19', '2021-9-20', '2021-9-21', '2021-9-22', '2021-9-23', '2021-9-24', '2021-9-25', '2021-9-26', '2021-9-27', '2021-9-28', '2021-9-29', '2021-9-30', '2021-10-1', '2021-10-2', '2021-10-3', '2021-10-4', '2021-10-5', '2021-10-6', '2021-10-7', '2021-10-8', '2021-10-9', '2021-10-10', '2021-10-11', '2021-10-12', '2021-10-13', '2021-10-14', '2021-10-15', '2021-10-16', '2021-10-17', '2021-10-18', '2021-10-19', '2021-10-20', '2021-10-21', '2021-10-22', '2021-10-23', '2021-10-24', '2021-10-25', '2021-10-26', '2021-10-27', '2021-10-28', '2021-10-29', '2021-10-31', '2021-11-1', '2021-11-2', '2021-11-3', '2021-11-4', '2021-11-5', '2021-11-6', '2021-11-7', '2021-11-8', '2021-11-9', '2021-11-10', '2021-11-11', '2021-11-12', '2021-11-13', '2021-11-14', '2021-11-15', '2021-11-16', '2021-11-17', '2021-11-18', '2021-11-19', '2021-11-20', '2021-11-21', '2021-11-22', '2021-11-23', '2021-11-24', '2021-11-25', '2021-11-26', '2021-11-27', '2021-11-28', '2021-11-29', '2021-11-30', '2021-12-1', '2021-12-2', '2021-12-3', '2021-12-4', '2021-12-5', '2021-12-6', '2021-12-7', '2021-12-8', '2021-12-9', '2021-12-10', '2021-12-11', '2021-12-12', '2021-12-13', '2021-12-14', '2021-12-15', '2021-12-16', '2021-12-17', '2021-12-18', '2021-12-19', '2021-12-20', '2021-12-21', '2021-12-22', '2021-12-23', '2021-12-24', '2021-12-25', '2021-12-26', '2021-12-27', '2021-12-28', '2021-12-29', '2021-12-30', '2021-12-31', '2021-12-31', '2021-12-31']
[datetime.datetime(2021, 1, 1, 0, 0), datetime.datetime(2021, 1, 2, 0, 0), datetime.datetime(2021, 1, 3, 0, 0), datetime.datetime(2021, 1, 4, 0, 0), datetime.datetime(2021, 1, 5, 0, 0), datetime.datetime(2021, 1, 6, 0, 0), datetime.datetime(2021, 1, 7, 0, 0), datetime.datetime(2021, 1, 8, 0, 0), datetime.datetime(2021, 1, 9, 0, 0), datetime.datetime(2021, 1, 10, 0, 0), datetime.datetime(2021, 1, 11, 0, 0), datetime.datetime(2021, 1, 12, 0, 0), datetime.datetime(2021, 1, 13, 0, 0), datetime.datetime(2021, 1, 14, 0, 0), datetime.datetime(2021, 1, 15, 0, 0), datetime.datetime(2021, 1, 16, 0, 0), datetime.datetime(2021, 1, 17, 0, 0), datetime.datetime(2021, 1, 18, 0, 0), datetime.datetime(2021, 1, 19, 0, 0), datetime.datetime(2021, 1, 20, 0, 0), datetime.datetime(2021, 1, 21, 0, 0), datetime.datetime(2021, 1, 22, 0, 0), datetime.datetime(2021, 1, 23, 0, 0), datetime.datetime(2021, 1, 24, 0, 0), datetime.datetime(2021, 1, 25, 0, 0), datetime.datetime(2021, 1, 26, 0, 0), datetime.datetime(2021, 1, 27, 0, 0), datetime.datetime(2021, 1, 28, 0, 0), datetime.datetime(2021, 1, 29, 0, 0), datetime.datetime(2021, 1, 30, 0, 0), datetime.datetime(2021, 1, 31, 0, 0), datetime.datetime(2021, 2, 1, 0, 0), datetime.datetime(2021, 2, 2, 0, 0), datetime.datetime(2021, 2, 3, 0, 0), datetime.datetime(2021, 2, 4, 0, 0), datetime.datetime(2021, 2, 5, 0, 0), datetime.datetime(2021, 2, 6, 0, 0), datetime.datetime(2021, 2, 7, 0, 0), datetime.datetime(2021, 2, 8, 0, 0), datetime.datetime(2021, 2, 9, 0, 0), datetime.datetime(2021, 2, 10, 0, 0), datetime.datetime(2021, 2, 11, 0, 0), datetime.datetime(2021, 2, 12, 0, 0), datetime.datetime(2021, 2, 15, 0, 0), datetime.datetime(2021, 2, 16, 0, 0), datetime.datetime(2021, 2, 17, 0, 0), datetime.datetime(2021, 2, 18, 0, 0), datetime.datetime(2021, 2, 19, 0, 0), datetime.datetime(2021, 2, 20, 0, 0), datetime.datetime(2021, 2, 21, 0, 0), datetime.datetime(2021, 2, 22, 0, 0), datetime.datetime(2021, 2, 23, 0, 0), datetime.datetime(2021, 2, 24, 0, 0), datetime.datetime(2021, 2, 25, 0, 0), datetime.datetime(2021, 2, 26, 0, 0), datetime.datetime(2021, 2, 27, 0, 0), datetime.datetime(2021, 2, 28, 0, 0), datetime.datetime(2021, 3, 1, 0, 0), datetime.datetime(2021, 3, 2, 0, 0), datetime.datetime(2021, 3, 3, 0, 0), datetime.datetime(2021, 3, 4, 0, 0), datetime.datetime(2021, 3, 5, 0, 0), datetime.datetime(2021, 3, 6, 0, 0), datetime.datetime(2021, 3, 7, 0, 0), datetime.datetime(2021, 3, 8, 0, 0), datetime.datetime(2021, 3, 9, 0, 0), datetime.datetime(2021, 3, 10, 0, 0), datetime.datetime(2021, 3, 11, 0, 0), datetime.datetime(2021, 3, 12, 0, 0), datetime.datetime(2021, 3, 13, 0, 0), datetime.datetime(2021, 3, 14, 0, 0), datetime.datetime(2021, 3, 15, 0, 0), datetime.datetime(2021, 3, 16, 0, 0), datetime.datetime(2021, 3, 17, 0, 0), datetime.datetime(2021, 3, 18, 0, 0), datetime.datetime(2021, 3, 19, 0, 0), datetime.datetime(2021, 3, 20, 0, 0), datetime.datetime(2021, 3, 21, 0, 0), datetime.datetime(2021, 3, 22, 0, 0), datetime.datetime(2021, 3, 23, 0, 0), datetime.datetime(2021, 3, 24, 0, 0), datetime.datetime(2021, 3, 25, 0, 0), datetime.datetime(2021, 3, 26, 0, 0), datetime.datetime(2021, 3, 27, 0, 0), datetime.datetime(2021, 3, 28, 0, 0), datetime.datetime(2021, 3, 29, 0, 0), datetime.datetime(2021, 3, 30, 0, 0), datetime.datetime(2021, 3, 31, 0, 0), datetime.datetime(2021, 4, 1, 0, 0), datetime.datetime(2021, 4, 2, 0, 0), datetime.datetime(2021, 4, 3, 0, 0), datetime.datetime(2021, 4, 4, 0, 0), datetime.datetime(2021, 4, 5, 0, 0), datetime.datetime(2021, 4, 6, 0, 0), datetime.datetime(2021, 4, 7, 0, 0), datetime.datetime(2021, 4, 8, 0, 0), datetime.datetime(2021, 4, 9, 0, 0), datetime.datetime(2021, 4, 10, 0, 0), datetime.datetime(2021, 4, 11, 0, 0), datetime.datetime(2021, 4, 12, 0, 0), datetime.datetime(2021, 4, 13, 0, 0), datetime.datetime(2021, 4, 14, 0, 0), datetime.datetime(2021, 4, 15, 0, 0), datetime.datetime(2021, 4, 16, 0, 0), datetime.datetime(2021, 4, 17, 0, 0), datetime.datetime(2021, 4, 18, 0, 0), datetime.datetime(2021, 4, 19, 0, 0), datetime.datetime(2021, 4, 20, 0, 0), datetime.datetime(2021, 4, 21, 0, 0), datetime.datetime(2021, 4, 22, 0, 0), datetime.datetime(2021, 4, 23, 0, 0), datetime.datetime(2021, 4, 24, 0, 0), datetime.datetime(2021, 4, 25, 0, 0), datetime.datetime(2021, 4, 26, 0, 0), datetime.datetime(2021, 4, 27, 0, 0), datetime.datetime(2021, 4, 28, 0, 0), datetime.datetime(2021, 4, 29, 0, 0), datetime.datetime(2021, 4, 30, 0, 0), datetime.datetime(2021, 5, 1, 0, 0), datetime.datetime(2021, 5, 2, 0, 0), datetime.datetime(2021, 5, 3, 0, 0), datetime.datetime(2021, 5, 4, 0, 0), datetime.datetime(2021, 5, 5, 0, 0), datetime.datetime(2021, 5, 6, 0, 0), datetime.datetime(2021, 5, 7, 0, 0), datetime.datetime(2021, 5, 8, 0, 0), datetime.datetime(2021, 5, 9, 0, 0), datetime.datetime(2021, 5, 10, 0, 0), datetime.datetime(2021, 5, 11, 0, 0), datetime.datetime(2021, 5, 12, 0, 0), datetime.datetime(2021, 5, 13, 0, 0), datetime.datetime(2021, 5, 14, 0, 0), datetime.datetime(2021, 5, 15, 0, 0), datetime.datetime(2021, 5, 16, 0, 0), datetime.datetime(2021, 5, 17, 0, 0), datetime.datetime(2021, 5, 18, 0, 0), datetime.datetime(2021, 5, 19, 0, 0), datetime.datetime(2021, 5, 20, 0, 0), datetime.datetime(2021, 5, 21, 0, 0), datetime.datetime(2021, 5, 22, 0, 0), datetime.datetime(2021, 5, 23, 0, 0), datetime.datetime(2021, 5, 24, 0, 0), datetime.datetime(2021, 5, 25, 0, 0), datetime.datetime(2021, 5, 26, 0, 0), datetime.datetime(2021, 5, 27, 0, 0), datetime.datetime(2021, 5, 28, 0, 0), datetime.datetime(2021, 5, 29, 0, 0), datetime.datetime(2021, 5, 30, 0, 0), datetime.datetime(2021, 5, 31, 0, 0), datetime.datetime(2021, 6, 1, 0, 0), datetime.datetime(2021, 6, 2, 0, 0), datetime.datetime(2021, 6, 3, 0, 0), datetime.datetime(2021, 6, 4, 0, 0), datetime.datetime(2021, 6, 5, 0, 0), datetime.datetime(2021, 6, 6, 0, 0), datetime.datetime(2021, 6, 7, 0, 0), datetime.datetime(2021, 6, 8, 0, 0), datetime.datetime(2021, 6, 9, 0, 0), datetime.datetime(2021, 6, 10, 0, 0), datetime.datetime(2021, 6, 11, 0, 0), datetime.datetime(2021, 6, 12, 0, 0), datetime.datetime(2021, 6, 13, 0, 0), datetime.datetime(2021, 6, 14, 0, 0), datetime.datetime(2021, 6, 15, 0, 0), datetime.datetime(2021, 6, 16, 0, 0), datetime.datetime(2021, 6, 17, 0, 0), datetime.datetime(2021, 6, 18, 0, 0), datetime.datetime(2021, 6, 19, 0, 0), datetime.datetime(2021, 6, 20, 0, 0), datetime.datetime(2021, 6, 21, 0, 0), datetime.datetime(2021, 6, 22, 0, 0), datetime.datetime(2021, 6, 23, 0, 0), datetime.datetime(2021, 6, 24, 0, 0), datetime.datetime(2021, 6, 25, 0, 0), datetime.datetime(2021, 6, 26, 0, 0), datetime.datetime(2021, 6, 27, 0, 0), datetime.datetime(2021, 6, 28, 0, 0), datetime.datetime(2021, 6, 29, 0, 0), datetime.datetime(2021, 6, 30, 0, 0), datetime.datetime(2021, 7, 1, 0, 0), datetime.datetime(2021, 7, 2, 0, 0), datetime.datetime(2021, 7, 3, 0, 0), datetime.datetime(2021, 7, 4, 0, 0), datetime.datetime(2021, 7, 5, 0, 0), datetime.datetime(2021, 7, 6, 0, 0), datetime.datetime(2021, 7, 7, 0, 0), datetime.datetime(2021, 7, 8, 0, 0), datetime.datetime(2021, 7, 9, 0, 0), datetime.datetime(2021, 7, 10, 0, 0), datetime.datetime(2021, 7, 11, 0, 0), datetime.datetime(2021, 7, 12, 0, 0), datetime.datetime(2021, 7, 13, 0, 0), datetime.datetime(2021, 7, 14, 0, 0), datetime.datetime(2021, 7, 15, 0, 0), datetime.datetime(2021, 7, 16, 0, 0), datetime.datetime(2021, 7, 17, 0, 0), datetime.datetime(2021, 7, 18, 0, 0), datetime.datetime(2021, 7, 19, 0, 0), datetime.datetime(2021, 7, 20, 0, 0), datetime.datetime(2021, 7, 21, 0, 0), datetime.datetime(2021, 7, 22, 0, 0), datetime.datetime(2021, 7, 23, 0, 0), datetime.datetime(2021, 7, 24, 0, 0), datetime.datetime(2021, 7, 25, 0, 0), datetime.datetime(2021, 7, 26, 0, 0), datetime.datetime(2021, 7, 27, 0, 0), datetime.datetime(2021, 7, 28, 0, 0), datetime.datetime(2021, 7, 29, 0, 0), datetime.datetime(2021, 7, 30, 0, 0), datetime.datetime(2021, 7, 31, 0, 0), datetime.datetime(2021, 8, 1, 0, 0), datetime.datetime(2021, 8, 2, 0, 0), datetime.datetime(2021, 8, 3, 0, 0), datetime.datetime(2021, 8, 4, 0, 0), datetime.datetime(2021, 8, 5, 0, 0), datetime.datetime(2021, 8, 6, 0, 0), datetime.datetime(2021, 8, 7, 0, 0), datetime.datetime(2021, 8, 8, 0, 0), datetime.datetime(2021, 8, 9, 0, 0), datetime.datetime(2021, 8, 10, 0, 0), datetime.datetime(2021, 8, 11, 0, 0), datetime.datetime(2021, 8, 12, 0, 0), datetime.datetime(2021, 8, 13, 0, 0), datetime.datetime(2021, 8, 14, 0, 0), datetime.datetime(2021, 8, 15, 0, 0), datetime.datetime(2021, 8, 16, 0, 0), datetime.datetime(2021, 8, 23, 0, 0), datetime.datetime(2021, 8, 28, 0, 0), datetime.datetime(2021, 8, 30, 0, 0), datetime.datetime(2021, 9, 3, 0, 0), datetime.datetime(2021, 9, 4, 0, 0), datetime.datetime(2021, 9, 5, 0, 0), datetime.datetime(2021, 9, 6, 0, 0), datetime.datetime(2021, 9, 7, 0, 0), datetime.datetime(2021, 9, 8, 0, 0), datetime.datetime(2021, 9, 9, 0, 0), datetime.datetime(2021, 9, 10, 0, 0), datetime.datetime(2021, 9, 11, 0, 0), datetime.datetime(2021, 9, 12, 0, 0), datetime.datetime(2021, 9, 13, 0, 0), datetime.datetime(2021, 9, 14, 0, 0), datetime.datetime(2021, 9, 15, 0, 0), datetime.datetime(2021, 9, 16, 0, 0), datetime.datetime(2021, 9, 17, 0, 0), datetime.datetime(2021, 9, 18, 0, 0), datetime.datetime(2021, 9, 19, 0, 0), datetime.datetime(2021, 9, 20, 0, 0), datetime.datetime(2021, 9, 21, 0, 0), datetime.datetime(2021, 9, 22, 0, 0), datetime.datetime(2021, 9, 23, 0, 0), datetime.datetime(2021, 9, 24, 0, 0), datetime.datetime(2021, 9, 25, 0, 0), datetime.datetime(2021, 9, 26, 0, 0), datetime.datetime(2021, 9, 27, 0, 0), datetime.datetime(2021, 9, 28, 0, 0), datetime.datetime(2021, 9, 29, 0, 0), datetime.datetime(2021, 9, 30, 0, 0), datetime.datetime(2021, 10, 1, 0, 0), datetime.datetime(2021, 10, 2, 0, 0), datetime.datetime(2021, 10, 3, 0, 0), datetime.datetime(2021, 10, 4, 0, 0), datetime.datetime(2021, 10, 5, 0, 0), datetime.datetime(2021, 10, 6, 0, 0), datetime.datetime(2021, 10, 7, 0, 0), datetime.datetime(2021, 10, 8, 0, 0), datetime.datetime(2021, 10, 9, 0, 0), datetime.datetime(2021, 10, 10, 0, 0), datetime.datetime(2021, 10, 11, 0, 0), datetime.datetime(2021, 10, 12, 0, 0), datetime.datetime(2021, 10, 13, 0, 0), datetime.datetime(2021, 10, 14, 0, 0), datetime.datetime(2021, 10, 15, 0, 0), datetime.datetime(2021, 10, 16, 0, 0), datetime.datetime(2021, 10, 17, 0, 0), datetime.datetime(2021, 10, 18, 0, 0), datetime.datetime(2021, 10, 19, 0, 0), datetime.datetime(2021, 10, 20, 0, 0), datetime.datetime(2021, 10, 21, 0, 0), datetime.datetime(2021, 10, 22, 0, 0), datetime.datetime(2021, 10, 23, 0, 0), datetime.datetime(2021, 10, 24, 0, 0), datetime.datetime(2021, 10, 25, 0, 0), datetime.datetime(2021, 10, 26, 0, 0), datetime.datetime(2021, 10, 27, 0, 0), datetime.datetime(2021, 10, 28, 0, 0), datetime.datetime(2021, 10, 29, 0, 0), datetime.datetime(2021, 10, 31, 0, 0), datetime.datetime(2021, 11, 1, 0, 0), datetime.datetime(2021, 11, 2, 0, 0), datetime.datetime(2021, 11, 3, 0, 0), datetime.datetime(2021, 11, 4, 0, 0), datetime.datetime(2021, 11, 5, 0, 0), datetime.datetime(2021, 11, 6, 0, 0), datetime.datetime(2021, 11, 7, 0, 0), datetime.datetime(2021, 11, 8, 0, 0), datetime.datetime(2021, 11, 9, 0, 0), datetime.datetime(2021, 11, 10, 0, 0), datetime.datetime(2021, 11, 11, 0, 0), datetime.datetime(2021, 11, 12, 0, 0), datetime.datetime(2021, 11, 13, 0, 0), datetime.datetime(2021, 11, 14, 0, 0), datetime.datetime(2021, 11, 15, 0, 0), datetime.datetime(2021, 11, 16, 0, 0), datetime.datetime(2021, 11, 17, 0, 0), datetime.datetime(2021, 11, 18, 0, 0), datetime.datetime(2021, 11, 19, 0, 0), datetime.datetime(2021, 11, 20, 0, 0), datetime.datetime(2021, 11, 21, 0, 0), datetime.datetime(2021, 11, 22, 0, 0), datetime.datetime(2021, 11, 23, 0, 0), datetime.datetime(2021, 11, 24, 0, 0), datetime.datetime(2021, 11, 25, 0, 0), datetime.datetime(2021, 11, 26, 0, 0), datetime.datetime(2021, 11, 27, 0, 0), datetime.datetime(2021, 11, 28, 0, 0), datetime.datetime(2021, 11, 29, 0, 0), datetime.datetime(2021, 11, 30, 0, 0), datetime.datetime(2021, 12, 1, 0, 0), datetime.datetime(2021, 12, 2, 0, 0), datetime.datetime(2021, 12, 3, 0, 0), datetime.datetime(2021, 12, 4, 0, 0), datetime.datetime(2021, 12, 5, 0, 0), datetime.datetime(2021, 12, 6, 0, 0), datetime.datetime(2021, 12, 7, 0, 0), datetime.datetime(2021, 12, 8, 0, 0), datetime.datetime(2021, 12, 9, 0, 0), datetime.datetime(2021, 12, 10, 0, 0), datetime.datetime(2021, 12, 11, 0, 0), datetime.datetime(2021, 12, 12, 0, 0), datetime.datetime(2021, 12, 13, 0, 0), datetime.datetime(2021, 12, 14, 0, 0), datetime.datetime(2021, 12, 15, 0, 0), datetime.datetime(2021, 12, 16, 0, 0), datetime.datetime(2021, 12, 17, 0, 0), datetime.datetime(2021, 12, 18, 0, 0), datetime.datetime(2021, 12, 19, 0, 0), datetime.datetime(2021, 12, 20, 0, 0), datetime.datetime(2021, 12, 21, 0, 0), datetime.datetime(2021, 12, 22, 0, 0), datetime.datetime(2021, 12, 23, 0, 0), datetime.datetime(2021, 12, 24, 0, 0), datetime.datetime(2021, 12, 25, 0, 0), datetime.datetime(2021, 12, 26, 0, 0), datetime.datetime(2021, 12, 27, 0, 0), datetime.datetime(2021, 12, 28, 0, 0), datetime.datetime(2021, 12, 29, 0, 0), datetime.datetime(2021, 12, 30, 0, 0), datetime.datetime(2021, 12, 31, 0, 0), datetime.datetime(2021, 12, 31, 0, 0), datetime.datetime(2021, 12, 31, 0, 0)]
将上述数据以表格图形式呈现,方便观测对比,顺便复习matplotlib相关知识。
#采用pyplot将数据用表格形式呈现,便于观测
#设置布局为2行,2列,每个框大小为10*10
fig,((x1,x2),(x3,x4))=plt.subplots(nrows=2,ncols=2,figsize=(10,10))
#表格标签设置
plt.style.use('fivethirtyeight')
#显示整个400多个真实值actual
x1.plot(datas,data['actual'])
x1.set_xlabel('')
x1.set_ylabel('Temp')
x1.set_title('Max Tem')
#显示整个400多个真实值temp_1
x2.plot(datas,data['temp_1'])
x2.set_xlabel('')
x2.set_ylabel('Temp')
x2.set_title('temp1 Max Tem')
#显示整个400多个真实值temp_2
x3.plot(datas,data['temp_2'])
x3.set_xlabel('')
x3.set_ylabel('Temp')
x3.set_title('temp2 Max Tem')
#显示整个400多个真实值random
x4.plot(datas,data['random'])
x4.set_xlabel('')
x4.set_ylabel('Temp')
x4.set_title('random Max Tem')
plt.show()
#取出标签,作为标签。
labels=np.array(data['actual'])
print(labels)
print('ooooooooo')
#去掉标签
data=data.drop('actual',axis=1)
print(data)
print('xxxxxxxx')
data_list=list(data.columns)
print(data_list)
data=np.array(data)
print(data)
显示结果:
[45 44 41 40 44 51 45 48 50 52 45 49 55 49 48 54 50 54 48 52 52 57 48 51
54 56 57 56 52 48 47 46 51 49 49 53 49 51 57 62 56 55 58 55 56 57 53 51
53 51 51 60 59 61 60 57 53 58 55 59 57 64 60 53 54 55 56 55 52 54 49 51
53 58 63 61 55 56 57 53 54 57 59 51 56 64 68 73 71 63 69 60 57 68 77 76
66 59 58 60 59 59 60 68 77 89 81 81 73 64 65 55 59 60 61 64 61 68 77 87
74 60 68 77 82 63 67 75 81 77 82 65 57 60 71 64 63 66 59 66 65 66 66 65
64 64 64 71 79 75 71 80 81 92 86 85 67 65 67 65 70 66 60 67 71 67 65 70
76 73 75 68 69 71 78 85 79 74 73 76 76 71 68 69 76 68 74 71 74 74 77 75
77 76 72 80 73 78 82 81 71 75 80 85 79 83 85 88 76 73 77 73 75 80 79 72
72 73 72 76 80 87 90 83 84 81 79 75 70 67 68 68 68 67 72 74 77 70 74 75
79 71 75 68 69 71 67 68 67 64 67 76 77 69 68 66 67 63 65 61 63 66 63 64
68 57 60 62 66 60 60 62 60 60 61 58 62 59 62 62 61 65 58 60 65 68 59 57
57 65 65 58 61 63 71 65 64 63 59 55 57 55 50 52 55 57 55 54 54 49 52 52
53 48 52 52 52 46 50 49 46 40 42 40 41 36 44 44 43 40 39 39 35 35 39 46
51 49 45 40 41 42 42 47 48 48 57 40 40 40]
ooooooooo
year month day week temp_2 temp_1 average random
0 2021 1 1 Fri 45 45 45.6 29
1 2021 1 2 Sat 44 45 45.7 61
2 2021 1 3 Sun 45 44 45.8 56
3 2021 1 4 Mon 44 41 45.9 53
4 2021 1 5 Tues 41 40 46.0 41
.. ... ... ... ... ... ... ... ...
345 2021 12 29 Thurs 47 48 45.3 65
346 2021 12 30 Fri 48 48 45.4 42
347 2021 12 31 Sat 48 57 45.5 57
348 2021 12 31 Sat 48 57 45.5 57
349 2021 12 31 Sat 48 57 45.5 57
[350 rows x 8 columns]
xxxxxxxx
['year', 'month', 'day', 'week', 'temp_2', 'temp_1', 'average', 'random']
[[2021 1 1 ... 45 45.6 29]
[2021 1 2 ... 45 45.7 61]
[2021 1 3 ... 44 45.8 56]
...
[2021 12 31 ... 57 45.5 57]
[2021 12 31 ... 57 45.5 57]
[2021 12 31 ... 57 45.5 57]]
(3)模型建立、损失函数、迭代训练及显示
3.1 常规模型建立
#独热编码
print('xxxxxxx')
data=pd.get_dummies(data)
print(data.head())
print('xxxxxxxx')
#取出标签,作为标签。
labels=np.array(data['actual'])
#去掉标签
data=data.drop('actual',axis=1)
print(data)
data_list=list(data.columns)
data=np.array(data)
print(data.shape)
#模型建立
from sklearn import preprocessing
input_data=preprocessing.StandardScaler().fit_transform(data)
x=torch.tensor(input_data,dtype=float)
y=torch.tensor(labels,dtype=float)
#权重参数初始化
weigh=torch.randn((14,128),dtype=float,requires_grad=True)
bias=torch.randn(128,dtype=float,requires_grad=True)
#初始化权重参数
weighs=torch.randn((128,1),dtype=float,requires_grad=True)
biass=torch.randn(1,dtype=float,requires_grad=True)
learning_rate=0.001
losses=[]
for i in range(100):
#隐藏层
hidden=x.mm(weigh)+bias
#加入激活函数
hidden=torch.relu(hidden)
#预测结果
predictions=hidden.mm(weighs)+biass
#计算损失
loss=torch.mean((predictions-y)**2)
losses.append(loss.data.numpy())
#显示结果
if i%10==0:
print('loss:',loss)
#反向传播
loss.backward()
#跟新参数
weigh.data.add_(-learning_rate*weigh.grad.data)
bias.data.add_(-learning_rate*bias.grad.data)
weighs.data.add_(-learning_rate*weighs.grad.data)
biass.data.add_(-learning_rate*biass.grad.data)
#每次迭代记得清零
weigh.grad.data.zero_()
bias.grad.data.zero_()
weighs.grad.data.zero_()
biass.grad.data.zero_()
print(predictions.shape)
xxxxxxx
year month day temp_2 ... week_Sun week_Thurs week_Tues week_Wed
0 2021 1 1 45 ... 0 0 0 0
1 2021 1 2 44 ... 0 0 0 0
2 2021 1 3 45 ... 1 0 0 0
3 2021 1 4 44 ... 0 0 0 0
4 2021 1 5 41 ... 0 0 1 0
[5 rows x 15 columns]
xxxxxxxx
year month day temp_2 ... week_Sun week_Thurs week_Tues week_Wed
0 2021 1 1 45 ... 0 0 0 0
1 2021 1 2 44 ... 0 0 0 0
2 2021 1 3 45 ... 1 0 0 0
3 2021 1 4 44 ... 0 0 0 0
4 2021 1 5 41 ... 0 0 1 0
.. ... ... ... ... ... ... ... ... ...
345 2021 12 29 47 ... 0 1 0 0
346 2021 12 30 48 ... 0 0 0 0
347 2021 12 31 48 ... 0 0 0 0
348 2021 12 31 48 ... 0 0 0 0
349 2021 12 31 48 ... 0 0 0 0
[350 rows x 14 columns]
(350, 14)
loss: tensor(2109.5258, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(692.9397, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(469.2138, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(381.2370, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(332.9759, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(302.4142, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(281.3954, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(266.1516, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(254.6720, dtype=torch.float64, grad_fn=<MeanBackward0>)
loss: tensor(245.7872, dtype=torch.float64, grad_fn=<MeanBackward0>)
torch.Size([350, 1])
3.2 采用pytorch自带模块进行建模、迭代训练
'''
pytorch 搭建神经网络模型-预测天气预报
author:leecj2015
time:2021-09-02
'''
#one-step :Import remaining functions
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import datetime
#second-step: load model(import data)
data=pd.read_csv('temp.csv')
yaers=data['year']
months=data['month']
days=data['day']
'''
ip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。
zip([iterable, ...])
eg:
>>>a = [1,2,3]
>>> b = [4,5,6]
>>> c = [4,5,6,7,8]
>>> zipped = zip(a,b) # 打包为元组的列表
[(1, 4), (2, 5), (3, 6)]
>>> zip(a,c) # 元素个数与最短的列表一致
[(1, 4), (2, 5), (3, 6)]
>>> zip(*zipped) # 与 zip 相反,*zipped 可理解为解压,返回二维矩阵式
[(1, 2, 3), (4, 5, 6)]
'''
datas=[str(int(year))+'-'+str(int(month))+'-'+str(int(day)) for year,month ,day in zip(yaers,months,days)]
datas=[datetime.datetime.strptime(date,'%Y-%m-%d') for date in datas]
#采用pyplot将数据用表格形式呈现,便于观测
#设置布局为2行,2列,每个框大小为10*10
fig,((x1,x2),(x3,x4))=plt.subplots(nrows=2,ncols=2,figsize=(10,10))
plt.style.use('fivethirtyeight')
#显示整个400多个真实值actual
x1.plot(datas,data['actual'])
x1.set_xlabel('')
x1.set_ylabel('Temp')
x1.set_title('Max Tem')
#显示整个400多个真实值temp_1
x2.plot(datas,data['temp_1'])
x2.set_xlabel('')
x2.set_ylabel('Temp')
x2.set_title('temp1 Max Tem')
#显示整个400多个真实值temp_2
x3.plot(datas,data['temp_2'])
x3.set_xlabel('')
x3.set_ylabel('Temp')
x3.set_title('temp2 Max Tem')
#显示整个400多个真实值random
x4.plot(datas,data['random'])
x4.set_xlabel('')
x4.set_ylabel('Temp')
x4.set_title('random Max Tem')
plt.show()
#独热编码
data=pd.get_dummies(data)
#取出标签,作为标签。
labels=np.array(data['actual'])
#去掉标签
data=data.drop('actual',axis=1)
data_list=list(data.columns)
data=np.array(data)
#模型建立
from sklearn import preprocessing
input_data=preprocessing.StandardScaler().fit_transform(data)
input_size=input_data.shape[1]
hidden_size=128
output_size=1
batch_size=16
my_nn=torch.nn.Sequential(
torch.nn.Linear(input_size,hidden_size),
torch.nn.Sigmoid(),
torch.nn.Linear(hidden_size,output_size),
)
cost=torch.nn.MSELoss(reduction='mean')
optimizer=torch.optim.Adam(my_nn.parameters(),lr=0.001)
#训练网络
losses=[]
for i in range(100):
batch_loss=[]
for s in range(0,len(input_data),batch_size):
end=s+batch_size if s+batch_size<len(input_data) else len(input_data)
x=torch.tensor(input_data[s:end],dtype=torch.float,requires_grad=True)
y=torch.tensor(labels[s:end],dtype=torch.float,requires_grad=True)
predictions=my_nn(x)
loss=cost(predictions,y)
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
batch_loss.append(loss.data.numpy())
if i%10==0:
losses.append(batch_loss)
print(i,np.mean(batch_loss))
#预测训练结果
x=torch.tensor(input_data,dtype=torch.float)
predict=my_nn(x).data.numpy()
print(predict)
#转化日期
datas=[str(int(year))+'-'+str(int(month))+'-'+str(int(day)) for year ,month,day in zip(yaers,months,days)]
datas=[datetime.datetime.strptime(data,'%Y-%m-%d')for data in datas]
#创建一个表格来保存日期和其对应的标签数值
tab_data=pd.DataFrame(data={'data':datas,'actual':labels})
print(tab_data.head(5))
years=data[:,data_list.index('year')]
months=data[:,data_list.index('month')]
days=data[:,data_list.index('day')]
test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]
predictions_data=pd.DataFrame(data={'data':test_dates,'predictions':predict.reshape(-1)})
#真实值
plt.plot(tab_data['data'],tab_data['actual'],'b-',label='actual')
#预测值
plt.plot(predictions_data['data'],predictions_data['predictions'],'ro',label='predict')
plt.legend()
plt.xlabel('Data')
plt.ylabel('Maximum Tem')
plt.title("Actual and Predicted")
plt.show()
示例:pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。
提示:采用pytorch进行2021年天气气温预测,熟悉神经网络模型基本训练。
完整源代码可以私信我