dnn回归预测_Tensorflow Python:完成的训练DNN回归模型,执行多次后预测会发生巨大变化...

deftrain_dnn_regression_model(learning_rate,regularization_strength,steps,batch_size,hidden_units,feature_columns,training_examples,training_targets,validation_examples,validation_targets,):periods=10steps_per_period=steps/periods#Initialize DNN Regressoroptimizer=tf.train.FtrlOptimizer(learning_rate=learning_rate,l1_regularization_strength=regularization_strength)optimizer=tf.contrib.estimator.clip_gradients_by_norm(optimizer,5.0)dnn_regressor=tf.estimator.DNNRegressor(feature_columns=feature_columns,hidden_units=hidden_units,optimizer=optimizer,activation_fn=tf.nn.leaky_relu)#Training Functionstraining_input_fn=lambda:input_fn(training_examples,training_targets,batch_size=batch_size)predict_training_input_fn=lambda:input_fn(training_examples,training_targets,num_epochs=1,shuffle=False)#Validation Functionpredict_validation_input_fn=lambda:input_fn(validation_examples,validation_targets,num_epochs=1,shuffle=False)#Train Modeltraining_rmse=[]validation_rmse=[]print("Training Model")forperiodinrange(0,periods):linear_regressor.train(input_fn=training_input_fn,#Manually break total steps by 10steps=steps_per_period)#Use Sklearn to calculate RMSEtraining_predictions=linear_regressor.predict(input_fn=predict_training_input_fn)training_predictions=np.array([item['predictions'][0]foritemintraining_predictions])training_root_mean_squared_error=math.sqrt(metrics.mean_squared_error(training_predictions,training_targets))#Calculate Validation RMSEvalidation_predictions=linear_regressor.predict(input_fn=predict_validation_input_fn)validation_predictions=np.array([item['predictions'][0]foriteminvalidation_predictions])validation_root_mean_squared_error=math.sqrt(metrics.mean_squared_error(validation_predictions,validation_targets))#Append Lossestraining_rmse.append(training_root_mean_squared_error)validation_rmse.append(validation_root_mean_squared_error)print("Period:",period,"RMSE:",training_root_mean_squared_error)print("Training Finished")#Graphplt.ylabel("RMSE")plt.xlabel("Periods")plt.title("Root Mean Squared Error vs. Periods")plt.tight_layout()plt.plot(training_rmse,label="training")plt.plot(validation_rmse,label="validation")plt.legend()returndnn_regressor

你可能感兴趣的:(dnn回归预测)