tensorflow数据结构-CollectionDef


计算图结构

  1. MetaGraphDef(计算图)
    • MetaInfoDef(运算方法)
      • OpList(运算方法集合)
        • OpDef(运算方法)
          • ArgDef(输入,输出)
          • AttrDef(属性)
    • GraphDef (连接结构)
      • NodeDef(节点)
    • SaverDef (模型持久化)
      • CheckpointFormatVersion(模型定义使用的版本)
    • map (集合)
      • NodeList(节点value)
      • BytesList(序列化value)
    • map(签名)
    • AssetFileDef (权重值)

文章目录

      • collection_def
        • 案列
        • NodeList
        • BytesList

collection_def

message CollectionDef {

  // NodeList用于收集图中的节点。
  message NodeList {
    repeated string value = 1;
  }

  // BytesList用于收集字符串和序列化的protobufs。
  message BytesList {
    repeated bytes value = 1;
  }

  // Int64List用于收集int,int64和long值。
  message Int64List {
    repeated int64 value = 1 [packed = true];
  }

  // FloatList用于收集浮点值。
  message FloatList {
    repeated float value = 1 [packed = true];
  }

  // AnyList用于收集Any protos。
  message AnyList {
    repeated google.protobuf.Any value = 1;
  }
  
  // 以上定义必须属于oneos中
  oneof kind {
    NodeList node_list = 1;
    BytesList bytes_list = 2;
    Int64List int64_list = 3;
    FloatList float_list = 4;
    AnyList any_list = 5;
  }
}

案列

// 1. 对于单一的数据类型, 列如 string, int, float:
tf.add_to_collection("your_collection_name", your_simple_value)
strings 将会被保存为 bytes_list.

// 2. 对于序列化数据, 有3种方法添加:

//1) 
tf.add_to_collection("your_collection_name",your_proto.SerializeToString())
collection_def {
  key: "user_defined_bytes_collection"
  value {
    bytes_list {
       value: "queue_name: \"test_queue\"\n"
    }
  }
}

//2) 
tf.add_to_collection("your_collection_name", str(your_proto))
collection_def {
  key: "user_defined_string_collection"
  value {
   bytes_list {
      value: "\n\ntest_queue"
    }
  }
}


//3) any_buf = any_pb2.Any()
tf.add_to_collection("your_collection_name",any_buf.Pack(your_proto))
   collection_def {
     key: "user_defined_any_collection"
     value {
       any_list {
         value {
           type_url: "type.googleapis.com/tensorflow.QueueRunnerDef"
           value: "\n\ntest_queue"
         }
       }
     }
   }

//对于Pyhon类型的对象, implement to_proto() 和 from_proto(), 并以下列方式在tensorflow中进行注册:
ops.register_proto_function("your_collection_name",
                            proto_type,
                            to_proto=YourPythonObject.to_proto,
                            from_proto=YourPythonObject.from_proto)
//并且使用这些函数来序列化和反序列化集合。例如,
ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES,
                            proto_type=variable_pb2.VariableDef,
                            to_proto=Variable.to_proto,
                            from_proto=Variable.from_proto)

NodeList

维护节点集合

// summaries 集合中,收集要保存的节点
collection_def {
	key: "summaries"
    value {
    	node_list {
   	 	value: "input_producer/ScalarSummary:0"
    	value: "shuffle_batch/ScalarSummary:0"
    	value: "ImageSummary:0"
    	}
    }
}

BytesList

维护字符串或者序列化之后的集合

// 所有可以训练变量的集合,以bytes_list二级只能
collection_def {
   key: "trainable_variables"
   value {
   		bytes_list {
   			value: "\n\017conv1/weights:0\022\024conv1/weights/Assign\032\024conv1/weights/read:0"
  			value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032\023conv1/biases/read:0"
   		}
  }
}

你可能感兴趣的:(tensorflow数据结构)