我自己对mnist官方例程进行了部分注解,希望分享出来有助于入门选手更好理解tensorflow的运行机制,可以拷贝到IDE再调试看看,看看具体数据流向还有一部分tensorflow里面用到的库。
我用的是pip安装的tensorflow-GPU-1.13,这段源码原始位置在https://github.com/tensorflow/models/blob/master/official/mnist/mnist.py
代码:
1 from __future__ import absolute_import 2 from __future__ import division 3 from __future__ import print_function 4 5 #absl是python标准库内的 6 from absl import app as absl_app 7 from absl import flags 8 9 import tensorflow as tf # pylint: disable=g-bad-import-order 10 11 from official.mnist import dataset 12 from official.utils.flags import core as flags_core 13 from official.utils.logs import hooks_helper 14 from official.utils.misc import distribution_utils 15 from official.utils.misc import model_helpers 16 17 18 LEARNING_RATE = 1e-4 19 20 #参数默认data_format = 'channels_first' 21 def create_model(data_format): 22 """Model to recognize digits in the MNIST dataset. 23 24 Network structure is equivalent to: 25 https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py 26 and 27 https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py 28 29 But uses the tf.keras API. 30 31 Args: 32 data_format: Either 'channels_first' or 'channels_last'. 'channels_first' is 33 typically faster on GPUs while 'channels_last' is typically faster on 34 CPUs. See 35 https://www.tensorflow.org/performance/performance_guide#data_formats 36 37 Returns: 38 A tf.keras.Model. 39 """ 40 41 #data_format:一个字符串,可以是channels_last(默认)或channels_first,\ 42 # 表示输入中维度的顺序,channels_last对应于具有形状(batch, height, width, channels)\ 43 # 的输入,而channels_first对应于具有形状(batch, channels, height, width)的输入. 44 #这里感觉输入只有三个维度,默认是单通道图? 45 if data_format == 'channels_first': 46 input_shape = [1, 28, 28] 47 else: 48 assert data_format == 'channels_last' 49 input_shape = [28, 28, 1] 50 51 #将tf.keras.layers.MaxPooling2D传递给max_pool 52 l = tf.keras.layers 53 max_pool = l.MaxPooling2D( 54 (2, 2), (2, 2), padding='same', data_format=data_format) 55 # The model consists of a sequential chain of layers, so tf.keras.Sequential 56 # (a subclass of tf.keras.Model) makes for a compact description. 57 return tf.keras.Sequential( 58 [ 59 #输入层确保输入的大小符合网络需要[28, 28]->[1, 28, 28] 60 l.Reshape( 61 target_shape=input_shape, 62 input_shape=(28 * 28,)), 63 #卷积 64 l.Conv2D( 65 32,#filters:整数, 输出空间的维数(即卷积中的滤波器数),就是卷积核个数 66 5,#卷积核大小,这里是5x5 67 padding='same', 68 data_format=data_format, 69 activation=tf.nn.relu), 70 #最大pooling 71 max_pool, 72 #卷积 73 l.Conv2D( 74 64, 75 5, 76 padding='same', 77 data_format=data_format, 78 activation=tf.nn.relu), 79 # 最大pooling 80 max_pool, 81 #在保留第0轴的情况下对输入的张量进行Flatten(扁平化),拉直? 82 l.Flatten(), 83 #fc 1024 -> units: 该层的神经单元结点数。 84 l.Dense(1024, activation=tf.nn.relu), 85 l.Dropout(0.4), 86 #fc输出 87 l.Dense(10) 88 ]) 89 90 #添加了很多参数,指定了一部分的值,数据url,模型url,batch_size等等 91 def define_mnist_flags(): 92 flags_core.define_base() 93 flags_core.define_performance(num_parallel_calls=False) 94 flags_core.define_image() 95 flags.adopt_module_key_flags(flags_core) 96 #自定义项参数都在这里设置了 97 flags_core.set_defaults(data_dir='./tmp/mnist_data', 98 model_dir='./tmp/mnist_model', 99 batch_size=100, 100 train_epochs=40, 101 stop_threshold=0.998) 102 103 104 def model_fn(features, labels, mode, params): 105 """The model_fn argument for creating an Estimator.""" 106 # 翻译成中文,注释的意思就是添加一个data_format的参数,下面的Estimator类需要用到 107 model = create_model(params['data_format']) 108 image = features 109 # 来判断一个对象是否是一个已知的类型。 110 if isinstance(image, dict): 111 image = features['image'] 112 113 #测试模式 114 if mode == tf.estimator.ModeKeys.PREDICT: 115 logits = model(image, training=False) 116 predictions = { 117 'classes': tf.argmax(logits, axis=1), 118 'probabilities': tf.nn.softmax(logits), 119 } 120 #如果只是测试到这里就返回了 121 return tf.estimator.EstimatorSpec( 122 mode=tf.estimator.ModeKeys.PREDICT, 123 predictions=predictions, 124 export_outputs={ 125 'classify': tf.estimator.export.PredictOutput(predictions) 126 }) 127 128 #训练模式 129 if mode == tf.estimator.ModeKeys.TRAIN: 130 #设置LEARNING_RATE 131 optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) 132 133 logits = model(image, training=True) 134 loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 135 accuracy = tf.metrics.accuracy( 136 labels=labels, predictions=tf.argmax(logits, axis=1)) 137 138 # Name tensors to be logged with LoggingTensorHook. 139 tf.identity(LEARNING_RATE, 'learning_rate') 140 tf.identity(loss, 'cross_entropy') 141 tf.identity(accuracy[1], name='train_accuracy') 142 143 # Save accuracy scalar to Tensorboard output. 144 tf.summary.scalar('train_accuracy', accuracy[1]) 145 146 return tf.estimator.EstimatorSpec( 147 mode=tf.estimator.ModeKeys.TRAIN, 148 loss=loss, 149 train_op=optimizer.minimize(loss, tf.train.get_or_create_global_step())) 150 if mode == tf.estimator.ModeKeys.EVAL: 151 logits = model(image, training=False) 152 loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 153 return tf.estimator.EstimatorSpec( 154 mode=tf.estimator.ModeKeys.EVAL, 155 loss=loss, 156 eval_metric_ops={ 157 'accuracy': 158 tf.metrics.accuracy( 159 labels=labels, predictions=tf.argmax(logits, axis=1)), 160 }) 161 162 163 def run_mnist(flags_obj): 164 """Run MNIST training and eval loop. 165 166 Args: 167 flags_obj: An object containing parsed flag values. 168 """ 169 170 #apply_clean是官方例程里面提供的用来清理现存model的方法,\ 171 # 取决于flags_obj.clean(True则清理flags_obj.model_dir内的文件) 172 model_helpers.apply_clean(flags_obj) 173 174 #把自定义的实现传给tf.estimator.Estimator 175 model_function = model_fn 176 177 #tf.ConfigProto()主要的作用是配置tf.Session的运算方式,比如gpu运算或者cpu运算 178 session_config = tf.ConfigProto( 179 #设置线程一个操作内部并行运算的线程数,比如矩阵乘法,如果设置为0,则表示以最优的线程数处理 180 inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, 181 #设置多个操作并行运算的线程数,比如 c = a + b,d = e + f . 可以并行运算 182 intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, 183 #有时候,不同的设备,它的cpu和gpu是不同的,如果将这个选项设置成True,\ 184 # 那么当运行设备不满足要求时,会自动分配GPU或者CPU 185 allow_soft_placement=True) 186 187 #获取gpu数目,优化算法等,用于优化 188 distribution_strategy = distribution_utils.get_distribution_strategy( 189 flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) 190 191 #所有输出(检查点,事件文件等)都被写入model_dir或其子目录.如果model_dir未设置,则使用临时目录. 192 #可以通过RunConfig对象(包含了有关执行环境的信息)传递config参数.它被传递给model_fn,\ 193 # 如果model_fn有一个名为“config”的参数(和输入函数以相同的方式).如果该config参数未被传递,\ 194 # 则由Estimator进行实例化.不传递配置意味着使用对本地执行有用的默认值.Estimator使配置对模型\ 195 # 可用(例如,允许根据可用的工作人员数量进行专业化),并且还使用其一些字段来控制内部,特别是关于检查点 196 run_config = tf.estimator.RunConfig( 197 train_distribute=distribution_strategy, session_config=session_config) 198 199 data_format = flags_obj.data_format 200 #channels_first,即(3,128,128,128)通道数在最前面 201 #channels_last,即(128,128,128,3)通道数在最后面 202 if data_format is None: 203 data_format = ('channels_first' 204 if tf.test.is_built_with_cuda() else 'channels_last')#判断安装的TF是否支持GPU 205 206 #estimator类对TensorFlow模型进行训练和计算. 207 #Estimator对象包装由model_fn指定的模型,其中,给定输入和其他一些参数,返回需要进行训练、计算,或预测的操作. 208 mnist_classifier = tf.estimator.Estimator( 209 #这个model_fn是参数名而已 210 model_fn=model_function,#模型对象 211 model_dir=flags_obj.model_dir,#模型目录,如果为空会创建一个临时目录 212 #猜测会去model_dir中寻找数据 213 config=run_config,#运行的一些参数 214 params={ 215 'data_format': data_format,#数据类型 216 }) 217 218 #这里定义了两个内部函数,只能被这个语句块的内部调用 219 # Set up training and evaluation input functions. 220 def train_input_fn(): 221 """Prepare data for training.""" 222 223 # When choosing shuffle buffer sizes, larger sizes result in better 224 # randomness, while smaller sizes use less memory. MNIST is a small 225 # enough dataset that we can easily shuffle the full epoch. 226 ds = dataset.train(flags_obj.data_dir) 227 ds = ds.cache().shuffle(buffer_size=50000).batch(flags_obj.batch_size) 228 229 # Iterate through the dataset a set number (`epochs_between_evals`) of times 230 # during each training session. 231 ds = ds.repeat(flags_obj.epochs_between_evals) 232 return ds 233 234 def eval_input_fn(): 235 return dataset.test(flags_obj.data_dir).batch( 236 flags_obj.batch_size).make_one_shot_iterator().get_next() 237 238 # Set up hook that outputs training logs every 100 steps. 239 train_hooks = hooks_helper.get_train_hooks( 240 flags_obj.hooks, model_dir=flags_obj.model_dir, 241 batch_size=flags_obj.batch_size) 242 243 # Train and evaluate model. 244 for _ in range(flags_obj.train_epochs // flags_obj.epochs_between_evals): 245 #训练一次,验证一次 246 mnist_classifier.train(input_fn=train_input_fn, hooks=train_hooks) 247 eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn) 248 print('\nEvaluation results:\n\t%s\n' % eval_results) 249 250 #如果eval_results['accuracy'] >= flags_obj.stop_threshold 说明模型训练好了 251 if model_helpers.past_stop_threshold(flags_obj.stop_threshold, 252 eval_results['accuracy']): 253 break 254 255 # Export the model 256 if flags_obj.export_dir is not None: 257 #预分配内存,等待数据进入 258 image = tf.placeholder(tf.float32, [None, 28, 28]) 259 input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ 260 'image': image, 261 }) 262 #输出模型 263 mnist_classifier.export_savedmodel(flags_obj.export_dir, input_fn) 264 265 266 def main(_): 267 run_mnist(flags.FLAGS) 268 269 270 if __name__ == '__main__': 271 #日志 272 tf.logging.set_verbosity(tf.logging.INFO) 273 #给flags.FLAGS添加了很多参数项目 274 define_mnist_flags() 275 #带参数的启动 276 absl_app.run(main)