detr_vae.py的原始代码如下:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
import numpy as np
import IPython
e = IPython.embed
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names, vq, vq_class, vq_dim, action_dim):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
self.state_dim, self.action_dim = state_dim, action_dim
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
print(f'Use VQ: {self.vq}, {self.vq_class}, {self.vq_dim}')
if self.vq:
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
else:
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters
if self.vq:
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
else:
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
bs, _ = qpos.shape
if self.encoder is None:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
probs = binaries = mu = logvar = None
else:
# cvae encoder
is_training = actions is not None # train or val
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
if self.vq:
logits = latent_info.reshape([*latent_info.shape[:-1], self.vq_class, self.vq_dim])
probs = torch.softmax(logits, dim=-1)
binaries = F.one_hot(torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(-1), self.vq_dim).view(-1, self.vq_class, self.vq_dim).float()
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
latent_input = self.latent_out_proj(straigt_through)
mu = logvar = None
else:
probs = binaries = None
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = binaries = probs = None
if self.vq:
latent_input = self.latent_out_proj(vq_sample.view(-1, self.vq_class * self.vq_dim))
else:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
return latent_input, probs, binaries, mu, logvar
def forward(self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
latent_input, probs, binaries, mu, logvar = self.encode(qpos, actions, is_pad, vq_sample)
# cvae decoder
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5)
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=self.action_dim, hidden_depth=2)
else:
raise NotImplementedError
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
transformer = build_transformer(args)
if args.no_encoder:
encoder = None
else:
encoder = build_transformer(args)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim=state_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
vq=args.vq,
vq_class=args.vq_class,
vq_dim=args.vq_dim,
action_dim=args.action_dim,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
model = CNNMLP(
backbones,
state_dim=state_dim,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model
让我们通读这段代码:
函数reparametrize部分:
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
这个 reparametrize
函数是变分自编码器(VAE)中的一个重要部分,用于实现所谓的“重参数化技巧”。重参数化技巧是一种在训练VAE时常用的方法,它允许模型在训练过程中通过反向传播更新其参数。下面是对这个函数的详细解释:
计算标准差:
std = logvar.div(2).exp()
这一步将对数方差除以2,然后对结果取指数,得到方差的平方根,即标准差。在正态分布中,方差是标准差的平方,因此这里先除以2,再取指数得到标准差。
生成随机噪声:
eps = Variable(std.data.new(std.size()).normal_())
这一步生成与标准差具有相同形状的随机噪声(从标准正态分布中采样)。Variable
是 PyTorch 中的一个类,用于封装张量,以便自动计算梯度。这里使用 std.data.new(std.size()).normal_()
生成正态分布的随机数。
重参数化操作:
return mu + std * eps
最后,将均值与标准差和随机噪声的乘积相加。这实际上是从以 mu
为均值、std
为标准差的正态分布中采样。这种方法使得采样操作可导,允许在训练过程中通过反向传播算法更新 mu
和 logvar
。
使用重参数化技巧可以使VAE的训练通过随机梯度下降(或其他优化算法)进行,因为它允许模型在训练过程中反向传播梯度。这对于训练生成模型如VAE至关重要,因为它允许模型学习如何编码输入数据到一个潜在的、连续的表示空间中,并从这个空间中有效地生成新的样本。
函数get_sinusoid_encoding_table部分:
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
这段代码定义了一个函数 get_sinusoid_encoding_table
,用于生成正弦波编码表(Sinusoidal Positional Encoding),这是在 Transformer 模型中用于位置编码的一种方法。这种编码方式是为了使模型能够利用序列中元素的顺序信息。下面是对这个函数的详细解释:
定义获取位置角度向量的函数:
get_position_angle_vec(position)
:这个内部函数为给定的位置生成一个角度向量。向量中的每个元素对应于该位置的不同维度。对于每个维度 hid_j
,该位置的角度计算为 position / (10000^(2 * hid_j / d_hid))
。这种计算方式确保了不同位置的角度变化在所有维度上是不同的,从而让模型能够区分序列中不同的位置。
生成正弦波编码表:
创建一个数组 sinusoid_table
,其中包含从 0 到 n_position-1
的每个位置的角度向量。
对于表中的偶数索引维度(0::2
),使用 np.sin
函数应用正弦变换。
对于表中的奇数索引维度(1::2
),使用 np.cos
函数应用余弦变换。
返回一个经过正弦和余弦变换的编码表,并使用 torch.FloatTensor
将其转换为 PyTorch 张量,并通过 unsqueeze(0)
增加一个维度,这通常用于批处理。
这种正弦波位置编码方式为 Transformer 模型提供了一种有效的方式来编码序列中元素的位置信息。由于 Transformer 模型本身不包含任何递归或卷积层,因此无法自然地处理序列数据中的顺序信息。通过添加正弦波位置编码,模型能够利用位置信息来更好地理解和处理序列数据。这种编码方式是 Transformer 架构的一个关键组成部分,广泛应用于自然语言处理和其他序列处理任务中。
DETRVAE类:
class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names, vq, vq_class, vq_dim, action_dim):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
self.state_dim, self.action_dim = state_dim, action_dim
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
print(f'Use VQ: {self.vq}, {self.vq_class}, {self.vq_dim}')
if self.vq:
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
else:
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters
if self.vq:
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
else:
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
bs, _ = qpos.shape
if self.encoder is None:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
probs = binaries = mu = logvar = None
else:
# cvae encoder
is_training = actions is not None # train or val
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
if self.vq:
logits = latent_info.reshape([*latent_info.shape[:-1], self.vq_class, self.vq_dim])
probs = torch.softmax(logits, dim=-1)
binaries = F.one_hot(torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(-1), self.vq_dim).view(-1, self.vq_class, self.vq_dim).float()
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
latent_input = self.latent_out_proj(straigt_through)
mu = logvar = None
else:
probs = binaries = None
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = binaries = probs = None
if self.vq:
latent_input = self.latent_out_proj(vq_sample.view(-1, self.vq_class * self.vq_dim))
else:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
return latent_input, probs, binaries, mu, logvar
def forward(self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
latent_input, probs, binaries, mu, logvar = self.encode(qpos, actions, is_pad, vq_sample)
# cvae decoder
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
这段代码定义了一个名为 DETRVAE
的类,它继承自 PyTorch 的 nn.Module
,并且似乎是一种结合了变分自编码器(VAE)和 Transformer 的深度学习模型。这种模型可能用于处理包含图像、位置和动作序列的复杂数据。让我们逐步解析这个类的主要部分:
__init__
action_head
和 is_pad_head
是线性层,用于最终的动作预测。query_embed
是嵌入层,用于处理对象查询。backbones
,选择不同的特征提取方法。encode
这个函数实现了 VAE 编码器的功能,将输入数据(如机器人的位置 qpos
和动作序列 actions
)编码为潜在空间的表示。
forward
backbones
,使用这些网络来处理图像特征并与位置和潜在输入结合。backbones
,直接处理位置和环境状态。action_head
和填充标记头 is_pad_head
生成动作预测和填充标记预测。DETRVAE类中的函数__init__部分:
def __init__(self, backbones, transformer, encoder, state_dim, num_queries, camera_names, vq, vq_class, vq_dim, action_dim):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
self.state_dim, self.action_dim = state_dim, action_dim
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.encoder_joint_proj = nn.Linear(state_dim, hidden_dim) # project qpos to embedding
print(f'Use VQ: {self.vq}, {self.vq_class}, {self.vq_dim}')
if self.vq:
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
else:
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], qpos, a_seq
# decoder extra parameters
if self.vq:
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
else:
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for proprio and latent
这段代码定义了 DETRVAE
类的初始化方法,是一个结合了变分自编码器(VAE)和 Transformer 架构的深度学习模型的构造函数。该模型设计用于处理包含图像、状态信息和动作序列的复杂数据。以下是对初始化方法的详细解析:
__init__
该函数用于初始化 DETRVAE
模型的各个组件。
参数:
模型组件:
特征提取:
backbones
,使用它们来处理图像特征,并通过 self.input_proj
进行投影。变分编码器组件:
解码器额外参数:
DETRVAE类中的函数encode部分:
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
bs, _ = qpos.shape
if self.encoder is None:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
probs = binaries = mu = logvar = None
else:
# cvae encoder
is_training = actions is not None # train or val
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
encoder_input = torch.cat([cls_embed, qpos_embed, action_embed], axis=1) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(qpos.device) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
if self.vq:
logits = latent_info.reshape([*latent_info.shape[:-1], self.vq_class, self.vq_dim])
probs = torch.softmax(logits, dim=-1)
binaries = F.one_hot(torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(-1), self.vq_dim).view(-1, self.vq_class, self.vq_dim).float()
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
latent_input = self.latent_out_proj(straigt_through)
mu = logvar = None
else:
probs = binaries = None
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = binaries = probs = None
if self.vq:
latent_input = self.latent_out_proj(vq_sample.view(-1, self.vq_class * self.vq_dim))
else:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
return latent_input, probs, binaries, mu, logvar
这段代码是 DETRVAE
类中 encode
方法的实现,它负责将输入数据编码为潜在空间的表示。这个方法是变分自编码器(VAE)和变分量化(VQ)技术的结合。下面是对这个方法的详细解释:
判断编码器是否存在:
self.encoder
为空),则创建一个零向量作为潜在样本,并通过投影层(self.latent_out_proj
)转换。使用编码器:
is_training
),即检查是否提供了动作序列(actions
)。self.encoder_action_proj
)投影到嵌入空间。qpos
)执行类似的投影(self.encoder_joint_proj
)。pos_embed
)和可选的填充掩码(is_pad
)。变分量化(VQ)或标准VAE处理:
mu
)和对数方差(logvar
)。非训练模式:
DETRVAE类中的函数forward部分:
def forward(self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
latent_input, probs, binaries, mu, logvar = self.encode(qpos, actions, is_pad, vq_sample)
# cvae decoder
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(transformer_input, None, self.query_embed.weight, self.pos.weight)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
这段代码定义了 DETRVAE
类的 forward
方法,它实现了模型的前向传播过程,即如何处理输入数据并生成输出。该方法结合了变分自编码器(VAE)和 Transformer 架构。下面是对这个方法的详细解释:
self.encode
方法对输入数据进行编码,生成潜在的表示(latent_input
)、概率(probs
)、二进制表示(binaries
)、均值(mu
)和对数方差(logvar
)。backbones
不为空):
backbones
)提取特征。self.input_proj
)和组合。self.input_proj_robot_state
处理)一起传递给 Transformer 模型。backbones
):
action_head
和 is_pad_head
生成动作预测和填充标记预测。a_hat
)、填充标记预测(is_pad_hat
)、以及编码阶段生成的统计数据和概率信息。CNNMLP类:
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5)
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=self.action_dim, hidden_depth=2)
else:
raise NotImplementedError
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
这段代码定义了一个名为 CNNMLP
的类,它继承自 PyTorch 的 nn.Module
。这个类似乎是为了实现一个结合卷积神经网络(CNN)和多层感知机(MLP)的模型,主要用于处理图像和状态信息,并输出动作预测。以下是对这个类的详细解析:
__init__
参数:
模型组件:
state_dim
,用于动作预测。forward
参数:
处理过程:
self.backbone_down_projs
中的下采样投影将特征图的维度降低。self.mlp
)处理合并的特征以生成动作预测(a_hat
)。CNNMLP类中的函数__init__部分:
def __init__(self, backbones, state_dim, camera_names):
""" Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5)
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(input_dim=mlp_in_dim, hidden_dim=1024, output_dim=self.action_dim, hidden_depth=2)
else:
raise NotImplementedError
这段代码定义了一个名为 CNNMLP
的类的初始化函数,该类继承自 PyTorch 的 nn.Module
。CNNMLP
类是一个深度学习模型,它结合了卷积神经网络(CNN)和多层感知机(MLP)来处理图像和状态信息。以下是对初始化函数的详细解析:
__init__
参数:
模型组件初始化:
state_dim
,用于最终的动作预测。self.action_head
的输入维度和 self.mlp
的各个参数。backbones
中的每个模块都有一个 num_channels
属性,这个属性表示该模块输出特征图的通道数。CNNMLP类中的函数forward部分:
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
这段代码定义了 CNNMLP
类的 forward
方法,它实现了模型的前向传播过程,即如何处理输入数据并生成动作预测。该方法主要涉及图像特征提取和多层感知机(MLP)的应用。以下是对这个方法的详细解释:
backbones
CNN 模块提取图像特征。self.backbone_down_projs
中定义的下采样投影进一步处理每个摄像头的特征。self.mlp
)处理合并的特征以生成动作预测(a_hat
)。a_hat
。函数mlp部分:
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
这段代码定义了一个名为 mlp
的函数,用于构建一个多层感知机(MLP)网络。这个网络由多个线性层(全连接层)和非线性激活函数(ReLU)组成。以下是对这个函数的详细解释:
无隐藏层(hidden_depth == 0
):
有隐藏层:
组合模块:
nn.Sequential
将所有创建的模块(线性层和激活函数)按顺序组合在一起,形成完整的 MLP 网络。函数build_encoder部分:
def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
这段代码定义了一个名为 build_encoder
的函数,用于构建一个 Transformer 编码器。这个函数根据提供的参数来配置和创建编码器。以下是对这个函数的详细解释:
创建单个编码器层:
TransformerEncoderLayer
创建一个编码器层,配置它使用上述参数。创建层归一化(如果启用):
normalize_before
为真,则创建一个 nn.LayerNorm
层用于归一化。创建编码器:
TransformerEncoder
创建编码器,包含多个编码器层和可选的层归一化。函数build部分:
def build(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
transformer = build_transformer(args)
if args.no_encoder:
encoder = None
else:
encoder = build_transformer(args)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim=state_dim,
num_queries=args.num_queries,
camera_names=args.camera_names,
vq=args.vq,
vq_class=args.vq_class,
vq_dim=args.vq_dim,
action_dim=args.action_dim,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model
这段代码定义了一个名为 build
的函数,用于构建一个名为 DETRVAE
的复合模型,结合了卷积神经网络(用于图像处理)、Transformer 架构(用于序列数据处理)和变分自编码器(VAE)。以下是对这个函数的详细解释:
设置状态维度:
state_dim
被设置为 14,这是机器人或环境状态的维度。构建图像处理的卷积网络(CNN)背景模型:
args.camera_names
中的每个摄像头,使用 build_backbone
函数构建一个卷积网络,并将其添加到 backbones
列表中。构建 Transformer 模型:
build_transformer
函数构建 Transformer 模型。条件性地构建编码器:
args.no_encoder
为真,则不构建编码器,否则使用 build_transformer
函数构建编码器。构建 DETRVAE 模型:
args
中提取的其他参数来初始化 DETRVAE
类的实例。计算模型的参数数量:
DETRVAE
模型实例。函数build_cnnmlp部分:
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
model = CNNMLP(
backbones,
state_dim=state_dim,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters/1e6,))
return model
这段代码定义了一个名为 build_cnnmlp
的函数,用于构建一个名为 CNNMLP
的深度学习模型。该模型结合了卷积神经网络(CNN)和多层感知机(MLP),主要用于处理图像和状态信息。以下是对这个函数的详细解释:
设置状态维度:
state_dim
被设置为 14,这可能是机器人或环境状态的维度。构建图像处理的卷积网络(CNN)背景模型:
args.camera_names
中的每个摄像头,使用 build_backbone
函数构建一个卷积网络,并将其添加到 backbones
列表中。构建 CNNMLP 模型:
backbones
,状态维度 state_dim
和摄像头名称 args.camera_names
来初始化 CNNMLP
类的实例。计算模型的参数数量:
CNNMLP
模型实例。