CNNS+BiLSTM代码学习

nn.py

def tag_dataset(dataset):
    correctLabels = []   #标注的序列
    predLabels = []   #预测的序列
    b = Progbar(len(dataset))  #Progbar是进度条
    for i,data in enumerate(dataset):
        #enumerate() 函数用于将一个可遍历的数据对象(dataset)组合为一个索引序列,同时列出数据(data)和数据下标(i),一般用在 for 循环当中 
        tokens, casing, char, labels = data
        tokens = np.asarray([tokens])  #np.asarray()将输入转为数组格式(https://blog.csdn.net/kepengs/article/details/84886395)
        casing = np.asarray([casing])  
        char = np.asarray([char])
        pred = model.predict([tokens, casing,char], verbose=False)[0]
        #输入测试数据,输出预测结果(https://blog.csdn.net/DoReAGON/article/details/88552348)
        pred = pred.argmax(axis=-1) #Predict the classes
        #argmax是取最大值,axis=-1指的是倒数第一列(https://blog.csdn.net/weixin_39190382/article/details/105854567)
        correctLabels.append(labels)
        predLabels.append(pred)
        b.update(i)
    b.update(i+1)
    return predLabels,

prepro.py

def readfile(filename):
    '''
    read file
    return format :
    [ ['EU', 'B-ORG'], ['rejects', 'O'], ['German', 'B-MISC'], ['call', 'O'], ['to', 'O'], ['boycott', 'O'], ['British', 'B-MISC'], ['lamb', 'O'], ['.', 'O'] ]
    '''
    f = open(filename,encoding="UTF-8")
    sentences = []    #最终存放每一个元素对应的标注信息
    sentence = []     #存放单个元素对应的标注信息
    for line in f:
        if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
            if len(sentence) > 0:
                sentences.append(sentence)
                sentence = []
            continue
        splits = line.split(' ')
        sentence.append([splits[0],splits[-1]])
    if len(sentence) >0:
        sentences.append(sentence)
        sentence = []
    return sentences

1.如果该行为空,或者是-DOCSTART,将非空的sentence存到sentences中。
2.如果不是上述情况:按照每行的空格进行分割,然后获取每行分割的第一列元素和最后一列元素,存到sentence中。

def getCasing(word, caseLookup):   
    casing = 'other'
    numDigits = 0
    for char in word:
        if char.isdigit():
            numDigits += 1
          
    digitFraction = numDigits / float(len(word))
    if word.isdigit(): #Is a digit
        casing = 'numeric'
    elif digitFraction > 0.5:
        casing = 'mainly_numeric'
    elif word.islower(): #All lower case
        casing = 'allLower'
    elif word.isupper(): #All upper case
        casing = 'allUpper'
    elif word[0].isupper(): #is a title, initial char upper, then all lower
        casing = 'initialUpper'
    elif numDigits > 0:
        casing = 'contains_digit'
    return caseLookup[casing]

1.判断当前word是什么类型的元素,然后存在caseLookup中,对于这个单词中每个字符如果超过一半是数字则类型定义为‘主要是数字’、全部数字、全部小写、全部大写、开头字母大写、包含数字。

def createBatches(data):
    l = []
    for i in data:
        l.append(len(i[0]))
    l = set(l)
    batches = []
    batch_len = []
    z = 0
    for i in l:
        for batch in data:
            if len(batch[0]) == i:
                batches.append(batch)
                z += 1
        batch_len.append(z)
    return batches,batch_len

batch是把数据分成多少个batch然后进行运算。

def createMatrices(sentences, word2Idx, label2Idx, case2Idx,char2Idx):
    unknownIdx = word2Idx['UNKNOWN_TOKEN']
    paddingIdx = word2Idx['PADDING_TOKEN']    
        
    dataset = []
    
    wordCount = 0
    unknownWordCount = 0
    
    for sentence in sentences:
        wordIndices = []    
        caseIndices = []
        charIndices = []
        labelIndices = []
        
        for word,char,label in sentence:  
            wordCount += 1
            if word in word2Idx:
                wordIdx = word2Idx[word]
            elif word.lower() in word2Idx:
                wordIdx = word2Idx[word.lower()]                 
            else:
                wordIdx = unknownIdx
                unknownWordCount += 1
            charIdx = []
            for x in char:
                charIdx.append(char2Idx[x])
            #Get the label and map to int            
            wordIndices.append(wordIdx)
            caseIndices.append(getCasing(word, case2Idx))
            charIndices.append(charIdx)
            labelIndices.append(label2Idx[label])
           
        dataset.append([wordIndices, caseIndices, charIndices, labelIndices]) 
   
    return dataset

你可能感兴趣的:(CNNS+BiLSTM代码学习)