def _pad_sequences(sequences, pad_tok, max_length):
"""
Args:
sequences: a generator of list or tuple
pad_tok: the char to pad with
Returns:
a list of list where each sublist has same length
"""
sequence_padded, sequence_length = [], []
for seq in sequences:
seq = list(seq)
seq_ = seq[:max_length] + [pad_tok]*max(max_length - len(seq), 0)
sequence_padded += [seq_]
sequence_length += [min(len(seq), max_length)]
return sequence_padded, sequence_length
def pad_sequences(sequences, pad_tok, nlevels=1):
"""
Args:
sequences: a generator of list or tuple
pad_tok: the char to pad with
nlevels: "depth" of padding, for the case where we have characters ids
Returns:
a list of list where each sublist has same length
"""
if nlevels == 1:
max_length = max(map(lambda x: len(x), sequences))
sequence_padded, sequence_length = _pad_sequences(sequences,
pad_tok, max_length)
elif nlevels == 2:
max_length_word = max([max(map(lambda x: len(x), seq))
for seq in sequences])
sequence_padded, sequence_length = [], []
for seq in sequences:
# all words are same length now
sp, sl = _pad_sequences(seq, pad_tok, max_length_word)
sequence_padded += [sp]
sequence_length += [sl]
max_length_sentence = max(map(lambda x: len(x), sequences))
sequence_padded, _ = _pad_sequences(sequence_padded,
[pad_tok]*max_length_word, max_length_sentence)
sequence_length, _ = _pad_sequences(sequence_length, 0,
max_length_sentence)
return sequence_padded, sequence_length
token embedding
words = [[1,2,3,4,5], [1,3,5,7,9,1,12,11,13], [1,3,5,7,11,13]]
word_ids, sequence_lengths = pad_sequences(words, 0)
print(word_ids)
[[1, 2, 3, 4, 5, 0, 0, 0, 0], [1, 3, 5, 7, 9, 1, 12, 11, 13], [1, 3, 5, 7, 11, 13, 0, 0, 0]]
char embedding
char_ids = [[[1,2,3,4,5], [1,3,5,7,9,1,12,11,13], [1,3,5,7,11,13]],
[[1,2,3,4,5], [1,12,11,13], [1,3,5,7,2,1,8]]]
char_ids, word_lengths = pad_sequences(char_ids, pad_tok=0,
nlevels=2)
print(char_ids)
[[[1, 2, 3, 4, 5, 0, 0, 0, 0], [1, 3, 5, 7, 9, 1, 12, 11, 13], [1, 3, 5, 7, 11, 13, 0, 0, 0]],
[[1, 2, 3, 4, 5, 0, 0, 0, 0], [1, 12, 11, 13, 0, 0, 0, 0, 0], [1, 3, 5, 7, 2, 1, 8, 0, 0]]]