Ai challenger 场景分类: 观察验证集中的错误分类情况

基于官方的验证脚本改的。

import json
import argparse
import time
import shutil
import pandas as pd

def __load_data(submit_file, reference_file, submit_dict, ref_dict):
  # load submit result and reference result

    with open(submit_file, 'r') as file1:
        submit_data = json.load(file1)
    with open(reference_file, 'r') as file1:
        ref_data = json.load(file1)
    if len(submit_data) != len(ref_data):
        result['warning'].append('Inconsistent number of images between submission and reference data \n')

    for item in submit_data:
        submit_dict[item['image_id']] = item['label_id']
    for item in ref_data:
        ref_dict[item['image_id']] = int(item['label_id'])
    return submit_dict, ref_dict


def __eval_result(submit_dict, ref_dict):
    # eval accuracy
    wrong_ids = []
    correct_ids = []

    right_count = 0
    for (key, value) in ref_dict.items():

        if key not in set(submit_dict.keys()):
            result['warning'].append('lacking image %s in your submission file \n' % key)
            print('warnning: lacking image %s in your submission file' % key)
            continue

        if value in submit_dict[key][:3]:
            right_count += 1
            if right_count<=100:
                correct_ids.append(key)
        else:
            wrong_ids.append(key)

    result['score'] = str(float(right_count)/max(len(ref_dict), 1e-5))
    return result, wrong_ids, correct_ids


if __name__ == '__main__':
    scene_classes = pd.read_csv('scene_classes.csv')
    wrongs = 'validation_wrong'
    corrects = 'validation_correct'
    path = 'ai_challenger_scene_validation_20170908\\scene_validation_images_20170908' #图片目录
    submit_dict = {}
    ref_dict = {}

    PARSER = argparse.ArgumentParser()

    PARSER.add_argument(
        '--submit',
        type=str,
        default='./submit.json',
        help="""\
        Path to submission file\
        """
    )

    PARSER.add_argument(
        '--ref',
        type=str,
        default='./ref.json',
        help="""\
        Path to reference file\
        """
    )

    FLAGS = PARSER.parse_args()

    result = {'error': [], 'warning': [], 'score': None}

    START_TIME = time.time()
    SUBMIT = {}
    REF = {}

    try:
        SUBMIT, REF = __load_data(FLAGS.submit, FLAGS.ref, submit_dict, ref_dict)
    except Exception as error:
        result['error'].append(str(error))
    try:
        result, wrong_ids, correct_ids = __eval_result(SUBMIT, REF)
    except Exception as error:
        result['error'].append(str(error))
    print('Evaluation time of your result: %f s' % (time.time() - START_TIME))

    print(result)

    for item in wrong_ids:
        shutil.copyfile(path+'\\'+item, wrongs+'\\'+item.split('.')[0]+str(submit_dict[item])+str(ref_dict[item])+'.jpg')
    for item in correct_ids:
        shutil.copyfile(path+'\\'+item, corrects+'\\'+item.split('.')[0]+str(submit_dict[item])+str(ref_dict[item])+'.jpg')

你可能感兴趣的:(PyTorch)