[tf]模仿keras写可重用的层

  • __call__
class LSTM(object):
  """LSTM layer using dynamic_rnn.

  Exposes variables in `trainable_weights` property.
  """

  def __init__(self, cell_size, num_layers=1, keep_prob=1., name='LSTM'):
    self.cell_size = cell_size
    self.num_layers = num_layers
    self.keep_prob = keep_prob
    self.reuse = None
    self.trainable_weights = None
    self.name = name

  def __call__(self, x, initial_state, seq_length):
    
    with tf.variable_scope(self.name, reuse=self.reuse) as vs:
      cell = tf.contrib.rnn.MultiRNNCell([
          tf.contrib.rnn.BasicLSTMCell(
              self.cell_size,
              forget_bias=0.0,
              reuse=tf.get_variable_scope().reuse)
          for _ in xrange(self.num_layers)
      ])

      lstm_out, next_state = tf.nn.dynamic_rnn(
          cell, x, initial_state=initial_state, sequence_length=seq_length)

      # shape(lstm_out) = (batch_size, timesteps, cell_size)

      if self.keep_prob < 1.:
        lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)

      if self.reuse is None:
        self.trainable_weights = vs.global_variables()

    self.reuse = True

    return lstm_out, next_state
  • 使用keras
class Actionselect(object):

  def __init__(self,
               action_class,
               **kwargs):
    self.multiclass_dense_layer = K.layers.Dense(action_class)  
    
  def __call__(self,input_data):
    return self.multiclass_dense_layer(input_data)

你可能感兴趣的:([tf]模仿keras写可重用的层)