最近在使用pytorch中的一维卷积来对文本进行处理,进行文本分类任务,查阅了网上相关的博客还有api这里做一个总结。
一维卷积,顾名思义就是在一维空间上进行卷积,通常用来处理时序的数据,卷积的过程如下图。
进行卷积的数据形状为[batch_size,seq_len,embedding_dim],经过卷积以后变成了[batch_size,out_channels,sql_len-kernel_size+1]的形状,在卷积的时候是在最后一个维度进行的所以需要对数据进行点处理,具体如代码所示。
import torch.nn as nn
import torch
data = torch.randn(4,5,8)# [batch_size,seq_len,embedding_dim)
con1d = nn.Conv1d(in_channels=8,out_channels=16,kernel_size=2)
data = torch.transpose(data,2,1)# 同 data.permute(0,2,1)
con1d_out= con1d(data)#[batch_size,out_chanels,seq_len-kernel_size+1] ->[4, 16, 4]
print(con1d_out.shape)
print(con1d_out)
这里采用了tranpose对dim=1,dim=2的维度数据进行了交换,同样的使用permute也可以达到这样的操作,个人习惯。
最后的输出
torch.Size([4, 16, 4])
tensor([[[-0.0851, 0.0582, -0.3878, -0.4815],
[-0.0192, 0.0096, -0.4060, -0.2221],
[-0.9653, -0.5644, -0.0039, -0.0162],
[-1.0623, -0.4552, 0.7921, -0.1066],
[-1.1642, 0.4845, 0.2344, -0.6042],
[-0.5638, 0.7780, 0.2239, 0.1187],
[-0.1438, -0.3047, -0.7292, -0.2968],
[-0.6816, 0.3791, -0.4561, -0.3937],
[-0.7172, -0.3273, 0.1383, -0.1623],
[-0.8436, -0.4637, 0.0030, -0.1074],
[ 1.0775, -0.5268, 0.7428, 0.5231],
[ 0.7474, 0.4146, 0.1968, 0.8429],
[-0.4140, -0.0394, 0.1463, 0.0412],
[ 0.5727, 0.4103, -0.1047, -0.2016],
[-0.1253, 0.0839, 0.1986, -0.7732],
[ 0.5374, 0.3954, -0.2495, 0.3254]],
[[ 0.2526, 0.2576, -0.5052, 0.0083],
[ 0.4127, 1.1993, -0.2114, 0.0136],
[ 0.0678, -0.1660, -1.3183, 0.2356],
[ 0.2819, 0.0628, -0.0574, -0.2374],
[ 0.3254, 0.9099, -0.5498, -0.2885],
[ 0.2731, -0.2013, 0.2595, -0.4752],
[ 0.6139, 0.0260, 0.4239, -1.0684],
[ 0.1177, 0.0573, -0.4777, -0.2491],
[-1.1266, -0.0891, -1.1373, 0.0738],
[-0.6815, -0.0559, -0.0862, 0.3590],
[ 0.1607, -0.6313, 1.2955, 0.5061],
[ 0.7632, -0.2714, 0.3060, 0.2704],
[ 0.7875, -0.5344, 0.3310, -0.5986],
[ 0.6162, 0.0442, 0.5216, 0.0574],
[-0.1813, -0.2603, 0.1043, 0.0509],
[ 0.4927, -0.1088, -0.5338, -0.2337]],
[[-0.0164, -0.6398, 0.0220, 1.4367],
[ 0.6438, 0.4777, 0.7895, -0.1808],
[ 0.9122, 0.0554, -0.3439, -0.2880],
[ 0.0640, 0.5090, 0.1620, 0.3268],
[ 1.4083, 0.2696, -0.8962, 0.7982],
[ 0.3067, 0.3309, 0.3118, 0.5801],
[-0.6267, 0.3782, 0.2978, -0.8898],
[ 0.2732, -0.4754, 0.0591, -0.2874],
[ 0.3752, -0.3867, 0.4108, -1.1205],
[-0.3308, -0.3190, 0.4023, -0.2092],
[-1.3494, 0.8448, 0.1239, -1.1028],
[-0.5598, -0.8947, 0.9866, 0.1430],
[-0.0092, 0.8585, -0.2731, -0.4883],
[-0.2728, 0.3041, 0.6107, 0.1400],
[-0.0886, -0.0418, -1.2089, 1.2100],
[ 0.7111, -0.0909, 0.3468, 0.6367]],
[[-0.2285, -1.0907, 1.0207, -0.0771],
[ 0.7745, 0.2723, 0.6125, 0.0904],
[ 0.5187, -0.2803, 0.1677, -0.9214],
[ 0.1704, -0.3473, 0.8135, -1.3735],
[ 0.3837, -0.0601, 0.9199, -0.6026],
[-0.3494, 0.2429, 1.0142, -0.1163],
[-0.9631, 0.2257, -0.2325, -0.3615],
[ 0.2249, 0.2316, -0.0267, -0.6608],
[ 0.4972, 0.2225, 0.1074, -0.3682],
[ 0.7068, -0.5119, 0.4362, -1.1837],
[ 0.1957, 0.2654, -0.8077, 0.3657],
[ 0.3629, 1.2386, -1.0372, -0.2023],
[-0.8409, 0.2340, 0.2384, -0.2724],
[-0.3382, 0.1901, 0.3490, 0.4499],
[-0.4086, -1.0089, -0.0738, 0.8813],
[-0.0946, 0.2343, -0.9303, 1.0733]]], grad_fn=)