prc曲线 sklearn_Python sklearn.metrics 模块,average_precision_score() 实例源码 - 编程字典...

def main():

"""

Train a classifier based on all the measures, to discriminate hypernymy from one other single relation.

"""

# Get the arguments

args = docopt("""Calculate the Average Precision (AP) at k for every hyper-other relation in the dataset.

Usage:

ap_on_each_relation.py

= the test set result file.

= the test set containing the original relations.

= the cutoff; if it is equal to zero, all the rank is considered.

""")

test_set_file = args['']

test_results_file = args['']

cutoff = int(args[''])

# Load the test set

print 'Loading the dataset...'

test_set, relations = load_dataset(test_set_file + '.test')

hyper_relation = 'hyper'

for other_relation in [relation for relation in relations if relation != hyper_relation]:

curr_relations = [other_relation, hyper_relation]

print '=================================================='

print 'Testing', hyper_relation, 'vs.', other_relation, '...'

# Filter out the dataset to contain only these two relations

relation_index = { relation : index for index, relation in enumerate(curr_relations) }

curr_test_set = { (x, y) : relation for (x, y), relation in test_set.iteritems() if relation in curr_relations }

# Sort the lines in the file in descending order according to the score

with codecs.open(test_results_file, 'r', 'utf-8') as f_in:

dataset = [tuple(line.strip().split('\t')) for line in f_in]

dataset = [(x, y, label, float(score)) for (x, y, label, score) in dataset if (x, y) in curr_test_set]

dataset = sorted(dataset, key=lambda line: line[-1], reverse=True)

# relevance: rel(i) is an indicator function equaling 1 if the item at rank i is a hypernym

gold = np.array([1 if label == 'True' else 0 for (x, y, label, score) in dataset])

scores = np.array([score for (x, y, label, score) in dataset])

for i in range(1, min(cutoff + 1, len(dataset))):

score = average_precision_score(gold[:i], scores[:i])

print 'Average Precision at%dis%.3f' % (i, 0 if score == -1 else score)

print 'FINAL: Average Precision at%dis%.3f' % (len(dataset), average_precision_score(gold, scores))

你可能感兴趣的:(prc曲线,sklearn)