pickle是常用的保存对象和数据的工具,总结使用以来碰到的问题对应的解决方法。尤其是,在保存defaultdict的时候遇到了问题,在stackoverflow上得到解答,感觉补充了以前的很多不足,所以在此小结巩固一下。
在python2中,有pickle和cPickle两个版本,主要区别在于cPickle的底层是C语言实现的,速度更加快,其他基本一致。
在python3中,只有保留了最优的版本,包名pickle,减少了开发人员的困惑。
以下内容来自于stackoverflow问题Can’t pickle defaultdict中两个大神“sloth”和“Martijn”的答案,推荐观看原版。
对于pickle来说任何要保存的模型或者对象都可以分为数据和代码两部分。
int, float, dict, set, tuple, list, string
等python自带的数据类型。对于数据部分,可以通过pickle保存起来然后导入即可。但是代码部分无法保存起来,但是会保存它与数据的关系,当执行pickle.load
的时候,会根据数据与代码部分的关系,恢复保存的数据。所以,通过import将代码部分定义引入或者将直接代码部分与pickle.load
放在同一个文件中。可以猜想,当我们要保存的数据或者对象中含有非顶级的类或函数,我们无法在pickle.load
的时候,引入或者找到对应的代码定义,导致加载失败,所以当保存时遇到非顶级的类对象和函数均会报错。
保存和加载一个顶级类对象和一个以顶级函数为值的dict,代码入下:
import cPickle as cp
class A(object):
def __init__(self, a):
self._a = a
def print_out(self):
print("this is an instance of modul-level class!", self._a)
a = A(3)
a.print_out()
with open("./data.pkl", "wb") as f:
cp.dump(a, f)
print("----------load----------")
with open("./data.pkl", 'rb') as f:
b = cp.load(f)
b.print_out()
#output
# ('this is an instance of modul-level class!', 3)
# ----------load----------
# ('this is an instance of modul-level class!', 3)
def modul_level_func():
print("this is a modul-level functiion!")
dic = {
"1": modul_level_func}
dic['1']()
with open("./dic.pkl", "wb") as f:
cp.dump(dic, f)
print("----------load----------")
with open("./dic.pkl", 'rb') as f:
b_dic = cp.load(f)
b_dic['1']()
#output
# this is a modul-level functiion!
# ----------load----------
# this is a modul-level functiion!
一些非顶级函数和类对象的例子:
# lambda表达式
dic = {
"1": lambda : 1}
cp.dump(dic, open("./lambda.pkl", "wb"))
# output
"""
TypeError Traceback (most recent call last)
in ()
1 dic = {"1": lambda : 1}
----> 2 cp.dump(dic, open("./lambda.pkl", "wb"))
c:\python27\lib\copy_reg.pyc in _reduce_ex(self, proto)
68 else:
69 if base is self.__class__:
---> 70 raise TypeError, "can't pickle %s objects" % base.__name__
71 state = base(self)
72 args = (self.__class__, base, state)
TypeError: can't pickle function objects
"""
# 嵌套函数
def outer():
def inner():
print("inside!")
dic = {
'1': inner}
return dic
dic = outer()
dic['1']()
cp.dump(dic, open("inner.pkl", 'wb'))
# output
"""
inside!
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
in ()
7 dic = outer()
8 dic['1']()
----> 9 cp.dump(dic, open("inner.pkl", 'wb'))
c:\python27\lib\copy_reg.pyc in _reduce_ex(self, proto)
68 else:
69 if base is self.__class__:
---> 70 raise TypeError, "can't pickle %s objects" % base.__name__
71 state = base(self)
72 args = (self.__class__, base, state)
TypeError: can't pickle function objects
"""
defaultdict一般与无参lambda表达式一起使用,保存时会报“TypeError: can’t pickle function objects”,保存失败。解决方案多种:
from collections import defaultdict as ddt
def mean(lis):
return sum(lis) * 1.0 / len(lis)
# 强制类型转换
a = ddt(lambda :1)
cp.dump(dict(a), open("ddt.pkl", "wb"))
# 重定义带有默认值的dict,不再使用工厂模式产生默认值。
class dictwithdefault(dict):
def __init__(self, default_value):
super(dictwithdefault, self).__init__()
self._default_value = default_value
def __missing__(self, key):
if callable(self._default_value):
return self._default_value(key)
else:
return self._default_value
class A(object):
def __init__(self, value):
self._value = value
def print_out(self):
print(self._value)
mean_score = mean(student_score_lis)
a_score_dic = dictwithdefault(mean_score)
cp.dump(a_score_dic, open("./daaa.pkl", "wb"))
b_score_dic = dictwithdefault(A)
cp.dump(b_score_dic, open("./daaa.pkl", "wb"))
# 自定义对象保存默认值。
class default_value(object):
def __init__(self, value):
self._value
def __call__(self):
return self._value
c_ddt = ddt(default_value(mean_score))
cp.dump(c_ddt, open("./daaa.pkl", "wb"))
pickle2和pickle3,分别指的是在python2和python3中的pickle,而不是真的存在包名为pickle2和pickle3包。在pickle保存数据时候,是按照一定的算法协议进行保存的,这些协议有不同的版本号,比如在pickle2中使用的版本为2,在pickle3中保存时默认版本号为3,同时由于高版本向低版本兼容,所以也支持版本2。所以,当要在pickle2打开pickle3保存的数据文件,此文件一定是要使用协议版本2进行保存的。
# 在python3中保存数据,使用协议2
import pickle as pkl
with open("num.pkl", "wb") as f:
pkl.dump(1, f, protocol=2)
# 在python2中可以直接打开
import cPickle as cp
with open("num.pkl", "rb") as f:
cp.dump(1, f)