问题重述:我在学习pytorch时,跟着网课学到使用pickle模块序列化Word2Seq类,并将该类序列化保存在ws.pkl文件中。然后我又创建了lib.py文件使用下面代码反序列化ws.pkl时出现了该问题。
ws = pickle.load(open('ws.pkl', 'rb'))
我在自己创建的lib.py文件中写下如下代码
import pickle
from utils.word2seq import Word2Seq # 这个Word2Seq即为你序列化的类
ws = pickle.load(open('ws.pkl', 'rb'))
这样直接运行lib.py文件,ws即可被正确加载,且不会报错
但是这样又会出现新的问题
问题描述:我又创建了另一个build_dataset.py文件,并期望在build_dataset.py文件中导入lib.py文件中的ws对象。可问题出现了,即:AttributeError: Can’t get attribute ‘xxx’ on
因此 这种解决办法并不是最根本的解决方法
之所以会出现题目中的问题,是你在构架pickle要序列化的那个类时写发就是错的,pickle序列化的那个类所在的py文件一定要**“干净”**,即该py文件只能写这一个类,下面举例:
假如我要序列化word2seq.py文件中的Word2Seq类, 那么该py文件中除了 class Word2Seq() 中的内容外,其他任何内容 诸如 def函数等 都不要多写
# word2seq.py
'''
构建词典,实现方法把句子转化为数字序列和其翻转
'''
class Word2Seq():
UNK_TAG = 'UNK' # UNK表示特殊字符,没见见过的词语都用UNK代替,UNK对应数字0
PAD_TAG = 'PAD' # 把短句子进行填充,使用PAD进行填充,PAD对应数字为1
UNK = 0
PAD = 1
def __init__(self):
# 将 词语 和 编号对应起来
self.dict = {
self.UNK_TAG: self.UNK,
self.PAD_TAG: self.PAD
}
self.count = {} # 统计词频
def fit(self, text):
'''
把单个句子保存到dict中, 并统计每个词语的词频
:param text: [word1, word2, word3, ...]
:return:
'''
for word in text:
'''Tips: 编程技巧
self.count.get(word, 0) + 1
如果当前字典中'word'存在则返回key对应的值并+1,如果'word'不存在则返回0+1
'''
self.count[word] = self.count.get(word, 0) + 1
def build_vocab(self, min=5, max=None, max_features=None):
'''
生成词典, 剔除不符合数量要求的词语
:param min: 词语最小出现次数
:param max: 最大出现次数
:param max_features: 一共保留多少个词语
:return:
'''
# 删除count中词频小于min的word
if min is not None:
'''
PS: 遍历字典时,其实遍历的是key
'''
self.count = {word: value for word, value in self.count.items() if value >= min}
# 删除count中词频大于max的word
if max is not None:
self.count = {word: value for word, value in self.count.items() if value <= max}
# 限制保留的词语数
if max_features is not None:
'''
sorted后会将元组变成列表
self.count.items() 是一个可迭代对象, 其中的每一个值是一个(key,value)对
key=lambda x:x[-1] 使字典中的key根据items中的value进行排序, x[-1]表示取最后一个值也就是value
reverse=True 由大到小,降序排列
[:max_features] 将排序后的前 max_features 个数取出来(因为sorted已经将dict_items变为list,故可以这样取值)
'''
temp = sorted(self.count.items(), key=lambda x: x[-1], reverse=True)[
:max_features] # 这样得到的是一个列表,其中每个元素是一个二值元组
self.count = dict(temp) # 将[(key, value), (key, value)] 转化为 {key:value, key:value}
# 给每一个词语进行编号 {word:num}
for word in self.count:
'''
因为原来的self.dict中已有self.UNK_TAG: self.UNK 和 self.PAD_TAG: self.PAD 两组键值对
故新词的编号从 2 开始,也就不会和之前的重复
'''
self.dict[word] = len(self.dict)
# 得到一个翻转的dict词典 {num:word}
self.inverse_dict = dict(zip(self.dict.values(), self.dict.keys()))
def transform(self, text, max_len=None):
'''
把句子转换为序列
:param text: [word1, word2, ...]
:param max_len: int, 对句子进行填充或裁剪
:return: [1, 2, 4, ...]
'''
'''
在self.dict中找到句子中每一个词语对应的编号,组成list返回
'''
if max_len is not None:
if max_len > len(text):
text = text + [self.PAD_TAG] * (max_len - len(text)) # 如果句子长度小于max_len则对句子填充max_len-len(text)个'PAD'
else:
text = text[:max_len] # 如果句子长度大于max_len, 则对句子裁剪,取前max_len个
return [self.dict.get(word, self.UNK) for word in text]
def inverse_transform(self, indices):
'''
将序列转化为句子
:param indices: [1, 2, 4, 5, 3, ...]
:return: [word1, word2, word4, word3, ...]
'''
return [self.inverse_dict.get(index) for index in indices]
def __len__(self):
# 词语的个数
return len(self.dict)
if __name__ == '__main__':
pass
此时我在dataset.py文件中写下函数生成ws.pkl文件
# 保存词到编号的映射
def fit_save_word_seq(max_features=10000):
import os
import pickle
from utils.word2seq import Word2Seq
ws = Word2Seq()
path = '../data/aclImdb'
temp_data_path = [os.path.join(path, 'train/pos'),
os.path.join(path, 'train/neg'),]
# os.path.join(path, 'test/pos'),
# os.path.join(path, 'test/neg')]
for data_path in temp_data_path:
file_list = os.listdir(data_path) # 获取目录下所有文件名
file_path_list = [os.path.join(data_path, file_name) for file_name in file_list if file_name.endswith('.txt')] # 获取文件完整路径
for file_path in tqdm(file_path_list):
text = tokenlize(open(file_path, encoding='utf-8').read()) # 对每个句子进行分词
ws.fit(text) # 为每个词语映射为序列
if max_features is not None: # 是否对词的最大个数进行限制
ws.build_vocab(min=10, max_features=max_features)
else:
ws.build_vocab(min=10) # 为每个词语进行编号
pickle.dump(ws, open('../model_data/ws.pkl', 'wb'), protocol=4)
print(ws.dict, len(ws))
if __name__ == '__main__':
fit_save_word_seq()
然后在lib.py文件中加载ws.pkl文件
import pickle
# 注意这里最大的不同是,不需要像方法一那样 导入一遍 Word2Seq类
ws = pickle.load(open('model_data/ws.pkl', 'rb'))
print(ws)
最后我在build_dataset.py文件中直接导入lib.py文件反序列化的ws对象,这样便不会报错了,想在哪里导入lib.py文件中的ws对象 就在 哪个文件导入。不会再出现题目中的问题!
from utils.lib import ws
PS:不要问我为什么要这样调来调去的,因为代码量大的话你必须写一个配置文件,专门存放可供调节的参数,所以使用方法二才是最佳的解决办法! 一定要保证pickle要序列化的那个类所在的py文件不含有除了class Word2Seq() 外的其他内容