广播操作允许你对不同形状的张量执行逐元素操作,而无需显式循环。
def cdists(mols, grid):
'''
Calculates the pairwise Euclidean distances between a set of molecules and a list
of positions on a grid (uses inplace operations to minimize memory demands).
Args:
mols (torch.Tensor): data set (of molecules) with shape
(batch_size x n_atoms x n_dims)
grid (torch.Tensor): array (of positions) with shape (n_positions x n_dims)
Returns:
torch.Tensor: batch of distance matrices (batch_size x n_atoms x n_positions)
'''
if len(mols.size()) == len(grid.size())+1:
grid = grid.unsqueeze(0) # add batch dimension
return F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1),
inplace=True).sqrt_()
那么,上面的代码中为什么要进行: (mols[:, :, None, :] - grid[:, None, :, :]) 这样的操作呢?
这段代码用于计算一组分子(mols
)与一个网格上的一组位置格点(grid
)之间的欧几里得距离。这里的(mols[:, :, None, :] - grid[:, None, :, :])
操作涉及到张量的广播操作,它的目的是计算每个分子的每个原子与每个位置之间的距离。
mols
张量的形状为 (batch_size x n_atoms x n_dims)
,其中 batch_size
是批次大小,n_atoms
是原子数量,n_dims
是原子的坐标维度(通常是3维,表示xyz坐标)。grid
张量的形状为 (n_positions x n_dims)
,其中 n_positions
是位置格点的数量,n_dims
同样是坐标的维度。首先,如果 mols
张量的维度比 grid
张量的维度多1,代码会通过 grid.unsqueeze(0)
添加一个额外的维度,以匹配 mols
张量的 batch 维度。这是为了使广播操作生效。
接下来,代码使用广播操作计算每个分子的每个原子与每个位置之间的距离。广播操作允许你对不具有相同形状的张量执行逐元素操作,而无需显式循环。
mols[:, :, None, :]
的形状变成 (batch_size x n_atoms x 1 x n_dims)
。这个操作在 n_atoms
维度上添加了一个额外的维度,以便与 grid[:, None, :, :]
进行广播。
grid[:, None, :, :]
的形状变成 (1 x 1 x n_positions x n_dims)
。这个操作在 batch 维度和 n_atoms
维度上添加了额外的维度,以便与 mols[:, :, None, :]
进行广播。
然后,这两个张量进行逐元素的减法操作,计算了每个分子的每个原子与每个位置格点之间的差异。结果是一个张量,其形状为 (batch_size x n_atoms x n_positions x n_dims)
。
最后,使用 .pow_(2)
和 .sqrt_()
操作,计算了每个分子的每个原子与每个位置格点之间的欧几里得距离。
总之,这段代码通过广播操作高效地计算了每个分子的每个原子与每个位置之间的距离,而无需显式的循环操作。这有助于提高计算效率,特别是在处理大规模数据时。
grid_test.py :
import torch
import torch.nn.functional as F
# 示例数据
batch_size = 2
n_atoms = 5
n_dims = 3
n_positions = 10
mols = torch.rand(batch_size, n_atoms, n_dims) # 随机生成分子坐标数据
grid = torch.rand(n_positions, n_dims) # 随机生成网格位置数据
print("batch_size = ", batch_size," n_atoms = ", n_atoms," n_dims = ",n_dims," n_positions = ",n_positions)
# 打印示例数据
print("示例数据 mols:")
print("mols = torch.rand(batch_size, n_atoms, n_dims)")
print(mols)
print(mols.shape)
print("\n示例数据 grid:")
print("grid = torch.rand(n_positions, n_dims)")
print(grid)
print(grid.shape)
# 如果 mols 张量的维度比 grid 张量的维度多1,添加一个额外的维度
if len(mols.size()) == len(grid.size()) + 1:
grid = grid.unsqueeze(0)
print("\n添加额外维度后的 grid:")
print(grid)
print(grid.shape)
print("\nmols[:, :, None, :]")
print(mols[:, :, None, :])
print(mols[:, :, None, :].shape)
print("\ngrid[:, None, :, :]")
print(grid[:, None, :, :])
print(grid[:, None, :, :].shape)
print("\nmols[:, :, None, :] - grid[:, None, :, :]")
print(mols[:, :, None, :] - grid[:, None, :, :])
print((mols[:, :, None, :] - grid[:, None, :, :]).shape)
print("\n(mols[:, :, None, :] - grid[:, None, :, :]).pow_(2)")
print((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2))
print("\ntorch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)")
print(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1))
print((torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)).shape)
print("\nF.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)")
print(F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True))
print((F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)).shape)
print("\nF.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()")
print(F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_())
print((F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()).shape)
# 计算每个分子的每个原子与每个位置之间的距离
result = F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()
# 打印计算结果
print("\n计算结果:")
print(result)
$ python grid_test.py
batch_size = 2 n_atoms = 5 n_dims = 3 n_positions = 10
示例数据 mols:
mols = torch.rand(batch_size, n_atoms, n_dims)
tensor([[[0.3787, 0.1093, 0.5062],
[0.3149, 0.4295, 0.1202],
[0.6499, 0.6533, 0.6489],
[0.9395, 0.5027, 0.7664],
[0.5991, 0.5733, 0.6474]],
[[0.1370, 0.3499, 0.7365],
[0.3564, 0.4096, 0.1820],
[0.2576, 0.4737, 0.2487],
[0.3169, 0.5875, 0.0414],
[0.9958, 0.2101, 0.3953]]])
torch.Size([2, 5, 3])
示例数据 grid:
grid = torch.rand(n_positions, n_dims)
tensor([[0.0985, 0.5795, 0.3998],
[0.3772, 0.8160, 0.7968],
[0.9571, 0.0205, 0.5068],
[0.7847, 0.9675, 0.3421],
[0.9007, 0.0692, 0.3701],
[0.8763, 0.4045, 0.2783],
[0.5665, 0.1797, 0.8626],
[0.4253, 0.6738, 0.3789],
[0.3690, 0.3504, 0.3530],
[0.1773, 0.4790, 0.9227]])
torch.Size([10, 3])
添加额外维度后的 grid:
tensor([[[0.0985, 0.5795, 0.3998],
[0.3772, 0.8160, 0.7968],
[0.9571, 0.0205, 0.5068],
[0.7847, 0.9675, 0.3421],
[0.9007, 0.0692, 0.3701],
[0.8763, 0.4045, 0.2783],
[0.5665, 0.1797, 0.8626],
[0.4253, 0.6738, 0.3789],
[0.3690, 0.3504, 0.3530],
[0.1773, 0.4790, 0.9227]]])
torch.Size([1, 10, 3])
mols[:, :, None, :]
tensor([[[[0.3787, 0.1093, 0.5062]],
[[0.3149, 0.4295, 0.1202]],
[[0.6499, 0.6533, 0.6489]],
[[0.9395, 0.5027, 0.7664]],
[[0.5991, 0.5733, 0.6474]]],
[[[0.1370, 0.3499, 0.7365]],
[[0.3564, 0.4096, 0.1820]],
[[0.2576, 0.4737, 0.2487]],
[[0.3169, 0.5875, 0.0414]],
[[0.9958, 0.2101, 0.3953]]]])
torch.Size([2, 5, 1, 3])
grid[:, None, :, :]
tensor([[[[0.0985, 0.5795, 0.3998],
[0.3772, 0.8160, 0.7968],
[0.9571, 0.0205, 0.5068],
[0.7847, 0.9675, 0.3421],
[0.9007, 0.0692, 0.3701],
[0.8763, 0.4045, 0.2783],
[0.5665, 0.1797, 0.8626],
[0.4253, 0.6738, 0.3789],
[0.3690, 0.3504, 0.3530],
[0.1773, 0.4790, 0.9227]]]])
torch.Size([1, 1, 10, 3])
mols[:, :, None, :] - grid[:, None, :, :]
tensor([[[[ 2.8014e-01, -4.7024e-01, 1.0644e-01],
[ 1.4571e-03, -7.0671e-01, -2.9065e-01],
[-5.7839e-01, 8.8782e-02, -5.7650e-04],
[-4.0609e-01, -8.5820e-01, 1.6411e-01],
[-5.2208e-01, 4.0115e-02, 1.3611e-01],
[-4.9762e-01, -2.9525e-01, 2.2792e-01],
[-1.8779e-01, -7.0417e-02, -3.5645e-01],
[-4.6644e-02, -5.6448e-01, 1.2727e-01],
[ 9.6272e-03, -2.4109e-01, 1.5320e-01],
[ 2.0133e-01, -3.6968e-01, -4.1653e-01]],
[[ 2.1635e-01, -1.4999e-01, -2.7958e-01],
[-6.2337e-02, -3.8646e-01, -6.7667e-01],
[-6.4219e-01, 4.0903e-01, -3.8660e-01],
[-4.6988e-01, -5.3796e-01, -2.2191e-01],
[-5.8587e-01, 3.6036e-01, -2.4991e-01],
[-5.6142e-01, 2.4999e-02, -1.5810e-01],
[-2.5159e-01, 2.4983e-01, -7.4247e-01],
[-1.1044e-01, -2.4423e-01, -2.5875e-01],
[-5.4167e-02, 7.9156e-02, -2.3282e-01],
[ 1.3754e-01, -4.9439e-02, -8.0255e-01]],
[[ 5.5133e-01, 7.3810e-02, 2.4912e-01],
[ 2.7265e-01, -1.6266e-01, -1.4796e-01],
[-3.0720e-01, 6.3283e-01, 1.4211e-01],
[-1.3490e-01, -3.1416e-01, 3.0679e-01],
[-2.5089e-01, 5.8416e-01, 2.7880e-01],
[-2.2644e-01, 2.4880e-01, 3.7061e-01],
[ 8.3396e-02, 4.7363e-01, -2.1377e-01],
[ 2.2454e-01, -2.0431e-02, 2.6995e-01],
[ 2.8082e-01, 3.0296e-01, 2.9588e-01],
[ 4.7252e-01, 1.7436e-01, -2.7384e-01]],
[[ 8.4102e-01, -7.6851e-02, 3.6667e-01],
[ 5.6233e-01, -3.1332e-01, -3.0420e-02],
[-1.7516e-02, 4.8217e-01, 2.5965e-01],
[ 1.5479e-01, -4.6482e-01, 4.2434e-01],
[ 3.8800e-02, 4.3350e-01, 3.9634e-01],
[ 6.3253e-02, 9.8140e-02, 4.8815e-01],
[ 3.7309e-01, 3.2297e-01, -9.6225e-02],
[ 5.1423e-01, -1.7109e-01, 3.8750e-01],
[ 5.7050e-01, 1.5230e-01, 4.1342e-01],
[ 7.6221e-01, 2.3702e-02, -1.5630e-01]],
[[ 5.0062e-01, -6.2699e-03, 2.4768e-01],
[ 2.2193e-01, -2.4274e-01, -1.4941e-01],
[-3.5792e-01, 5.5275e-01, 1.4067e-01],
[-1.8561e-01, -3.9424e-01, 3.0535e-01],
[-3.0160e-01, 5.0408e-01, 2.7735e-01],
[-2.7715e-01, 1.6872e-01, 3.6916e-01],
[ 3.2683e-02, 3.9355e-01, -2.1521e-01],
[ 1.7383e-01, -1.0051e-01, 2.6851e-01],
[ 2.3010e-01, 2.2288e-01, 2.9444e-01],
[ 4.2181e-01, 9.4283e-02, -2.7528e-01]]],
[[[ 3.8479e-02, -2.2963e-01, 3.3670e-01],
[-2.4021e-01, -4.6610e-01, -6.0388e-02],
[-8.2006e-01, 3.2939e-01, 2.2968e-01],
[-6.4775e-01, -6.1759e-01, 3.9437e-01],
[-7.6374e-01, 2.8072e-01, 3.6637e-01],
[-7.3929e-01, -5.4636e-02, 4.5818e-01],
[-4.2946e-01, 1.7019e-01, -1.2619e-01],
[-2.8831e-01, -3.2387e-01, 3.5753e-01],
[-2.3204e-01, -4.7952e-04, 3.8346e-01],
[-4.0334e-02, -1.2907e-01, -1.8627e-01]],
[[ 2.5793e-01, -1.6992e-01, -2.1771e-01],
[-2.0759e-02, -4.0639e-01, -6.1480e-01],
[-6.0061e-01, 3.8910e-01, -3.2472e-01],
[-4.2830e-01, -5.5788e-01, -1.6004e-01],
[-5.4429e-01, 3.4043e-01, -1.8804e-01],
[-5.1984e-01, 5.0723e-03, -9.6227e-02],
[-2.1001e-01, 2.2990e-01, -6.8060e-01],
[-6.8861e-02, -2.6416e-01, -1.9688e-01],
[-1.2589e-02, 5.9229e-02, -1.7095e-01],
[ 1.7911e-01, -6.9366e-02, -7.4067e-01]],
[[ 1.5904e-01, -1.0583e-01, -1.5102e-01],
[-1.1965e-01, -3.4230e-01, -5.4810e-01],
[-6.9950e-01, 4.5319e-01, -2.5803e-01],
[-5.2719e-01, -4.9380e-01, -9.3347e-02],
[-6.4318e-01, 4.0452e-01, -1.2134e-01],
[-6.1873e-01, 6.9158e-02, -2.9533e-02],
[-3.0890e-01, 2.9399e-01, -6.1391e-01],
[-1.6775e-01, -2.0007e-01, -1.3018e-01],
[-1.1148e-01, 1.2331e-01, -1.0426e-01],
[ 8.0225e-02, -5.2802e-03, -6.7398e-01]],
[[ 2.1837e-01, 8.0081e-03, -3.5833e-01],
[-6.0314e-02, -2.2846e-01, -7.5541e-01],
[-6.4016e-01, 5.6703e-01, -4.6534e-01],
[-4.6786e-01, -3.7996e-01, -3.0066e-01],
[-5.8385e-01, 5.1836e-01, -3.2865e-01],
[-5.5940e-01, 1.8300e-01, -2.3684e-01],
[-2.4956e-01, 4.0783e-01, -8.2122e-01],
[-1.0842e-01, -8.6233e-02, -3.3749e-01],
[-5.2144e-02, 2.3716e-01, -3.1157e-01],
[ 1.3956e-01, 1.0856e-01, -8.8129e-01]],
[[ 8.9723e-01, -3.6938e-01, -4.4795e-03],
[ 6.1855e-01, -6.0585e-01, -4.0157e-01],
[ 3.8695e-02, 1.8964e-01, -1.1149e-01],
[ 2.1100e-01, -7.5734e-01, 5.3189e-02],
[ 9.5011e-02, 1.4097e-01, 2.5192e-02],
[ 1.1946e-01, -1.9439e-01, 1.1700e-01],
[ 4.2930e-01, 3.0443e-02, -4.6737e-01],
[ 5.7044e-01, -4.6362e-01, 1.6351e-02],
[ 6.2672e-01, -1.4023e-01, 4.2277e-02],
[ 8.1842e-01, -2.6882e-01, -5.2744e-01]]]])
torch.Size([2, 5, 10, 3])
(mols[:, :, None, :] - grid[:, None, :, :]).pow_(2)
tensor([[[[7.8481e-02, 2.2112e-01, 1.1329e-02],
[2.1231e-06, 4.9944e-01, 8.4477e-02],
[3.3454e-01, 7.8823e-03, 3.3235e-07],
[1.6491e-01, 7.3651e-01, 2.6931e-02],
[2.7257e-01, 1.6092e-03, 1.8526e-02],
[2.4763e-01, 8.7170e-02, 5.1948e-02],
[3.5266e-02, 4.9586e-03, 1.2706e-01],
[2.1757e-03, 3.1863e-01, 1.6198e-02],
[9.2683e-05, 5.8124e-02, 2.3469e-02],
[4.0534e-02, 1.3667e-01, 1.7349e-01]],
[[4.6807e-02, 2.2498e-02, 7.8165e-02],
[3.8859e-03, 1.4935e-01, 4.5788e-01],
[4.1240e-01, 1.6730e-01, 1.4946e-01],
[2.2079e-01, 2.8940e-01, 4.9245e-02],
[3.4325e-01, 1.2986e-01, 6.2455e-02],
[3.1519e-01, 6.2496e-04, 2.4995e-02],
[6.3296e-02, 6.2414e-02, 5.5127e-01],
[1.2197e-02, 5.9650e-02, 6.6951e-02],
[2.9340e-03, 6.2657e-03, 5.4207e-02],
[1.8916e-02, 2.4442e-03, 6.4408e-01]],
[[3.0397e-01, 5.4479e-03, 6.2063e-02],
[7.4335e-02, 2.6459e-02, 2.1893e-02],
[9.4375e-02, 4.0047e-01, 2.0195e-02],
[1.8198e-02, 9.8694e-02, 9.4122e-02],
[6.2946e-02, 3.4124e-01, 7.7727e-02],
[5.1273e-02, 6.1902e-02, 1.3735e-01],
[6.9548e-03, 2.2433e-01, 4.5697e-02],
[5.0420e-02, 4.1742e-04, 7.2876e-02],
[7.8857e-02, 9.1784e-02, 8.7545e-02],
[2.2327e-01, 3.0403e-02, 7.4989e-02]],
[[7.0732e-01, 5.9061e-03, 1.3445e-01],
[3.1622e-01, 9.8171e-02, 9.2536e-04],
[3.0679e-04, 2.3249e-01, 6.7419e-02],
[2.3960e-02, 2.1606e-01, 1.8006e-01],
[1.5054e-03, 1.8792e-01, 1.5708e-01],
[4.0009e-03, 9.6314e-03, 2.3829e-01],
[1.3919e-01, 1.0431e-01, 9.2593e-03],
[2.6444e-01, 2.9273e-02, 1.5016e-01],
[3.2548e-01, 2.3194e-02, 1.7092e-01],
[5.8096e-01, 5.6178e-04, 2.4429e-02]],
[[2.5062e-01, 3.9311e-05, 6.1346e-02],
[4.9254e-02, 5.8923e-02, 2.2322e-02],
[1.2810e-01, 3.0553e-01, 1.9787e-02],
[3.4452e-02, 1.5542e-01, 9.3239e-02],
[9.0964e-02, 2.5410e-01, 7.6924e-02],
[7.6812e-02, 2.8467e-02, 1.3628e-01],
[1.0682e-03, 1.5488e-01, 4.6316e-02],
[3.0217e-02, 1.0102e-02, 7.2099e-02],
[5.2947e-02, 4.9675e-02, 8.6694e-02],
[1.7792e-01, 8.8893e-03, 7.5782e-02]]],
[[[1.4806e-03, 5.2729e-02, 1.1337e-01],
[5.7700e-02, 2.1725e-01, 3.6467e-03],
[6.7250e-01, 1.0850e-01, 5.2755e-02],
[4.1958e-01, 3.8142e-01, 1.5553e-01],
[5.8330e-01, 7.8806e-02, 1.3423e-01],
[5.4655e-01, 2.9851e-03, 2.0993e-01],
[1.8443e-01, 2.8965e-02, 1.5925e-02],
[8.3122e-02, 1.0489e-01, 1.2783e-01],
[5.3842e-02, 2.2994e-07, 1.4704e-01],
[1.6269e-03, 1.6660e-02, 3.4695e-02]],
[[6.6527e-02, 2.8872e-02, 4.7397e-02],
[4.3095e-04, 1.6515e-01, 3.7797e-01],
[3.6073e-01, 1.5140e-01, 1.0545e-01],
[1.8344e-01, 3.1124e-01, 2.5613e-02],
[2.9626e-01, 1.1589e-01, 3.5358e-02],
[2.7023e-01, 2.5728e-05, 9.2596e-03],
[4.4104e-02, 5.2854e-02, 4.6322e-01],
[4.7418e-03, 6.9780e-02, 3.8761e-02],
[1.5849e-04, 3.5081e-03, 2.9225e-02],
[3.2082e-02, 4.8116e-03, 5.4860e-01]],
[[2.5293e-02, 1.1201e-02, 2.2806e-02],
[1.4316e-02, 1.1717e-01, 3.0042e-01],
[4.8930e-01, 2.0538e-01, 6.6580e-02],
[2.7793e-01, 2.4384e-01, 8.7136e-03],
[4.1369e-01, 1.6363e-01, 1.4724e-02],
[3.8283e-01, 4.7828e-03, 8.7222e-04],
[9.5418e-02, 8.6428e-02, 3.7688e-01],
[2.8140e-02, 4.0030e-02, 1.6948e-02],
[1.2428e-02, 1.5206e-02, 1.0870e-02],
[6.4360e-03, 2.7880e-05, 4.5425e-01]],
[[4.7686e-02, 6.4129e-05, 1.2840e-01],
[3.6378e-03, 5.2195e-02, 5.7065e-01],
[4.0981e-01, 3.2152e-01, 2.1654e-01],
[2.1889e-01, 1.4437e-01, 9.0394e-02],
[3.4088e-01, 2.6870e-01, 1.0801e-01],
[3.1292e-01, 3.3489e-02, 5.6095e-02],
[6.2282e-02, 1.6632e-01, 6.7440e-01],
[1.1754e-02, 7.4361e-03, 1.1390e-01],
[2.7190e-03, 5.6243e-02, 9.7075e-02],
[1.9477e-02, 1.1786e-02, 7.7667e-01]],
[[8.0503e-01, 1.3644e-01, 2.0066e-05],
[3.8260e-01, 3.6705e-01, 1.6126e-01],
[1.4973e-03, 3.5964e-02, 1.2431e-02],
[4.4522e-02, 5.7357e-01, 2.8291e-03],
[9.0270e-03, 1.9874e-02, 6.3462e-04],
[1.4272e-02, 3.7786e-02, 1.3690e-02],
[1.8430e-01, 9.2680e-04, 2.1844e-01],
[3.2541e-01, 2.1494e-01, 2.6736e-04],
[3.9277e-01, 1.9664e-02, 1.7874e-03],
[6.6981e-01, 7.2266e-02, 2.7820e-01]]]])
torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1)
tensor([[[0.3109, 0.5839, 0.3424, 0.9284, 0.2927, 0.3867, 0.1673, 0.3370,
0.0817, 0.3507],
[0.1475, 0.6111, 0.7292, 0.5594, 0.5356, 0.3408, 0.6770, 0.1388,
0.0634, 0.6654],
[0.3715, 0.1227, 0.5150, 0.2110, 0.4819, 0.2505, 0.2770, 0.1237,
0.2582, 0.3287],
[0.8477, 0.4153, 0.3002, 0.4201, 0.3465, 0.2519, 0.2528, 0.4439,
0.5196, 0.6060],
[0.3120, 0.1305, 0.4534, 0.2831, 0.4220, 0.2416, 0.2023, 0.1124,
0.1893, 0.2626]],
[[0.1676, 0.2786, 0.8337, 0.9565, 0.7963, 0.7595, 0.2293, 0.3158,
0.2009, 0.0530],
[0.1428, 0.5436, 0.6176, 0.5203, 0.4475, 0.2795, 0.5602, 0.1133,
0.0329, 0.5855],
[0.0593, 0.4319, 0.7613, 0.5305, 0.5920, 0.3885, 0.5587, 0.0851,
0.0385, 0.4607],
[0.1761, 0.6265, 0.9479, 0.4537, 0.7176, 0.4025, 0.9030, 0.1331,
0.1560, 0.8079],
[0.9415, 0.9109, 0.0499, 0.6209, 0.0295, 0.0657, 0.4037, 0.5406,
0.4142, 1.0203]]])
torch.Size([2, 5, 10])
F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True)
tensor([[[0.3109, 0.5839, 0.3424, 0.9284, 0.2927, 0.3867, 0.1673, 0.3370,
0.0817, 0.3507],
[0.1475, 0.6111, 0.7292, 0.5594, 0.5356, 0.3408, 0.6770, 0.1388,
0.0634, 0.6654],
[0.3715, 0.1227, 0.5150, 0.2110, 0.4819, 0.2505, 0.2770, 0.1237,
0.2582, 0.3287],
[0.8477, 0.4153, 0.3002, 0.4201, 0.3465, 0.2519, 0.2528, 0.4439,
0.5196, 0.6060],
[0.3120, 0.1305, 0.4534, 0.2831, 0.4220, 0.2416, 0.2023, 0.1124,
0.1893, 0.2626]],
[[0.1676, 0.2786, 0.8337, 0.9565, 0.7963, 0.7595, 0.2293, 0.3158,
0.2009, 0.0530],
[0.1428, 0.5436, 0.6176, 0.5203, 0.4475, 0.2795, 0.5602, 0.1133,
0.0329, 0.5855],
[0.0593, 0.4319, 0.7613, 0.5305, 0.5920, 0.3885, 0.5587, 0.0851,
0.0385, 0.4607],
[0.1761, 0.6265, 0.9479, 0.4537, 0.7176, 0.4025, 0.9030, 0.1331,
0.1560, 0.8079],
[0.9415, 0.9109, 0.0499, 0.6209, 0.0295, 0.0657, 0.4037, 0.5406,
0.4142, 1.0203]]])
torch.Size([2, 5, 10])
F.relu(torch.sum((mols[:, :, None, :] - grid[:, None, :, :]).pow_(2), -1), inplace=True).sqrt_()
tensor([[[0.5576, 0.7641, 0.5852, 0.9635, 0.5410, 0.6219, 0.4090, 0.5805,
0.2858, 0.5922],
[0.3840, 0.7817, 0.8539, 0.7480, 0.7318, 0.5838, 0.8228, 0.3726,
0.2518, 0.8157],
[0.6095, 0.3503, 0.7177, 0.4594, 0.6942, 0.5005, 0.5263, 0.3517,
0.5081, 0.5733],
[0.9207, 0.6445, 0.5479, 0.6481, 0.5887, 0.5019, 0.5028, 0.6662,
0.7208, 0.7784],
[0.5586, 0.3612, 0.6734, 0.5321, 0.6496, 0.4915, 0.4497, 0.3353,
0.4351, 0.5124]],
[[0.4094, 0.5278, 0.9131, 0.9780, 0.8924, 0.8715, 0.4789, 0.5620,
0.4482, 0.2302],
[0.3779, 0.7373, 0.7859, 0.7213, 0.6690, 0.5287, 0.7484, 0.3366,
0.1814, 0.7652],
[0.2435, 0.6572, 0.8725, 0.7283, 0.7694, 0.6233, 0.7475, 0.2918,
0.1962, 0.6788],
[0.4197, 0.7915, 0.9736, 0.6735, 0.8471, 0.6344, 0.9503, 0.3648,
0.3950, 0.8989],
[0.9703, 0.9544, 0.2234, 0.7880, 0.1719, 0.2564, 0.6353, 0.7353,
0.6436, 1.0101]]])
torch.Size([2, 5, 10])
计算结果:
tensor([[[0.5576, 0.7641, 0.5852, 0.9635, 0.5410, 0.6219, 0.4090, 0.5805,
0.2858, 0.5922],
[0.3840, 0.7817, 0.8539, 0.7480, 0.7318, 0.5838, 0.8228, 0.3726,
0.2518, 0.8157],
[0.6095, 0.3503, 0.7177, 0.4594, 0.6942, 0.5005, 0.5263, 0.3517,
0.5081, 0.5733],
[0.9207, 0.6445, 0.5479, 0.6481, 0.5887, 0.5019, 0.5028, 0.6662,
0.7208, 0.7784],
[0.5586, 0.3612, 0.6734, 0.5321, 0.6496, 0.4915, 0.4497, 0.3353,
0.4351, 0.5124]],
[[0.4094, 0.5278, 0.9131, 0.9780, 0.8924, 0.8715, 0.4789, 0.5620,
0.4482, 0.2302],
[0.3779, 0.7373, 0.7859, 0.7213, 0.6690, 0.5287, 0.7484, 0.3366,
0.1814, 0.7652],
[0.2435, 0.6572, 0.8725, 0.7283, 0.7694, 0.6233, 0.7475, 0.2918,
0.1962, 0.6788],
[0.4197, 0.7915, 0.9736, 0.6735, 0.8471, 0.6344, 0.9503, 0.3648,
0.3950, 0.8989],
[0.9703, 0.9544, 0.2234, 0.7880, 0.1719, 0.2564, 0.6353, 0.7353,
0.6436, 1.0101]]])