Gram矩阵其实是一种度量矩阵。矩阵分析中有这样的定义。
设 V 是 n 维欧式空间 ϵ1,⋯,ϵn 是它的一个基, gij=(ϵi,ϵj),G=(gij)n×n , 则称 G 为基 ϵ1,⋯,ϵn 的度量矩阵,也称为Gram矩阵。
重点解释:
因此:对于三维的向量求Gram矩阵,就是要求 gij ,而 gij 就是第i通道与第j通道进行点乘,然后相加。其实专业点就是i通道的feature map与j通道的feature map进行内积。查看内积定义:对于两个竖直向量 α=(a1,⋯,an)T,β=(b1,⋯,bn) , 规定 <α,β>=αTβ=∑ni=1aibi 。所以可以看出,如果将矩阵先进行竖直向量化,然后将其转置,两者进行矩阵相乘,结果就是内积,或是说对应位置点乘,然后相加。
Gram矩阵是计算每个通道i的feature map与每个通道j的feature map的内积。自然就会得到C*C的矩阵。Gram矩阵的每个值可以说是代表i通道的feature map与j通道的feature map的互相关程度。而卷积网络的卷积其实也是互相关,具体情况见CNN基本问题 中的卷积到底是如何卷积的??
。 值得注意的是:卷积网络的卷积和互相关是一样的,不是信号处理中所说的要先将卷积核旋转180再计算
require("torch")
require('nn')
local net = nn.Sequential()
local module = nn.SpatialConvolution(1,1,2,2):noBias()
net:add(module)
local wt = torch.Tensor(2,2)
wt[{1,1}]=5
wt[{1,2}]=6
wt[{2,1}]=7
wt[{2,2}]=8
print(wt)
net:get(1).weight = wt:view(1,1,2,2)
local input = torch.Tensor(2,2)
--local wt = torch.Tensor(2,2)
input[{1,1}]=1
input[{1,2}]=5
input[{2,1}]=3
input[{2,2}]=4
print(net:forward(input:view(1,1,2,2)))
print('xcorr2')
print(torch.xcorr2(input,wt))
print('conv2')
print(torch.conv2(input,wt))
-- 输出结果
[torch.DoubleTensor of size 2x2]
(1,1,.,.) =
88
[torch.DoubleTensor of size 1x1x1x1]
xcorr2
88
[torch.DoubleTensor of size 1x1]
-- 这里确实是卷积核旋转180度,再卷
conv2
81
[torch.DoubleTensor of size 1x1]
结论:卷积网络中的卷积就是互相关,等价于torch.xcorr2或torch.xorr3之类的
,而信号处理说的卷积等价于torch.conv2之类的
两种计算Gram矩阵的方式
-- 这个只能针对batch=1的情况
-- 不过可以直接backward。
function GramMatrix()
local net = nn.Sequential()
net:add(nn.View(-1):setNumInputDims(2))
local concat = nn.ConcatTable()
concat:add(nn.Identity())
concat:add(nn.Identity())
net:add(concat)
net:add(nn.MM(false, true))
return net
end
-- 这个可以计算N*C*H*W的Gram矩阵,就是batch的。
-- 得到的是 N*C*C, 当然N=1的话就直接是普通三维的Gram。
local gramMatrix = function(input)
local N,C,H,W = input:size(1), input:size(2), input:size(3), input:size(4)
local vecInput = input:view(N,C,H*W)
print(vecInput:transpose(2,3))
local gramMatrix = torch.bmm(vecInput, vecInput:transpose(2,3))
local output = gramMatrix / H / W
return output
end
第一个没什么好说的,就是直接MM按照内积的定义进行计算,第二个进行转置为true。
第二个就是bmm,首先
local N,C,H,W = input:size(1), input:size(2), input:size(3), input:size(4)
local vecInput = input:view(N,C,H*W)
将N*C*H*W变成N*C*(H*W),然后N*C*(H*W)与N*(H*W)*C进行bmm,这里就是第一个C的i维的feature map与第二个C的j维的feature map进行内积(或是说点乘,全部相加)。