itertools基本使用

Posted by Sky丶Memory on August 11, 2018

itertools算得上是Python中比较常用的一个模块,它主要提供了操作在可迭代对象上的一些实用函数,这篇文章主要介绍这些函数的基本等价实现与一些使用心得.(Python Version:3.6.3)

cycle

基本等价实现:

def cycle(iterable):
    # cycle('ABCD') ---> A B C D A B C D A B C D...
    saved = []
    for item in iterable:
        yield item
        saved.append(item)
    while saved:
        for item in saved:
            yield item

使用例子:基于cycle实现的Round-Robin负载均衡

class RoundRobinSelector:
    def __init__(self, hosts):
        self.hosts = hosts
        self._it = cycle(hosts)
    
    def select(self):
        return next(self._it)

之前去知乎面试,考过手写cycle的实现。。。。另外,cycle使用中有个地方需要注意,它需要额外的空间来存储,所以当输入集太大时,最好自己用取模实现类似的功能


count

基本等价实现:

def count(start=0, step=1):
    # count(10) ---> 0 1 2 3 4 5 6...
    # count(2.5, 0.5) ---> 2.5 3 3.5...
    n = start
    while True:
        yield n
        n += step

repeat

基本等价实现:

def repeat(o, times=None):
    if times is None:
        while 1:
            yield o
    else:
        for _ in range(times):
            yield o

使用例子: repeat函数主要用在一些需要常量序列的地方,比如对一个序列求平方—map(pow, range(5), repeat(2))


compress

基本等价实现:

def compress(iterable, selectors):
    # compress('ABCDED', [1, 0, 1, 0, 0, 0]) ---> A C
    return (d for d, s in zip(iterable, selectors) if s)

filterfalse

基本等价实现:

def filterfalse(pred, iterable):
    if pred is None:
        pred = bool
    return (item for item in iterable if not pred(item))

def filter(pred, iterable):
    if pred is None:
        pred = bool
    return (item for item in iterable if pred(item))

filterfalse配合内置函数filter可用于对输入集通过谓词(predicate)进行分组,对得到的序列可用于不同业务逻辑的处理


takewhile

基本等价实现:

def takewhile(pred, iterable):
    for item in iterable:
        if pred(item):
            yield item
        else:
            break

用于得到一个符合条件的前缀序列


dropwhile

基本等价实现:

def dropwhile(pred, iterable):
    # dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]) ---> 6 4 1
    it = iter(iterable)
    for item in it:
        if not pred(item):
            yield item
            break
    for item in it:
        yield item

用于得到一个符合条件的后缀序列


chain

基本等价实现:

def chain(*iterables):
    # chain('ABC', 'DEF') --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

用于将连续的多个序列当成单个序列处理


chain.from_iterable

基本等价实现:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F
    for it in iterables:
        for element in it:
            yield element

跟chain类似,唯一的区别在于,chain.from_iterable不用事先知道iterables的长度,而chain则需要


islice

基本等价实现:

def islice(iterable, *args):
    # islice('ABCDEFG', 2) --> A B
    # islice('ABCDEFG', 2, 4) --> C D
    # islice('ABCDEFG', 2, None) --> C D E F G
    # islice('ABCDEFG', 0, None, 2) --> A C E G
    s = slice(*args)
    it = iter(range(s.start or 0, s.stop or sys.maxsize, s.step or 1))
    try:
        nexti = next(it)
    except StopIteration:
        return
    for i, element in enumerate(iterable):
        if i == nexti:
            yield element
            nexti = next(it)

该函数主要作用有如下两个:

  • 从一定程度上避免了常规切片带来的内存复制问题,同时也带来了另外一个问题,切片的时间复杂度增为O(n),使用时,需要自己权衡一下

  • 提供了generator切片,但start、stop参数不支持负数,需注意该限制


accumulate

基本等价实现:

def accumulate(iterable, func=operator.add):
    'Return running totals'
    # accumulate([1,2,3,4,5]) --> 1 3 6 10 15
    # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
    it = iter(iterable)
    try:
        total = next(it)
    except StopIteration:
        return
    yield total
    for element in it:
        total = func(total, element)
        yield total

该函数与functools.reduce类似,唯一区别在于,functools.reduce只返回最终计算的值,而accumulate返回了每步计算的值,通过accumulate函数,我们很容易得出前缀相关的性质,如前缀和、前缀积、前缀最值等


zip_longest

基本等价实现:

class ZipExhausted(Exception):
    pass

def zip_longest(*args, **kwds):
    # zip_longest('ABCD', 'xy', fillvalue='-') --> ('A', 'x'), ('B', 'y'), ('C', '-'), ('D', '-')
    fillvalue = kwds.get('fillvalue')
    counter = len(args) - 1
    def sentinel():
        nonlocal counter
        if not counter:
            raise ZipExhausted
        counter -= 1
        yield fillvalue
    fillers = repeat(fillvalue)
    iterators = [chain(it, sentinel(), fillers) for it in args]
    try:
        while iterators:
            yield tuple(map(next, iterators))
    except ZipExhausted:
        pass

内置函数zip当参数序列长度不等时,选择最短序列的长度做为最终结果的长度,而zip_longest刚好与其相反,选择最长序列长度做为最终结果长度,而不及最长长度的序列,可通过fillvalue指定填充值


groupby

基本等价实现:

class groupby:
    # [k for k, g in groupby('AAAABBBCCDAABBB')] --> A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] --> AAAA BBB CC D
    def __init__(self, iterable, key=None):
        if key is None:
            key = lambda x: x
        self.keyfunc = key
        self.it = iter(iterable)
        self.tgtkey = self.currkey = self.currvalue = object()
    def __iter__(self):
        return self
    def __next__(self):
        while self.currkey == self.tgtkey:
            self.currvalue = next(self.it)    # Exit on StopIteration
            self.currkey = self.keyfunc(self.currvalue)
        self.tgtkey = self.currkey
        return (self.currkey, self._grouper(self.tgtkey))
    def _grouper(self, tgtkey):
        while self.currkey == tgtkey:
            yield self.currvalue
            try:
                self.currvalue = next(self.it)
            except StopIteration:
                return
            self.currkey = self.keyfunc(self.currvalue)

通过指定key函数,对可迭代对象进行分组,分组的依据为相邻元素是否相等来决定是否属于同一个组,这与SQL中的Group By处理方式不同,使用时需要额外注意

参考资料