论文代码:https://github.com/Uason-Chen/CTR-GCN
文件路径:CTR-GCN/model/ctrgcn.py
import math
import pdb
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
def import_class(name):
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod
def conv_branch_init(conv, branches):
weight = conv.weight
n = weight.size(0)
k1 = weight.size(1)
k2 = weight.size(2)
nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches)))
nn.init.constant_(conv.bias, 0)
def conv_init(conv):
if conv.weight is not None:
nn.init.kaiming_normal_(conv.weight, mode='fan_out')
if conv.bias is not None:
nn.init.constant_(conv.bias, 0)
def bn_init(bn, scale):
nn.init.constant_(bn.weight, scale)
nn.init.constant_(bn.bias, 0)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
if hasattr(m, 'weight'):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor):
nn.init.constant_(m.bias, 0)
elif classname.find('BatchNorm') != -1:
if hasattr(m, 'weight') and m.weight is not None:
m.weight.data.normal_(1.0, 0.02)
if hasattr(m, 'bias') and m.bias is not None:
m.bias.data.fill_(0)
class TemporalConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1):
super(TemporalConv, self).__init__()
pad = (kernel_size + (kernel_size-1) * (dilation-1) - 1) // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(kernel_size, 1),
padding=(pad, 0),
stride=(stride, 1),
dilation=(dilation, 1))
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class MultiScale_TemporalConv(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
dilations=[1,2,3,4],
residual=True,
residual_kernel_size=1):
super().__init__()
assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches'
# Multiple branches of temporal convolution
self.num_branches = len(dilations) + 2
branch_channels = out_channels // self.num_branches
if type(kernel_size) == list:
assert len(kernel_size) == len(dilations)
else:
kernel_size = [kernel_size]*len(dilations)
# Temporal Convolution branches
self.branches = nn.ModuleList([
nn.Sequential(
nn.Conv2d(
in_channels,
branch_channels,
kernel_size=1,
padding=0),
nn.BatchNorm2d(branch_channels),
nn.ReLU(inplace=True),
TemporalConv(
branch_channels,
branch_channels,
kernel_size=ks,
stride=stride,
dilation=dilation),
)
for ks, dilation in zip(kernel_size, dilations)
])
# Additional Max & 1x1 branch
self.branches.append(nn.Sequential(
nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0),
nn.BatchNorm2d(branch_channels),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3,1), stride=(stride,1), padding=(1,0)),
nn.BatchNorm2d(branch_channels) # 为什么还要加bn
))
self.branches.append(nn.Sequential(
nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride,1)),
nn.BatchNorm2d(branch_channels)
))
# Residual connection
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride)
# initialize
self.apply(weights_init)
def forward(self, x):
# Input dim: (N,C,T,V)
res = self.residual(x)
branch_outs = []
for tempconv in self.branches:
out = tempconv(x)
branch_outs.append(out)
out = torch.cat(branch_outs, dim=1)
out += res
return out
class CTRGC(nn.Module):
def __init__(self, in_channels, out_channels, rel_reduction=8, mid_reduction=1):
super(CTRGC, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if in_channels == 3 or in_channels == 9:
self.rel_channels = 8
self.mid_channels = 16
else:
self.rel_channels = in_channels // rel_reduction
self.mid_channels = in_channels // mid_reduction
self.conv1 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1)
self.conv2 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1)
self.conv3 = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1)
self.conv4 = nn.Conv2d(self.rel_channels, self.out_channels, kernel_size=1)
self.tanh = nn.Tanh()
for m in self.modules():
if isinstance(m, nn.Conv2d):
conv_init(m)
elif isinstance(m, nn.BatchNorm2d):
bn_init(m, 1)
def forward(self, x, A=None, alpha=1):
x1, x2, x3 = self.conv1(x).mean(-2), self.conv2(x).mean(-2), self.conv3(x)
x1 = self.tanh(x1.unsqueeze(-1) - x2.unsqueeze(-2))
x1 = self.conv4(x1) * alpha + (A.unsqueeze(0).unsqueeze(0) if A is not None else 0) # N,C,V,V
x1 = torch.einsum('ncuv,nctv->nctu', x1, x3)
return x1
class unit_tcn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
super(unit_tcn, self).__init__()
pad = int((kernel_size - 1) / 2)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0),
stride=(stride, 1))
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
conv_init(self.conv)
bn_init(self.bn, 1)
def forward(self, x):
x = self.bn(self.conv(x))
return x
class unit_gcn(nn.Module):
def __init__(self, in_channels, out_channels, A, coff_embedding=4, adaptive=True, residual=True):
super(unit_gcn, self).__init__()
inter_channels = out_channels // coff_embedding
self.inter_c = inter_channels
self.out_c = out_channels
self.in_c = in_channels
self.adaptive = adaptive
self.num_subset = A.shape[0]
self.convs = nn.ModuleList()
for i in range(self.num_subset):
self.convs.append(CTRGC(in_channels, out_channels))
if residual:
if in_channels != out_channels:
self.down = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1),
nn.BatchNorm2d(out_channels)
)
else:
self.down = lambda x: x
else:
self.down = lambda x: 0
if self.adaptive:
self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32)))
else:
self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False)
self.alpha = nn.Parameter(torch.zeros(1))
self.bn = nn.BatchNorm2d(out_channels)
self.soft = nn.Softmax(-2)
self.relu = nn.ReLU(inplace=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
conv_init(m)
elif isinstance(m, nn.BatchNorm2d):
bn_init(m, 1)
bn_init(self.bn, 1e-6)
def forward(self, x):
y = None
if self.adaptive:
A = self.PA
else:
A = self.A.cuda(x.get_device())
for i in range(self.num_subset):
z = self.convs[i](x, A[i], self.alpha)
y = z + y if y is not None else z
y = self.bn(y)
y += self.down(x)
y = self.relu(y)
return y
class TCN_GCN_unit(nn.Module):
def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, kernel_size=5, dilations=[1,2]):
super(TCN_GCN_unit, self).__init__()
self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive)
self.tcn1 = MultiScale_TemporalConv(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilations=dilations,
residual=False)
self.relu = nn.ReLU(inplace=True)
if not residual:
self.residual = lambda x: 0
elif (in_channels == out_channels) and (stride == 1):
self.residual = lambda x: x
else:
self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride)
def forward(self, x):
y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
return y
class Model(nn.Module):
def __init__(self, num_class=60, num_point=25, num_person=2, graph=None, graph_args=dict(), in_channels=3,
drop_out=0, adaptive=True):
super(Model, self).__init__()
if graph is None:
raise ValueError()
else:
Graph = import_class(graph)
self.graph = Graph(**graph_args)
A = self.graph.A # 3,25,25
self.num_class = num_class
self.num_point = num_point
self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point)
base_channel = 64
self.l1 = TCN_GCN_unit(in_channels, base_channel, A, residual=False, adaptive=adaptive)
self.l2 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
self.l3 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
self.l4 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive)
self.l5 = TCN_GCN_unit(base_channel, base_channel*2, A, stride=2, adaptive=adaptive)
self.l6 = TCN_GCN_unit(base_channel*2, base_channel*2, A, adaptive=adaptive)
self.l7 = TCN_GCN_unit(base_channel*2, base_channel*2, A, adaptive=adaptive)
self.l8 = TCN_GCN_unit(base_channel*2, base_channel*4, A, stride=2, adaptive=adaptive)
self.l9 = TCN_GCN_unit(base_channel*4, base_channel*4, A, adaptive=adaptive)
self.l10 = TCN_GCN_unit(base_channel*4, base_channel*4, A, adaptive=adaptive)
self.fc = nn.Linear(base_channel*4, num_class)
nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class))
bn_init(self.data_bn, 1)
if drop_out:
self.drop_out = nn.Dropout(drop_out)
else:
self.drop_out = lambda x: x
def forward(self, x):
if len(x.shape) == 3:
N, T, VC = x.shape
x = x.view(N, T, self.num_point, -1).permute(0, 3, 1, 2).contiguous().unsqueeze(-1)
N, C, T, V, M = x.size()
x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T)
x = self.data_bn(x)
x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V)
x = self.l1(x)
x = self.l2(x)
x = self.l3(x)
x = self.l4(x)
x = self.l5(x)
x = self.l6(x)
x = self.l7(x)
x = self.l8(x)
x = self.l9(x)
x = self.l10(x)
# N*M,C,T,V
c_new = x.size(1)
x = x.view(N, M, c_new, -1)
x = x.mean(3).mean(1)
x = self.drop_out(x)
return self.fc(x)
上述是图神经网络中的 C T R − G C N CTR-GCN CTR−GCN 的源码,我们将类比下面 C 语言的冒泡排序:
#include
void swap(int *a, int *b)
{
int temp = *a;
*a = *b;
*b = temp;
}
void bubble_sort(int arr[], int len)
{
int i, j;
for (i = 0; i < len - 1; i++)
{
for (j = 0; j < len - i - 1; j++)
{
if (arr[j] > arr[j + 1])
{
swap(&arr[j], &arr[j + 1]);
}
}
}
}
int main()
{
int i;
int arr[] = {3, 5, 1, 7, 2};
int len = sizeof(arr) / sizeof(arr[0]);
bubble_sort(arr, len);
printf("排序后的数组:\n");
for (i = 0; i < len; i++)
{
printf("%d ", arr[i]);
}
return 0;
}
import math
import pdb
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
这是代码的第一步,导入了一些需要用到的模块,如 m a t h math math、 n u m p y numpy numpy、 t o r c h torch torch 等。这些模块提供了一些数学函数、数组操作、张量计算等功能,方便我们编写和运行模型。
类比 C,这一步类似于 main.cpp
中的:
#include
像前面的函数:def import_class(name)
、def conv_branch_init(conv, branches)
、def conv_init(conv)
、def bn_init(bn, scale)
、def weights_init(m)
。
这些函数都是一些辅助函数,用于实现一些通用的功能,如:
import_class(name)
:这个函数可以根据一个字符串参数 n a m e name name,动态地导入一个类对象,并返回它。这样可以方便地根据配置文件中的参数来选择不同的类。conv_branch_init(conv, branches)
:这个函数可以对一个卷积层 c o n v conv conv 进行初始化,使其输出的方差在不同的分支 b r a n c h e s branches branches 上保持一致。这样可以避免某些分支的输出过大或过小,影响模型的收敛。conv_init(conv)
:这个函数可以对一个卷积层 c o n v conv conv 进行初始化,使其权重服从正态分布,偏置为0。这样可以避免权重过大或过小,影响模型的收敛。bn_init(bn, scale)
:这个函数可以对一个批标准化层 b n bn bn 进行初始化,使其权重为 s c a l e scale scale,偏置为0。这样可以控制批标准化层的缩放和平移效果。weights_init(m)
:这个函数可以对一个模块 m m m 进行递归地初始化,根据不同类型的子模块调用不同的初始化函数。这样可以方便地对整个模型进行统一的初始化。定义了很多类,class TemporalConv(nn.Module)
、class MultiScale_TemporalConv(nn.Module)
、class CTRGC(nn.Module)
、class unit_tcn(nn.Module)
、class unit_gcn(nn.Module)
、class TCN_GCN_unit(nn.Module)
它们都是用于构建 C T R − G C N CTR-GCN CTR−GCN 模型的不同组件。
TemporalConv
:这个类用于实现一维卷积操作,它继承了 t o r c h . n n . M o d u l e torch.nn.Module torch.nn.Module 类,并重写了初始化函数和前向传播函数。它有一个属性, c o n v conv conv,表示一个一维卷积层。它的前向传播函数接收一个输入张量,并返回一个输出张量,表示经过一维卷积后的特征。MultiScale_TemporalConv
:这个类用于实现多尺度的一维卷积操作,它继承了 t o r c h . n n . M o d u l e torch.nn.Module torch.nn.Module 类,并重写了初始化函数和前向传播函数。它有一个属性, c o n v conv conv,表示一个列表,包含多个不同尺度的一维卷积层。它的前向传播函数接收一个输入张量,并返回一个输出张量,表示经过多尺度一维卷积后的特征。CTRGC
:这个类用于实现 C T R − G C CTR-GC CTR−GC 操作,它继承了 t o r c h . n n . M o d u l e torch.nn.Module torch.nn.Module 类,并重写了初始化函数和前向传播函数。它有几个属性,如 s h a r e d _ c o n v shared\_conv shared_conv、 r e f i n e _ c o n v refine\_conv refine_conv、 b n bn bn 等,表示不同的子模块。它的前向传播函数接收一个输入张量和一个图对象,并返回一个输出张量,表示经过 C T R − G C CTR-GC CTR−GC 后的特征。unit_tcn
:这个类用于实现一个时间卷积单元,它继承了 t o r c h . n n . M o d u l e torch.nn.Module torch.nn.Module 类,并重写了初始化函数和前向传播函数。它有几个属性,如 t c n tcn tcn、 r e l u relu relu、 d r o p o u t dropout dropout 等,表示不同的子模块。它的前向传播函数接收一个输入张量,并返回一个输出张量,表示经过时间卷积单元后的特征。unit_gcn
:这个类用于实现一个图卷积单元,它继承了 t o r c h . n n . M o d u l e torch.nn.Module torch.nn.Module 类,并重写了初始化函数和前向传播函数。它有几个属性,如 g c n gcn gcn、 r e l u relu relu、 d r o p o u t dropout dropout 等,表示不同的子模块。它的前向传播函数接收一个输入张量和一个图对象,并返回一个输出张量,表示经过图卷积单元后的特征。TCN_GCN_unit
:这个类用于实现一个 T C N − G C N TCN-GCN TCN−GCN 单元,它继承了 t o r c h . n n . M o d u l e torch.nn.Module torch.nn.Module 类,并重写了初始化函数和前向传播函数。它有几个属性,如 t c n tcn tcn、 g c n gcn gcn 等,表示不同的子模块。它的前向传播函数接收一个输入张量和一个图对象,并返回一个输出张量,表示经过 T C N − G C N TCN-GCN TCN−GCN 单元后的特征。这些辅助(辅助函数和辅助类),类比 C 中 main.cpp
中的辅助函数:
void swap(int *a, int *b)
{
int temp = *a;
*a = *b;
*b = temp;
}
接下来,实现了 class Model(nn.Module)
,它继承了 t o r c h . n n . M o d u l e torch.nn.Module torch.nn.Module 类,并重写了初始化函数和前向传播函数。
它有几个属性,如 g r a p h graph graph、 d a t a _ b n data\_bn data_bn、 t c n _ g c n _ u n i t tcn\_gcn\_unit tcn_gcn_unit、 f c fc fc 等,表示不同的子模块。它的前向传播函数接收一个输入张量,并返回一个输出张量,表示每个骨架序列对应的动作类别概率。
Model 类是 C T R − G C N CTR-GCN CTR−GCN 模型的最终封装,它可以用于训练和测试。
我觉得这一步可以类比 C 中 main.cpp
中的:
void bubble_sort(int arr[], int len)
{
int i, j;
for (i = 0; i < len - 1; i++)
{
for (j = 0; j < len - i - 1; j++)
{
if (arr[j] > arr[j + 1])
{
swap(&arr[j], &arr[j + 1]);
}
}
}
}
相比于辅助函数的通用,辅助类的各个不完整组件, M o d e l Model Model 类是调用了辅助函数和辅助类,从而实现了完整的图神经网络功能。
这个 b u b b l e _ s o r t bubble\_sort bubble_sort 函数就是为了实现完整的冒泡排序功能。
封装完之后, M o d e l Model Model 类在 main.py
文件中被调用,这个文件是用于运行模型的主程序。在 main.py
文件中,有一个函数叫做 m o d e l _ l o a d model\_load model_load,它可以根据配置文件中的参数,动态地导入和创建 Model 类的对象,并返回它。然后,在 m a i n main main 函数中,会调用 m o d e l _ l o a d model\_load model_load 函数来创建模型,并将其传递给训练和测试的函数,进行模型的训练和测试。
类比 C 的 main.cpp 中 main()
函数:
int main()
{
int i;
int arr[] = {3, 5, 1, 7, 2};
int len = sizeof(arr) / sizeof(arr[0]);
bubble_sort(arr, len);
printf("排序后的数组:\n");
for (i = 0; i < len; i++)
{
printf("%d ", arr[i]);
}
return 0;
}
m a i n main main 是所有前面定义函数的归宿,最终在 m a i n main main 里有具体的输入,调用之前的功能函数:bubble_sort(arr, len);
,得到输出。
神经网络模型中的类通常是继承了 torch.nn.Module 类的子类,它需要重写两个方法,分别是初始化函数 __init__
和前向传播函数 forward
。
__init__
函数用于定义模型的参数和子模块,forward
函数用于定义模型的计算逻辑。__init__
方法是在模型被创建时执行的,而forward
方法是在模型被调用时执行的。例如,一个简单的全连接网络可以定义为:
import torch.nn as nn
class FCN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(FCN, self).__init__()
# 定义一个线性层,将输入映射到隐藏层
self.linear1 = nn.Linear(input_size, hidden_size)
# 定义一个激活函数,增加非线性
self.relu = nn.ReLU()
# 定义一个线性层,将隐藏层映射到输出
self.linear2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
# 前向传播函数,接收一个输入张量x,返回一个输出张量y
# 将输入张量通过第一个线性层
x = self.linear1(x)
# 将输出张量通过激活函数
x = self.relu(x)
# 将输出张量通过第二个线性层
y = self.linear2(x)
# 返回输出张量
return y
它有三个参数:输入大小,隐藏层大小和输出大小。
它有两个方法:__init__
和forward
。
__init__
方法用于初始化网络的权重和偏置,以及定义网络的结构。forward
方法用于实现网络的前向传播过程,即根据输入计算输出。具体来说,
class FCN(nn.Module)
class FCN(nn.Module)
从这一行得到启发,class ClassName(继承自哪个类)
这个括号里面写这个类继承自哪个类。
继承是面向对象编程中的一个概念,它表示一个类可以从另一个类获取属性和方法,从而实现代码的复用和扩展。
例如,FCN 类继承自 nn.Module 类,就可以使用 nn.Module 类提供的一些方法,比如 parameters(), to(), save()等。
nn.Module 是 PyTorch 框架中提供的一个基类,它封装了神经网络的一些基本功能,比如参数管理,设备转换,保存和加载等。继承自 nn.Module 类的子类可以方便地使用这些功能,而不需要自己实现。
def __init__(self, input_size, hidden_size, output_size)
s e l f self self 是一个特殊的参数,它表示类的实例对象本身。在 Python 中,定义类的方法时,第一个参数必须是 s e l f self self,用于区分类的实例和类的方法。
例如,当我们创建一个 FCN 类的实例 fcn 时,我们可以用 fcn.linear1 访问它的第一个线性层属性,而不是 FCN.linear1。这是因为 self.linear1 表示实例对象的属性,而 FCN.linear1 表示类的属性。
类的属性是指类或类的实例对象所拥有的变量或模块。类的属性可以分为两种:类属性和实例属性。
- 类属性是指类本身所拥有的属性,它可以被类或类的所有实例对象共享。
- 实例属性是指类的实例对象所拥有的属性,它只能被该实例对象访问。
例如,在 FCN 类中,input_size, hidden_size, output_size 是类属性,因为它们是在类定义时就确定的,而 linear1, relu, linear2 是实例属性,因为它们是在实例化时才创建的。
super(FCN, self).__init__()
它的意思是调用 FCN 类的父类 nn.Module 的 __init__
方法,从而继承父类的一些属性和方法。
这样做的好处是可以避免重复编写父类的初始化代码,也可以避免多重继承时的冲突。
super(FCN, self)
括号里的 FCN 和 self 是 super() 函数的两个参数,其中,
要注意,super(FCN, self).__init__()
并不是直接对子类和实例化对象进行初始化,而是通过调用父类的__init__()
方法来间接地初始化子类和实例化对象。即 super()
函数会返回一个父类的对象,然后用这个对象来执行父类的__init__()
方法。
当调用 super(FCN, self)
时,Python 会根据 FCN 的继承关系,找到它的一个父类,然后把 self 转换成那个父类的对象,再调用那个父类的方法。
super()
函数可以让你避免显式地引用父类的名字,这样可以方便地使用多重继承。
- 在 Python 3 中,可以直接写
super().__init__()
,而不需要写super(FCN, self).__init()__
,这样更简洁,Python 会自动推断出要调用的父类和实例。- 但是在 Python 2 中,必须写明这两个参数,否则会报错。
forward(self, x)
forward 函数是在你创建了子类的对象后,用这个对象对输入张量进行计算的时候被调用的。
比如,
model.CTRGCN.Model
",并将其分割为两部分,前面的部分表示模块的路径,后面的部分表示类的名称。import_class
函数,根据模块的路径,动态地导入模块对象,并从模块对象中获取类对象,例如 Model 类。torch.nn.DataParallel
函数,根据配置文件中的设备编号,将类对象包装为一个并行计算的对象,以便在多个 GPU 上运行模型。
torch.nn.DataParallel
函数是一个用于实现模型并行计算的函数,它可以将模型的参数和输入数据分配到多个 GPU 上,从而加速模型的训练和测试。使用 torch.nn.DataParallel 函数的好处有:
- 可以提高模型的运行效率,缩短训练和测试的时间。
- 可以增大模型的批量大小,提高模型的泛化能力。
- 可以简化模型的编写和调用,无需手动处理多个 GPU 之间的通信和同步。