nn.py
def tag_dataset(dataset):
correctLabels = [] #标注的序列
predLabels = [] #预测的序列
b = Progbar(len(dataset)) #Progbar是进度条
for i,data in enumerate(dataset):
#enumerate() 函数用于将一个可遍历的数据对象(dataset)组合为一个索引序列,同时列出数据(data)和数据下标(i),一般用在 for 循环当中
tokens, casing, char, labels = data
tokens = np.asarray([tokens]) #np.asarray()将输入转为数组格式(https://blog.csdn.net/kepengs/article/details/84886395)
casing = np.asarray([casing])
char = np.asarray([char])
pred = model.predict([tokens, casing,char], verbose=False)[0]
#输入测试数据,输出预测结果(https://blog.csdn.net/DoReAGON/article/details/88552348)
pred = pred.argmax(axis=-1) #Predict the classes
#argmax是取最大值,axis=-1指的是倒数第一列(https://blog.csdn.net/weixin_39190382/article/details/105854567)
correctLabels.append(labels)
predLabels.append(pred)
b.update(i)
b.update(i+1)
return predLabels,
prepro.py
def readfile(filename):
'''
read file
return format :
[ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'], ['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ]
'''
f = open(filename,encoding="UTF-8")
sentences = [] #最终存放每一个元素对应的标注信息
sentence = [] #存放单个元素对应的标注信息
for line in f:
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
if len(sentence) > 0:
sentences.append(sentence)
sentence = []
continue
splits = line.split(' ')
sentence.append([splits[0],splits[-1]])
if len(sentence) >0:
sentences.append(sentence)
sentence = []
return sentences
1.如果该行为空,或者是-DOCSTART,将非空的sentence存到sentences中。
2.如果不是上述情况:按照每行的空格进行分割,然后获取每行分割的第一列元素和最后一列元素,存到sentence中。
def getCasing(word, caseLookup):
casing = 'other'
numDigits = 0
for char in word:
if char.isdigit():
numDigits += 1
digitFraction = numDigits / float(len(word))
if word.isdigit(): #Is a digit
casing = 'numeric'
elif digitFraction > 0.5:
casing = 'mainly_numeric'
elif word.islower(): #All lower case
casing = 'allLower'
elif word.isupper(): #All upper case
casing = 'allUpper'
elif word[0].isupper(): #is a title, initial char upper, then all lower
casing = 'initialUpper'
elif numDigits > 0:
casing = 'contains_digit'
return caseLookup[casing]
1.判断当前word是什么类型的元素,然后存在caseLookup中,对于这个单词中每个字符如果超过一半是数字则类型定义为‘主要是数字’、全部数字、全部小写、全部大写、开头字母大写、包含数字。
def createBatches(data):
l = []
for i in data:
l.append(len(i[0]))
l = set(l)
batches = []
batch_len = []
z = 0
for i in l:
for batch in data:
if len(batch[0]) == i:
batches.append(batch)
z += 1
batch_len.append(z)
return batches,batch_len
batch是把数据分成多少个batch然后进行运算。
def createMatrices(sentences, word2Idx, label2Idx, case2Idx,char2Idx):
unknownIdx = word2Idx['UNKNOWN_TOKEN']
paddingIdx = word2Idx['PADDING_TOKEN']
dataset = []
wordCount = 0
unknownWordCount = 0
for sentence in sentences:
wordIndices = []
caseIndices = []
charIndices = []
labelIndices = []
for word,char,label in sentence:
wordCount += 1
if word in word2Idx:
wordIdx = word2Idx[word]
elif word.lower() in word2Idx:
wordIdx = word2Idx[word.lower()]
else:
wordIdx = unknownIdx
unknownWordCount += 1
charIdx = []
for x in char:
charIdx.append(char2Idx[x])
#Get the label and map to int
wordIndices.append(wordIdx)
caseIndices.append(getCasing(word, case2Idx))
charIndices.append(charIdx)
labelIndices.append(label2Idx[label])
dataset.append([wordIndices, caseIndices, charIndices, labelIndices])
return dataset