该代码中编写了许多初始化权重的信息,其他的代码都没有加载过初始化参数的信息吗?
import string
string.punctuation #所有的标点字符
'!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'
w = nn.Parameter(torch.ones(*param[:4]))
在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里,于是在栈溢网看到了一篇解释,并做了几个实验才算完全理解了这个函数。首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。
#根据tensor的维度初始化参数
import torch
import torch.nn as nn
w = torch.empty(2, 3)
# torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
# tensor([[ 0.2530, -0.4382, 1.5995],
# [ 0.0544, 1.6392, -2.0752]])
和上面的nn.ParameterList()配合使用
首先声明一个ParameterList()
如
vars=nn.ParameterList()
w=nn.Parameter(torch.ones(*param[:4]))
vars.append(w)
再输出ParameterList()就可以看到变化了
param[:4] [32, 3, 3, 3]
w = nn.Parameter(torch.ones(*param[:4]))
torch.nn.init.kaiming_normal_(w)
返回一个reader对象,利用该对象遍历csv文件中的行。
with open(self.root) as csvfile:
csvreader = csv.reader(csvfile, delimiter=',')
next(csvreader, None) # skip (filename, label)
for i, row in enumerate(csvreader):
#numpy.random.choice(a, size=None, replace=True, p=None)
#从a(只要是ndarray都可以,但必须是一维的)中随机抽取数字,并组成指定大小(size)的数组
#replace:True表示可以取相同数字,False表示不可以取相同数字
#数组p:与数组a相对应,表示取数组a中每个元素的概率,默认为选取每个元素的概率相同。
selected_cls = np.random.choice(self.cls_num, self.n_way, False)
#随机不重复地取数数据,在cls_num中随机不重复地取n_way个数。
selected_text = random.sample(self.dictLabels[str(cls)], self.k_spt + self.k_qry)
#第一个参数是数据,第二个参数是数据的个数
参考文档:https://www.jianshu.com/p/46eb3004beca
torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
参数列表如上,参考博客如下:
https://blog.csdn.net/a362682954/article/details/81196840
做的是将若干图像拼接成一幅图像,padding表示图片之间的间隔,padding默认值为2.可以设置为其他值。
参考:https://blog.csdn.net/u012343179/article/details/83007296
参考资料:https://blog.csdn.net/zbrwhut/article/details/80625702
https://www.jianshu.com/p/46a04f1f6085
tensor是Pytorch中的完美组件,但是构建神经网络还远远不够,我们需要能够构建计算图的tensor,这就是Variable。在torch中的Variable就是一个存放会变化的值的地理位置,里面的值会不断变化,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化,里面的鸡蛋就是torch的Tensor了(torch是用tensor计算的,tensor里面的参数都是variable的形式)。Variable是对tensor的封装,操作和tensor是一样的,但是每个Variable都有三个属性,Variable中的tensor本身.data,对应tensor的梯度.grad以及这个Variable是通过说明方式得到的.grad_fn。如果用Variable进行计算,那返回的也是一个同类型的Variable
def __init__中的参数是必须在调用该类的时候就需要传入的
而def forward()中的参数是需要在调用类的对象时候需要传入
maml=Meta(vocab_size,embedding_dim,n_filters,output_dim,dropout,pad_idx)
accs = maml(x_spt, y_spt, x_qry, y_qry)
而Meta类的定义如下
class Meta(nn.Module):
"""
Meta Learner
"""
def __init__(self, args,vocab_size, embedding_dim, n_filters,
filter_sizes, output_dim, dropout, pad_idx,static):
def forward(self, x_spt, y_spt, x_qry, y_qry):
如:
#调入类的实例(声明类)
maml = Meta(args,vocab_size=len(vocab),filter_sizes=[3,4,5],embedding_dim=100,n_filters=100,output_dim=1,dropout=0.5,pad_idx=0).to(device)
#调用类的对象(使用类)
accs = maml(x_spt, y_spt, x_qry, y_qry)
一个错误
在没有看到bert的博客前是想用word2vec来做预训练的模型,出的一个问题是直接用torch.tensor().long()直接去转换我们的英文字符,但是需要注意的是,我们需要先构建一个字典,将我们的数据在预处理的过程中,要将数据转换成对应的词典的索引。这样再继续构成tensor才可以,否则我们无法对其进行使用。这也就是为什么我们需要先构建词典,之后再使用预训练模型bert或者word2vec。通过这个我们也就知道了对文本的基本处理流程(1.导入数据 2.对数据进行处理(分词,除去停用词,和一些无用的标点符号。) 3.构建词典 4.dataloader 5.下游任务)他的__getitem__(self,index)
方法只要data和label一一对应,也就是大小尺寸一样就可以,不用注意index指的是什么,即使data的维度很复杂,我们也只需用data[index]
,label[index]
即可。
def __getitem__(self, index):
support_x = self.support_x_batch[index]
support_y = self.support_y_batch[index]
qry_x = self.qry_x_batch[index]
qry_y = self.qry_y_batch[index]
return support_x, qry_x, support_y, qry_y
np.random.choice
很重要。因为源程序中的是每个图片是一个单独的文件,那么我们也将每条评论构成一条单独的文件。对每个文件进行预处理,分词,去除停用词,构建词典。
#在当前py文件的目录下按csv文件中的数据按'\n'分成一个一个txt文件
import re
text = open(r'F:\研一\NLP\数据集\IMDB Dataset.csv',"r", encoding='UTF-8').read() #打开本体TXT文件
text = str(text)
print(text)
b=re.split('\n',text)
print(b)
n=0
for i in b:
n+=1
with open('%s.txt'%n,'w', encoding='UTF-8') as f:
f.write(i)
2.读取某文件夹下面的所有文件名,并将其保存到Excel中。
3.可以见其他几篇博客,文本数据预处理(自己定义),文本预处理(torchtext),读取文件以及一些处理方法。