python itertools库详解(让你的代码更pythonic)

itertools

api:


Itertool functions

itertools.accumulate(iterable[, func ])

def accumulate(iterable, func=operator.add):
    'Return running totals'
    # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
    # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
    it = iter(iterable)
    try:
        total = next(it)
    except StopIteration:
        return
    yield total
    for element in it:
        total = func(total, element)
        yield total


eg:
>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, operator.mul)) # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
>>> list(accumulate(data, max)) # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
# Amortize a 5% loan of 1000 with 4 annual payments of 90
>>> cashflows = [1000, -90, -90, -90, -90]
>>> list(accumulate(cashflows, lambda bal, pmt: bal*1.05 + pmt))
[1000, 960.0, 918.0, 873.9000000000001, 827.5950000000001]

itertools.chain 连接多个列表或者迭代器

>>> x = itertools.chain(range(3), range(4), [3,2,1])
>>> print(list(x))
[0, 1, 2, 0, 1, 2, 3, 3, 2, 1]

itertools.combinations 求列表或生成器中指定数目的元素不重复的所有组合

源码:
def combinations(iterable, r):
    # combinations('ABCD', 2) --> AB AC AD BC BD CD
    # combinations(range(4), 3) --> 012 013 023 123
    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)

def combinations(iterable, r):
    pool = tuple(iterable)
    n = len(pool)
    for indices in permutations(range(n), r):
        if sorted(indices) == list(indices):
            yield tuple(pool[i] for i in indices)

>>> x = itertools.combinations(range(4), 3)
>>> print(list(x))
[(0, 1, 2), (0, 1, 3), (0, 2, 3), (1, 2, 3)]

itertools.combinations_with_replacement(iterable, r) 允许重复元素的组合

>>> x = itertools.combinations_with_replacement('ABC', 2)
>>> print(list(x))
[('A', 'A'), ('A', 'B'), ('A', 'C'), ('B', 'B'), ('B', 'C'), ('C', 'C')]

itertools.compress(data, selectors) 按照真值表筛选元素

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
    return (d for d, s in zip(data, selectors) if s)

>>> list(itertools.compress(s,[1,0,1,0,1,1]))
['a', 'c', 'e', 'f']

itertools.count(start=0, step=1)

def count(start=0, step=1):
    # count(10) --> 10 11 12 13 14 ...
    # count(2.5, 0.5) -> 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step
>>> x = itertools.count(start=20, step=-1)
>>> print(list(itertools.islice(x, 0, 10, 1)))
[20, 19, 18, 17, 16, 15, 14, 13, 12, 11]

itertools.cycle(iterable) 循环指定的列表和迭代器

def cycle(iterable):
    # cycle('ABCD') --> A B C D A B C D A B C D ...
    saved = []
    for element in iterable:
        yield element
        saved.append(element)
    while saved:
        for element in saved:
            yield element

>>> x = itertools.cycle('ABC')
>>> print(list(itertools.islice(x, 0, 10, 1)))
['A', 'B', 'C', 'A', 'B', 'C', 'A', 'B', 'C', 'A']

itertools.dropwhile(predicate, iterable) 根据真值表舍弃第一个不满足条件前面的值

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,4,1]) --> 6 4 1
    iterable = iter(iterable)
    for x in iterable:
        if not predicate(x):
            yield x
            break
    for x in iterable:
        yield x

>>> x = itertools.dropwhile(lambda e: e < 5, range(10))
>>> print(list(x))
[5, 6, 7, 8, 9]

itertools.filterfalse 保留对应真值为False的元素

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x%2, range(10)) --> 0 2 4 6 8
    if predicate is None:
        predicate = bool
    for x in iterable:
        if not predicate(x):
            yield x

>>> x = itertools.filterfalse(lambda e: e < 5, (1, 5, 3, 6, 9, 4))
>>> print(list(x))
[5, 6, 9]

itertools.groupby(iterable, key=None) 按照分组函数的值对元素进行分组


class groupby:
    # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D
    def __init__(self, iterable, key=None):
        if key is None:
            key = lambda x: x
        self.keyfunc = key
        self.it = iter(iterable)
        self.tgtkey = self.currkey = self.currvalue = object()
    def __iter__(self):
        return self
    def __next__(self):
        self.id = object()
        while self.currkey == self.tgtkey:
            self.currvalue = next(self.it) # Exit on StopIteration
            self.currkey = self.keyfunc(self.currvalue)
        self.tgtkey = self.currkey
    return (self.currkey, self._grouper(self.tgtkey, self.id))
    def _grouper(self, tgtkey, id):
        while self.id is id and self.currkey == tgtkey:
            yield self.currvalue
            try:
                self.currvalue = next(self.it)
            except StopIteration:
                return
            self.currkey = self.keyfunc(self.currvalue)

>>> x = itertools.groupby(range(10), lambda x: x < 5 or x > 8)
>>> for condition, numbers in x:
    print(condition, list(numbers))
