前面学习了LSTM-FCN的相关知识,现在针对该框架我们找到了一份代码资源,来通过对实现代码的解读进一步理解该模型。
Windows10 python3.6 CUDA8.0 CuDNN5.1
GPU:GeForce GTX 960M
tensorflow-gpu>=1.2.0
keras>=2.0.4
scipy
numpy
pandas
scikit-learn>=0.18.2
h5py
matplotlib
joblib>=0.12
博主运行时使用了GPU加速,这可以大幅提高运行速度,如果没有GPU的话只需要python环境即可,只是CPU运行起来速度确实有些拉跨。当然我的显卡也很差劲,利用率直接拉满。
初始时我们指定选择使用GPU运行
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
构建LSTM-FCN模型
def generate_lstmfcn(MAX_SEQUENCE_LENGTH, NB_CLASS, NUM_CELLS=8):
ip = Input(shape=(1, MAX_SEQUENCE_LENGTH))
x = LSTM(NUM_CELLS)(ip)
#以一定概率丢弃一训练的参数,防止其过拟合
x = Dropout(0.8)(x)#缀学层
#Permute可以同时多次交换tensor的维度
y = Permute((2, 1))(ip)
y = Conv1D(128, 8, padding='same', kernel_initializer='he_uniform')(y)
#批归一化 让我们的均值方差变化没有那么猛烈
y = BatchNormalization()(y)
y = Activation('relu')(y)
y = Conv1D(256, 5, padding='same', kernel_initializer='he_uniform')(y)
y = BatchNormalization()(y)
y = Activation('relu')(y)
y = Conv1D(128, 3, padding='same', kernel_initializer='he_uniform')(y)
y = BatchNormalization()(y)
y = Activation('relu')(y)
y = GlobalAveragePooling1D()(y)
x = concatenate([x, y])
out = Dense(NB_CLASS, activation='softmax')(x)
model = Model(ip, out)
model.summary()
# add load model code here to fine-tune
return model
def generate_alstmfcn(MAX_SEQUENCE_LENGTH, NB_CLASS, NUM_CELLS=8):
ip = Input(shape=(1, MAX_SEQUENCE_LENGTH))
x = AttentionLSTM(NUM_CELLS)(ip)#注意力机制LSTM
x = Dropout(0.8)(x)
y = Permute((2, 1))(ip)
y = Conv1D(128, 8, padding='same', kernel_initializer='he_uniform')(y)
y = BatchNormalization()(y)
y = Activation('relu')(y)
y = Conv1D(256, 5, padding='same', kernel_initializer='he_uniform')(y)
y = BatchNormalization()(y)
y = Activation('relu')(y)
y = Conv1D(128, 3, padding='same', kernel_initializer='he_uniform')(y)
y = BatchNormalization()(y)
y = Activation('relu')(y)
y = GlobalAveragePooling1D()(y)
x = concatenate([x, y])
out = Dense(NB_CLASS, activation='softmax')(x)
model = Model(ip, out)
model.summary()
# add load model code here to fine-tune
return model
关于模型的具体构建都在layer-utils.py中,这里就不再赘述。
这里指定我们的数据集名称集合,在项目中为方便运行,它使用了循环来执行127个数据集,但我目前还没有成功,希望我能够后期完成吧,该文件主要用于加载数据集中使用,名称顺序与constants.py相同,内分训练集与测试集,数据读取时使用constants.py的数据集目录。
dataset_map = [
('ChlorineConcentration', 2),
('InsectWingbeatSound', 3),
('Lighting7', 4),
('Wine', 5),
('WordsSynonyms', 6),
('50words', 7),
('Beef', 8),
('DistalPhalanxOutlineAgeGroup', 9),
('DistalPhalanxOutlineCorrect', 10),
('DistalPhalanxTW', 11),
('ECG200', 12),
('ECGFiveDays', 13),
('BeetleFly', 14),
('BirdChicken', 15),
('ItalyPowerDemand', 16),
('SonyAIBORobotSurface', 17),
('SonyAIBORobotSurfaceII', 18),
('MiddlePhalanxOutlineAgeGroup', 19),
('MiddlePhalanxOutlineCorrect', 20),
('MiddlePhalanxTW', 21),
('ProximalPhalanxOutlineAgeGroup', 22),
('ProximalPhalanxOutlineCorrect', 23),
('ProximalPhalanxTW', 24),
('MoteStrain', 25),
('MedicalImages', 26),
('Strawberry', 27),
('ToeSegmentation1', 28),
('Coffee', 29),
('Cricket_X', 30),
('Cricket_Y', 31),
('Cricket_Z', 32),
('uWaveGestureLibrary_X', 33),
('uWaveGestureLibrary_Y', 34),
('uWaveGestureLibrary_Z', 35),
('ToeSegmentation2', 36),
('DiatomSizeReduction', 37),
('car', 38),
('CBF', 39),
('CinC_ECG_torso', 40),
('Computers', 41),
('Earthquakes', 42),
('ECG5000', 43),
('ElectricDevices', 44),
('FaceAll', 45),
('FaceFour', 46),
('FacesUCR', 47),
('Fish', 48),
('FordA', 49),
('FordB', 50),
('Gun_Point', 51),
('Ham', 52),
('HandOutlines', 53),
('Haptics', 54),
('Herring', 55),
('InlineSkate', 56),
('LargeKitchenAppliances', 57),
('Lighting2', 58),
('MALLAT', 59),
('Meat', 60),
('NonInvasiveFatalECG_Thorax1', 61),
('NonInvasiveFatalECG_Thorax2', 62),
('OliveOil', 63),
('OSULeaf', 64),
('PhalangesOutlinesCorrect', 65),
('Phoneme', 66),
('plane', 67),
('RefrigerationDevices', 68),
('ScreenType', 69),
('ShapeletSim', 70),
('ShapesAll', 71),
('SmallKitchenAppliances', 72),
('StarlightCurves', 73),
('SwedishLeaf', 74),
('Symbols', 75),
('synthetic_control', 76),
('Trace', 77),
('Patterns', 78),
('TwoLeadECG', 79),
('UWaveGestureLibraryAll', 80),
('wafer', 81),
('Worms', 82),
('WormsTwoClass', 83),
('yoga', 84),
('ACSF1', 85),
('AllGestureWiimoteX', 86),
('AllGestureWiimoteY', 87),
('AllGestureWiimoteZ', 88),
('BME', 89),
('Chinatown', 90),
('Crop', 91),
('DodgerLoopDay', 92),
('DodgerLoopGame', 93),
('DodgerLoopWeekend', 94),
('EOGHorizontalSignal', 95),
('EOGVerticalSignal', 96),
('EthanolLevel', 97),
('FreezerRegularTrain', 98),
('FreezerSmallTrain', 99),
('Fungi', 100),
('GestureMidAirD1', 101),
('GestureMidAirD2', 102),
('GestureMidAirD3', 103),
('GesturePebbleZ1', 104),
('GesturePebbleZ2', 105),
('GunPointAgeSpan', 106),
('GunPointMaleVersusFemale', 107),
('GunPointOldVersusYoung', 108),
('HouseTwenty', 109),
('InsectEPGRegularTrain', 110),
('InsectEPGSmallTrain', 111),
('MelbournePedestrian', 112),
('MixedShapesRegularTrain', 113),
('MixedShapesSmallTrain', 114),
('PickupGestureWiimoteZ', 115),
('PigAirwayPressure', 116),
('PigArtPressure', 117),
('PigCVP', 118),
('PLAID', 119),
('PowerCons', 120),
('Rock', 121),
('SemgHandGenderCh2', 122),
('SemgHandMovementCh2', 123),
('SemgHandSubjectCh2', 124),
('ShakeGestureWiimoteZ', 125),
('SmoothSubspace', 126),
('UMD', 127)
]
下面是具体实现的伪代码:
MODELS = [
('lstmfcn', generate_lstmfcn),
('alstmfcn', generate_alstmfcn),
]#指定两个模型
for model_id, (MODEL_NAME, model_fn) in enumerate(MODELS):#两个模型循环调用
if not os.path.exists()#判断记录文件是否存在并打开准备写入;
for dname, did in dataset_map:#循环数据集目录开始读取数据集
load_data()#加载数据并完成预处理
train_model()#训练模型
evaluate_model()#评估模型
写入实验结果;
关闭文件
画图
dataset_id,dataset_name,dataset_name_,test_accuracy
0,Adiac,lstmfcn_8_cells_weights/Adiac,0.849105
1,ArrowHead,lstmfcn_8_cells_weights/ArrowHead,0.822857
2,data/ChlorineConcentration,lstmfcn_8_cells_weights/ChlorineConcentration,0.821354