定义一个主函数,对sin函数进行显示:
if __name__ == "__main__":
samples = 200
xs = np.random.uniform(-np.pi, np.pi, [samples])
xs = sorted(xs)
ys = sorted(ys)
# 使用matplotlib进行图像的显示
plt.plot(xs, ys)
plt.show()
app = SinApp(Config())
with app:
xs_train, ys_train = app.train()
xs_predict, ys_predict = app.predict()
plt.plot(xs_train,ys_train)
plt.plot(xs_predict, ys_predict)
plt.show()
将用到的参数放在Config的类中
class Config:
def __init__(self):
self.save_path = './model_sin/sin'
self.lr = 0.001
self.epoches = 2000
self.batch_size = 200
# 定义隐藏层的数量
self.hidden_units = 200
定义一个张量类:Tensors
class Tensors:
def __init__(slef, config):
self.x = tf.placeholder(tf.float32, [None], 'x')
self.y = tf.placeholder(tf.float32, [None], 'y')
x = tf.reshape(slef.x, [-1, 1])
x = tf.layer.dense(x, config.hidden_units, tf.nn.relu)
y = tf.layer.dense(x, 1)
self.y_predict = tf.reshape(y, [-1])
self.loss = tf.reduce_mean(tf.square(tf.y_predict - self.y)) # 使用方差损失
self.lr = tf.placeholder(tf.float32, [], 'lr') # 定义学习步长(可以定义成动态的)
opt = tf.train.AdaOptimizer(self.lr) # 定义优化器
self.train_op = opt.minimize(self.loss)
self.loss = tf.sqrt(self.loss) # 取sinx的平方根(求平方差)打印的会更合理的,减少误差
定义一个样本类:Sample(实际上大部分工作都是在处理样本的)
class Sample:
def __init__(self, samples):
self.xs = np.random.uniform(-np.pi, np.pi, [samples]) # 可以自己定义samples
self.xs = sorted(self.xs)
self.ys = np.sin(self.xs)
@property
def num_examples(self):
return len(self.xs)
定义SinApp类
class SinApp:
def __init__(self, config):
self.ts = Tensors(config)
self.session = tf.Session()
self.saver = tf.train.Saver()
try:
self.saver.restore(self.session, config.save_path)
except:
self.session.run(tf.global_variables_initializer())
def train(self):
sample = Sample(self.config.samples)
cfg = self.config
ts = self.ts
for _ in range(cfg.epoches):
_ , loss = self.session.run([ts.train_op, ts.loss], {ts.x: sample.xs, ts.y:sample.ys, ts.lr:cfg.lr})
self.save()
return samples.xs, samples.ys # 通过训练得到的xs和ys
def save(self):
self.saver.save(self.session, self.config.save_path)
print('save model into', self.save_path)
def predict(self):
sample = Sample(400) # 样本数量400个,不是很重要就不写在config中了
ys = self.session.run(self.ts.y_predict, {self.ts.x: samples.xs}) # 预测的值弄出来,按照顺序算出ys的值了。
return sample.xs, ys # 返回xs是400个样本点,ys是对应的正弦值
def close(self):
self.session.close()
def __enter__(self):
return self
def __exit__(self):
self.close()