True [0, 1, 2, 3, 4]
False [5, 6, 7, 8]
True [9]

itertools.islice(iterable, stop)

itertools.islice(iterable, start, stop[, step ])

def islice(iterable, *args):
    # islice('ABCDEFG', 2) --> A B
    # islice('ABCDEFG', 2, 4) --> C D

    # islice('ABCDEFG', 2, None) --> C D E F G
    # islice('ABCDEFG', 0, None, 2) --> A C E G
    s = slice(*args)
    it = iter(range(s.start or 0, s.stop or sys.maxsize, s.step or 1))
    try:
        nexti = next(it)
    except StopIteration:
        return
    for i, element in enumerate(iterable):
        if i == nexti:
            yield element
            nexti = next(it)

>>> x = itertools.islice(range(10), 0, 9, 2)
>list(x)
>[0, 2, 4, 6, 8]

itertools.permutations(iterable, r=None),输出输入序列的全排列

The number of items returned is n! / (n-r)! when 0 <= r <= n or zero when r > n.

根据序列位置进行全排列,而不是值,所以如果输入序列有重复值,输出亦会有

源码模拟:
def permutations(iterable, r=None):
    # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) --> 012 021 102 120 201 210
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return
    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])
    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
    else:
        return
eg:
>>> import itertools as iters
>>> x = iters.permutations([0,1,2], 3)
>>> x
0x0000018B993C40A0>
>>> list(x)
[(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]
>>>  x = iters.permutations([0,1,1], 3)
>>> x = iters.permutations([0,1,1], 3)
>>> list(x)
[(0, 1, 1), (0, 1, 1), (1, 0, 1), (1, 1, 0), (1, 0, 1), (1, 1, 0)]

itertools.product(*iterables, repeat=1) Cartesian product of input iterables.

product(A, B) returns the same as ((x,y) for x in A for y in B)

def product(*args, repeat=1):
    # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) --> 000 001 010 011 100 101 110 111
    pools = [tuple(pool) for pool in args] * repeat
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    for prod in result:
        yield tuple(prod)

>>> list(iters.product('ABC',repeat =1))
[('A',), ('B',), ('C',)]
>>> list(iters.product('ABC',repeat =2))
[('A', 'A'), ('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'B'), ('B', 'C'), ('C', 'A'), ('C', 'B'), ('C', 'C')]

itertools.repeat(object[, times ])

def repeat(object, times=None):
    # repeat(10, 3) --> 10 10 10
    if times is None:
    while True:
        yield object
    else:
        for i in range(times):
            yield object

>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

itertools.starmap(function, iterable)

map() and starmap()的区别类似于function(a,b) and function(*c).

def starmap(function, iterable):
    # starmap(pow, [(2,5), (3,2), (10,3)]) --> 32 9 1000
    for args in iterable:
        yield function(*args)

>>> x = itertools.starmap(str.islower, 'aBCDefGhI')
>>> print(list(x))
[True, False, False, False, True, True, False, True, False]

itertools.takewhile(predicate, iterable) 保留序列元素直到条件不满足


def takewhile(predicate, iterable):
    # takewhile(lambda x: x<5, [1,4,6,4,1]) --> 1 4
    for x in iterable:
        if predicate(x):
            yield x
        else:
            break

itertools.tee(iterable, n=2)

一旦迭代器被tee()函数分割,就不要再次使用;list()其实更快

def tee(iterable, n=2):
    it = iter(iterable)
    deques = [collections.deque() for i in range(n)]
    def gen(mydeque):
    while True:
        if not mydeque: # when the local deque is empty
            try:
                newval = next(it) # fetch a new value and
            except StopIteration:
                return
            for d in deques: # load it to all the deques
                d.append(newval)
        yield mydeque.popleft()
    return tuple(gen(d) for d in deques)

>>> x = itertools.tee(range(10), 2)
>>> for letters in x:
... print(list(letters))
...
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

itertools.zip_longest(*iterables, fillvalue=None)

def zip_longest(*args, fillvalue=None):
    # zip_longest('ABCD', 'xy', fillvalue='-') --> Ax By C- Diterators
    = [iter(it) for it in args]
    num_active = len(iterators)
    if not num_active:
        return
    while True:
        values = []
        for i, it in enumerate(iterators):
                try:
                    value = next(it)
            except StopIteration:
                    num_active -= 1
                    if not num_active:
                        return
                    iterators[i] = repeat(fillvalue)
                    value = fillvalue
            values.append(value)
        yield tuple(values)

>>> x = itertools.zip_longest(range(3), range(5))
>>> y = zip(range(3), range(5))
>>> print(list(x))
[(0, 0), (1, 1), (2, 2), (None, 3), (None, 4)]
>>> print(list(y))
[(0, 0), (1, 1), (2, 2)]

你可能感兴趣的:(python)