sklearn的LabelEncoder,当transform
的时候,遇到没在fit
后的编码规则里的新值,会报错,它不像spark的LabelEncoder碰到新值会给你编成len+1。提供两种解决办法:
相信你肯定看过这个文章:
https://blog.csdn.net/qq_19446965/article/details/120110169
from sklearn.preprocessing import LabelEncoder as LEncoder
''' 重写LabelEncoder '''
# 重写LabelEncoder,将没有在编码规则里的填充Unknown
class LabelEncoder(LEncoder):
def fit(self, y):
return super(LabelEncoder, self).fit(list(y) + ['Unknown'])
def fit_transform(self, y):
return super(LabelEncoder, self).fit_transform(list(y) + ['Unknown'])
def transform(self, y):
new_y = ['Unknown' if x not in set(self.classes_) else x for x in y]
return super(LabelEncoder, self).transform(new_y)
继承并重写这个类虽然方便,但是存在问题,通过例子说明:
country_list = ['A', 'a', 'b', 'c', 'd']
label_encoder = LabelEncoder()
label_encoder.fit(country_list)
print('country_list: ', label_encoder.classes_)
print('encode_country_list: ', label_encoder.transform(country_list))
new_country_list = ['a', 'b', 'c', 'g', 'h', 'i']
print('new_encode_country_list: ', label_encoder.transform(new_country_list))
country_list: ['A' 'Unknown' 'a' 'b' 'c' 'd']
encode_country_list: [0 2 3 4 5]
new_encode_country_list: [2 3 4 1 1 1]
country_list = ['889', '778', '567', '1920', '999']
label_encoder = LabelEncoder()
label_encoder.fit(country_list)
print('country_list: ', label_encoder.classes_)
print('encode_country_list: ', label_encoder.transform(country_list))
new_country_list = ['889', '778', '100', '200', '300']
print('new_encode_country_list: ', label_encoder.transform(new_country_list))
country_list: ['1920' '567' '778' '889' '999' 'Unknown']
encode_country_list: [3 2 1 0 4]
new_encode_country_list: [3 2 5 5 5]
我的需求是这些不在编码规则里的值是需要删除的,所以重写后,编码的值我并不知道那个是我要删除的,而且对于数字的字符串它编码也会有问题。所以如下方法:
from sklearn.preprocessing import LabelEncoder
le = preprocessing.LabelEncoder()
le.fit(X)
# label编码其实就是映射的字典,将编码字典保存
le_dict = dict(zip(le.classes_, le.transform(le.classes_)))
检索单个新项目的标签,如果项目丢失,则将值设置为未知
le_dict.get(new_item, 'Unknown')
检索 Dataframe 列的标签:
df['col'] = df['col'].apply(lambda x: le_dict.get(x, 'Unknown'))
# 再将新值删除
df = df[df['col'] != 'Unknown']
df['col'] = df['col'].astype(dtype='int64')
还有其它的方法,可以参考:
https://stackoom.com/question/1QM33