计算多分类的准召

df = pd.DataFrame({'real': list(y_test), 'pre1': list(np.argsort(-y_pred, axis=1)[:, 0])}).astype(int)
df['pre1'] = df['pre1'].apply(lambda x: id_to_name_dict[index_to_labelid_dict[x]])
df['real'] = df['real'].apply(lambda x: id_to_name_dict[index_to_labelid_dict[x]])
df['right'] = df.apply(lambda row: 1 if row['pre1'] == row['real'] else 0, axis=1)

def get_real_num(key):
    if key in df.groupby(['real'])['pre1'].count().reset_index().set_index('real')['pre1'].to_dict():
        return df.groupby(['real'])['pre1'].count().reset_index().set_index('real')['pre1'].to_dict()[key]
    else:
        return 0
    
def get_pre_num(key):
    if key in df.groupby(['pre1'])['real'].count().reset_index().set_index('pre1')['real'].to_dict():
        return df.groupby(['pre1'])['real'].count().reset_index().set_index('pre1')['real'].to_dict()[key]
    else:
        return 0
    
def get_tp(key):
    if key in df.groupby(['real'])['right'].sum().reset_index().set_index('real')['right'].to_dict():
        return df.groupby(['real'])['right'].sum().reset_index().set_index('real')['right'].to_dict()[key]
    else:
        return 0

def get_recall(tp, real_num):
    if real_num == 0:
        return 0
    else:
        return tp / real_num
    
def get_accuracy(tp, pre_num):
    if pre_num == 0:
        return 0
    else:
        return tp / pre_num
    
roc_dict = {key: {'real_num': get_real_num(key), 'pre_num': get_pre_num(key), 'tp': get_tp(key)}
            for key in list(set(df['real'].drop_duplicates().tolist() + df['pre1'].drop_duplicates().tolist()))}
roc_dict = {key: value.update({'fn': value['real_num'] - value['tp'], 'fp': value['pre_num'] - value['tp']}) or value for key, value in roc_dict.items()}
roc_dict = {key: value.update({'recall': get_recall(value['tp'], value['real_num']), 'accuracy': get_accuracy(value['tp'], value['pre_num'])}) or value for key, value in roc_dict.items()}
roc_dict = dict(sorted(roc_dict.items(), key=lambda x: x[1]['real_num'], reverse=True))
roc_dict

你可能感兴趣的:(算法)