在用mini-imagenet数据集时,要对csv文件的labels的列进行onehot编码,网上没怎么查到,顺手写了一波。
代码如下(示例):
from numpy import argmax
import pandas as pd
import numpy as np
代码如下(示例):
new_train_path = r'D:\CsDataset\mini-imagenet\new_train.csv'
class OnehotLables:
def __init__(self, labrel_idx):
self.label_idx = labrel_idx # 要编码的那列索引
def _creat_one_hot(self,labels_len): # labels_len :要编码的列长
'''做一个单位阵形式的列表'''
return np.eye(labels_len).tolist()
def _create_onehot_labels(self, labels, labels_set_list): #labels_set_list 取那列不重复的形成列表
'''将要编码那列编码'''
labels_len = len(labels_set_list)
onehot = self._creat_one_hot(labels_len)
for i in range(len(onehot)):
for j in range(len(labels)):
if labels[j] == labels_set_list[i]:
labels[j] = onehot[i]
return labels
def get_label_name(self, onehot_label, labels_set_list):
'''取编码对应编码前的值'''
return labels_set_list[argmax(onehot_label)]
def forward(self, csv_path): # 传入文件路径
csv_data = pd.read_csv(csv_path)
labels = csv_data.iloc[:, self.label_idx]
labels_set_list = list(set(labels))
# print(f'num_classification: {len(labels_set_list)}')
onehot_labels = self._create_onehot_labels(labels, labels_set_list)
return onehot_labels, labels_set_list
creat_onehot_labels = OnehotLables(labrel_idx=1)
onehot_labels, labels_set_list = creat_onehot_labels.forward(new_train_path)
onehot_label = onehot_labels[0]
print(onehot_label)
print(creat_onehot_labels.get_label_name(onehot_label,labels_set_list))
(base) C:\Users\渺渺夕\Desktop\Transfer_Learning_learning>D:/Anaconda3/python.exe c:/Users/渺渺夕/Desktop/Transfer_Learning_learning/data.py
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
n01532829