python学习笔记04

python学习笔记之04.

迭代器和生成器

迭代器

迭代是Python最强大的功能之一,是访问集合元素的一种方式。
迭代器是一个可以记住遍历的位置的对象。
迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。
迭代器有两个基本的方法:iter() 和 next()。

字符串,列表或元组对象都可用于创建迭代器:

实例(Python 3.0+)

>>>list=[1,2,3,4]
>>> it = iter(list)    # 创建迭代器对象
>>> print (next(it))   # 输出迭代器的下一个元素
1
>>> print (next(it))
2
>>>

迭代器对象可以使用常规for语句进行遍历:

实例(Python 3.0+)
#!/usr/bin/python3
 
list=[1,2,3,4]
it = iter(list)    # 创建迭代器对象
for x in it:
    print (x, end=" ")

执行以上程序,输出结果如下:

1 2 3 4

也可以使用 next() 函数:

实例(Python 3.0+)
#!/usr/bin/python3

import sys         # 引入 sys 模块
 
list=[1,2,3,4]
it = iter(list)    # 创建迭代器对象
 
while True:
    try:
        print (next(it))
    except StopIteration:
        sys.exit()

执行以上程序,输出结果如下:

1
2
3
4
生成器

在 Python 中,使用了 yield 的函数被称为生成器(generator)。

跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。

在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。
调用一个生成器函数,返回的是一个迭代器对象。

以下实例使用 yield 实现斐波那契数列:

实例(Python 3.0+)
#!/usr/bin/python3
 
import sys
 
def fibonacci(n): # 生成器函数 - 斐波那契
    a, b, counter = 0, 1, 0
    while True:
        if (counter > n): 
            return
        yield a
        a, b = b, a + b
        counter += 1
f = fibonacci(10) # f 是一个迭代器,由生成器返回生成
 
while True:
    try:
        print (next(f), end=" ")
    except StopIteration:
        sys.exit()

执行以上程序,输出结果如下:

0 1 1 2 3 5 8 13 21 34 55

装饰器

引用知乎大神非常形象的比喻:
https://www.zhihu.com/question/26930016

内裤可以用来遮羞,但是到了冬天它没法为我们防风御寒,聪明的人们发明了长裤,有了长裤后宝宝再也不冷了,装饰器就像我们这里说的长裤,在不影响内裤作用的前提下,给我们的身子提供了保暖的功效。

装饰器本质上是一个Python函数,它可以让其他函数在不需要做任何代码变动的前提下增加额外功能,装饰器的返回值也是一个函数对象。它经常用于有切面需求的场景,比如:插入日志、性能测试、事务处理、缓存、权限校验等场景。装饰器是解决这类问题的绝佳设计,有了装饰器,我们就可以抽离出大量与函数功能本身无关的雷同代码并继续重用。概括的讲,装饰器的作用就是为已经存在的对象添加额外的功能。

下面是一个重写了特殊方法 getattribute 的类装饰器, 可以打印日志:

def log_getattribute(cls):
    # Get the original implementation
    orig_getattribute = cls.__getattribute__

    # Make a new definition
    def new_getattribute(self, name):
        print('getting:', name)
        return orig_getattribute(self, name)

    # Attach to the class and return
    cls.__getattribute__ = new_getattribute
    return cls

# Example use
@log_getattribute
class A:
    def __init__(self,x):
        self.x = x
    def spam(self):
        pass

下面是使用效果:

>>> a = A(42)
>>> a.x
getting: x
42
>>> a.spam()
getting: spam
>>>

作业

员工信息管理程序
staffs.txt

3,Rain Wang,21,13451054608,IT,2017‐04‐01
4,Mack Qiao,44,15653354208,Sales,2016‐02‐01
5,Rachel Chen,23,13351024606,IT,2013‐03‐16
6,Eric Liu,19,18531054602,Marketing,2012‐12‐01
7,Chao Zhang,21,13235324334,Administration,2011‐08‐08
8,Kevin Chen,22,13151054603,Sales,2013‐04‐01
9,Shit Wen,20,13351024602,IT,2017‐07‐03
10,Shanshan Du,26,13698424612,Operation,2017‐07‐02
11,Libai,26,134435366,IT,2015‐10‐2
12,Libai,26,134435366,IT,2015‐10‐23

StaffManage.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Created by master on  2018/5/21 16:45.
import codecs
import os
from tabulate import tabulate


