【2021-4-2】手动码出AlexNet(3.主文件供参考)

声明一下,这只是基本的网络和运行结构。调参之后的模型和数据由于需要花费精力而且没有指导意义,不会放。

from Data_Channel import Data_Channel
from AlexNet_Model import AlexNet
import numpy as np
import matplotlib.pyplot as plt
'''
主训练函数,这个就简单了,直接调用数据通道就行。
创建三组数据通道,依次调用,传入模型进行训练。
'''
Alex_model = AlexNet(learning_rate=0.02, drop_out=0.8, n_classes=6)
Labels_OH = {"cloudy":np.tile(np.array([0,0,0,0,0,1]), (80,1)), "haze":np.tile(np.array([0,0,0,0,1,0]), (80,1)),
             "rainy":np.tile(np.array([0,0,0,1,0,0]), (80,1)), "snow":np.tile(np.array([0,0,1,0,0,0]), (80,1)),
             "sunny":np.tile(np.array([0,1,0,0,0,0]), (80,1)), "thunder":np.tile(np.array([1,0,0,0,0,0]), (80,1))}
DC_Dic = {"cloudy":Data_Channel(category="cloudy", pool_size=20), "haze":Data_Channel(category="haze", pool_size=20),
          "rainy":Data_Channel(category="rainy", pool_size=20), "snow":Data_Channel(category="snow", pool_size=20),
          "sunny":Data_Channel(category="sunny", pool_size=20), "thunder":Data_Channel(category="thunder", pool_size=20)}

DC_list = ["cloudy", "haze", "rainy", "snow", "sunny", "thunder"]

for i in range(50):
    print("running")
    Index = DC_list[i%6]
    labels_now = Labels_OH[Index]
    DC = DC_Dic[Index]
    DC.Renew_dataset()
    Alex_model.learn(DC.RF_pool, labels_now)

X = np.arange(len(Alex_model.Loss_list))
plt.plot(X, Alex_model.Loss_list, '-r')
plt.grid()
plt.show()


你可能感兴趣的:(深度学习,人工智能,tensorflow,神经网络)