class StaffManage(object):
    file_path = "staffs.txt"

    # 增删改查入口
    def manage(self, sql):
        lower_sql = sql.lower()
        if lower_sql.startswith("add"):
            self.add(sql)
        elif lower_sql.startswith("del"):
            self.remove(sql)
        elif lower_sql.startswith("update"):
            self.update(sql)
        elif lower_sql.startswith("find"):
            self.find(sql)
        elif lower_sql.startswith("exit"):
            exit("退出程序")
        else:
            log("命令不存在或者输入错误,请仔细检查后重新输入!", "error")

    # 查询
    def find(self, sql):
        '''
        find name,age from staff_table where age > 22
        find * from staff_table where dept = "IT"
        find * from staff_table where enroll_date like "2013"
        '''
        lower_sql = sql.lower()
        lower_sql = lower_sql.replace("find", "").replace("*", "").replace("from", "") \
            .replace("staff_table", "").replace('"', "").replace("'", "").strip()
        list_param = lower_sql.split("where")
        new_staffs = []
        if list_param[0]:
            log("没有这个命令,请确认!", "error")
        else:
            list_staff = self.load_file()
            if "=" in list_param[1]:
                key = list_param[1].split("=")[0].strip()
                value = list_param[1].split("=")[1].strip()
                for staff in list_staff:
                    if staff[key].lower() == value:
                        new_staffs.append(staff)
            elif "like" in list_param[1]:
                key = list_param[1].split("like")[0].strip()
                value = list_param[1].split("like")[1].strip()
                for staff in list_staff:
                    if value in staff[key].lower():
                        new_staffs.append(staff)
            else:
                if ">" in list_param[1]:
                    key = list_param[1].split(">")[0].strip()
                    value = list_param[1].split(">")[1].strip()
                    for staff in list_staff:
                        if int(staff[key]) > int(value):
                            new_staffs.append(staff)
                elif "<" in list_param[1]:
                    key = list_param[1].split("<")[0].strip()
                    value = list_param[1].split("<")[1].strip()
                    for staff in list_staff:
                        if int(staff[key]) < int(value):
                            new_staffs.append(staff)
        print_table(new_staffs)

    # 添加纪录
    def add(self, sql):
        #  "add staff_table Alex Li,25,134435344,IT,2015‐10‐29"
        sql = "0," + sql.replace("add", "").replace("ADD", "").replace("staff_table", "").strip()
        d_staff = self.__split_line(sql)
        if not self.phone_exist(self.load_file(), d_staff["phone"]):
            d_staff["staff_id"] = str(int(self.__get_last_staff_id()) + 1)
            self.__insert([d_staff])
            log("影响了1行")
        else:
            log("存在重复的手机号码,请确认!", "error")

    # 判断手机号是否存在
    def phone_exist(self, staffs, phone):
        phones = []
        for staff in staffs:
            phones.append(staff["phone"])
        return phone in phones

    # 插入记录
    def __insert(self, staffs, mode="a+"):
        with codecs.open(self.file_path, mode, "utf-8") as f:
            for i in range(len(staffs)):
                staff = staffs[i]
                record = ""
                len_values = len(staff.values())
                for j in range(len_values):
                    v = list(staff.values())[j]
                    if j < len_values - 1:
                        record += v + ","
                    else:
                        record += v
                if i > 0:
                    f.write("\n" + record)
                else:
                    f.write(record)

    # 更新记录
    def update(self, sql):
        effect_lines = 0
        lower_sql = sql.lower()
        lower_sql = lower_sql.replace("update", "").replace("staff_table", "") \
            .replace('"', "").replace('set', "").replace("'", "").strip()
        list_param = lower_sql.split("where")
        u_key = list_param[0].split("=")[0].strip()
        u_value = list_param[0].split("=")[1].strip()
        o_key = list_param[1].split("=")[0].strip()
        o_value = list_param[1].split("=")[1].strip()
        list_staff = self.load_file()
        for staff in list_staff:
            if staff[o_key].lower() == o_value:
                staff[u_key] = u_value
                effect_lines += 1

        log("影响了%s行" % effect_lines)
        self.__insert(list_staff, "w")

    # 删除记录
    def remove(self, sql):
        # del from staff where id=3
        effect_lines = 0
        sql = sql.lower().replace("del", "").replace("from", "").replace("staff_table", "") \
            .replace("where", "").strip()
        list_param = sql.split("=")
        id = list_param[1].strip()
        list_staff = self.load_file()
        for staff in list_staff:
            if staff["staff_id"] == id:
                list_staff.remove(staff)
                effect_lines += 1
            else:
                return
        log("影响了%s行" % effect_lines)
        self.__insert(list_staff, "w")

    # 切分sql,提取有效字段
    def __split_line(self, line):
        list_param = line.split(",")
        d_staff = {}
        for i in range(len(list_param)):
            if i == 0:
                d_staff["staff_id"] = list_param[i].strip()
            elif i == 1:
                d_staff["name"] = list_param[i].strip()
            elif i == 2:
                d_staff["age"] = list_param[i].strip()
            elif i == 3:
                d_staff["phone"] = list_param[i].strip()
            elif i == 4:
                d_staff["dept"] = list_param[i].strip()
            elif i == 5:
                d_staff["enroll_date"] = list_param[i].strip()
        return d_staff

    # 加载数据表
    def load_file(self):
        if os.path.exists(self.file_path):
            list_staffs = []
            with codecs.open(self.file_path, "r", "UTF-8") as staff_list:
                for line in staff_list:
                    staff = self.__split_line(line)
                    list_staffs.append(staff)
                return list_staffs

    # 读取数据表的最后一行
    def __read_last_line(self):
        with codecs.open(self.file_path, 'r', "UTF-8") as f:  # 打开文件
            lines = f.readlines()
            return lines[-1]  # 取最后一行

    # 获取最后一个id,实现id自增
    def __get_last_staff_id(self):
        line = self.__read_last_line()
        d_staff = self.__split_line(line)
        return d_staff["staff_id"]


# 日志工具
def log(msg, log_type="info"):
    if log_type == 'info':
        print("\033[32;1m%s\033[0m" % msg)
    elif log_type == 'error':
        print("\033[31;1m%s\033[0m" % msg)


# 打印表格
def print_table(staffs):
    print(tabulate(staffs, tablefmt="grid"))


if __name__ == '__main__':
    staffManage = StaffManage()
    while True:
        sql_eval = input(">>>")
        staffManage.manage(sql_eval)

你可能感兴趣的:(python学习笔记04)