python学习笔记--ThreadLocal

我们知道多线程环境下,每一个线程均可以使用所属进程的全局变量。如果一个线程对全局变量进行了修改,将会影响到其他所有的线程。为了避免多个线程同时对变量进行修改,引入了线程同步机制,通过互斥锁,条件变量或者读写锁来控制对全局变量的访问。

只用全局变量并不能满足多线程环境的需求,很多时候线程还需要拥有自己的私有数据,这些数据对于其他线程来说不可见。因此线程中也可以使用局部变量,局部变量只有线程自身可以访问,同一个进程下的其他线程不可访问。

有时候使用局部变量不太方便,因此 python 还提供了 ThreadLocal 变量,它本身是一个全局变量,但是每个线程却可以利用它来保存属于自己的私有数据,这些私有数据对其他线程也是不可见的。下图给出了线程中这几种变量的存在情况:


python学习笔记--ThreadLocal_第1张图片

首先借助一个小程序来看看多线程环境下全局变量的同步问题。

import threading
# 如果是一个类对象,结果就完全不一样
 class Foo(object):
     def __init__(self):
         self.name = 0
local_values = Foo()
def func(num):
    local_values.name = num
    import time
    time.sleep(1)
    print(local_values.name, threading.current_thread().name)
for i in range(20):
    th = threading.Thread(target=func, args=(i,), name='线程%s' % i)
    th.start()

程序运行的结果:

19 线程1
19 线程3
19 线程4
19 线程5
19 线程2
19 线程0
19 线程9
19 线程6
19 线程8
19 线程11
19 线程7
19 线程10
19 线程15
19 线程16
19 线程13
19 线程14
19 线程18
19 线程19
19 线程12
19 线程17

在程序中设置time.sleep(1)是为了展示,当取全局变量,修改全局变量这个过程在CPU执行过程中并不是原子性的,它是可以被打断的。

再看看用TreadLocal来实现,效果如何:

import threading
# 实例化对象
local_values = threading.local()
def func(num):
    # 给对象加属性,这个属性就会保存在当前线程开辟的空间中
    local_values.name = num
    import time
    time.sleep(1)
    # 取值的时候也从当前线程开辟的空间中取值
    print(local_values.name, threading.current_thread().name)
for i in range(20):
    th = threading.Thread(target=func, args=(i,), name='线程%s' % i)
    th.start()

结果如下:

0 线程0
1 线程1
2 线程2
3 线程3
10 线程10
9 线程9
4 线程4
8 线程8
5 线程5
7 线程7
6 线程6
13 线程13
11 线程11
17 线程17
15 线程15
14 线程14
16 线程16
12 线程12
18 线程18
19 线程19

很明显每个线程之间没有干涉,均取到了线程对应的正确值。

避免线程之间共享数据干涉的方法就是建立一个全局字典,保存进程 ID 到该进程局部变量的映射关系,运行中的线程可以根据自己的 ID 来获取本身拥有的数据。这样,就可以避免在函数调用中传递参数。

我们仿照源码写一个简单版的‘threadlocal’,

依据这个思路,我们自己实现给线程开辟独有的空间保存特有的值

协程和线程都有自己的唯一标识get_ident,利用这个唯一标识作为字典的key,key对应的value就是当前线程或协程特有的值,取值的时候也拿这个key来取:

import threading
# get_ident就是获取线程或协程唯一标识的
try:
    from greenlet import getcurrent as get_ident # 协程
    # 当没有协程的模块时就用线程
except ImportError:
    try:
        from thread import get_ident
    except ImportError:
        from _thread import get_ident # 线程
class Local(object):
    def __init__(self):
        self.storage = {}
        self.get_ident = get_ident
    def set(self,k,v):
        ident = self.get_ident()
        origin = self.storage.get(ident)
        if not origin:
            origin = {k:v}
        else:
            origin[k] = v
        self.storage[ident] = origin
    def get(self,k):
        ident = self.get_ident()
        origin = self.storage.get(ident)
        if not origin:
            return None
        return origin.get(k,None)
# 实例化自定义local对象对象
local_values = Local()
def task(num):
    local_values.set('name',num)
    import time
    time.sleep(1)
    print(local_values.get('name'), threading.current_thread().name)
for i in range(20):
    th = threading.Thread(target=task, args=(i,),name='线程%s' % i)
    th.start()

本程序中添加了,获取协程Id的模块。类似Flask中Local的功能。

python的object类中有__setattr__和__getattr__内置方法,可以通过改写这两个方法来完成对类属性的赋值与取值操作。如下:

import threading
try:
    from greenlet import getcurrent as get_ident # 协程
except ImportError:
    try:
        from thread import get_ident
    except ImportError:
        from _thread import get_ident # 线程
class Local(object):
    def __init__(self):
        # 这里一定要用object来调用,因为用self调用的就会触发__setattr__方法,__setattr__方法里
        # 又会用self去赋值就又会调用__setattr__方法,就变成递归了
        object.__setattr__(self, '__storage__', {})
        object.__setattr__(self, '__ident_func__', get_ident)
    def __getattr__(self, name):
        try:
            return self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)
    def __setattr__(self, name, value):
        ident = self.__ident_func__()
        storage = self.__storage__
        try:
            storage[ident][name] = value
        except KeyError:
            storage[ident] = {name: value}
    def __delattr__(self, name):
        try:
            del self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)
local_values = Local()
def task(num):
    local_values.name = num
    import time
    time.sleep(1)
    print(local_values.name, threading.current_thread().name)
for i in range(20):
    th = threading.Thread(target=task, args=(i,),name='线程%s' % i)
    th.start()

python的ThreadLocal要比这个更复杂,功能更强大,源码目前还不能完全看懂,等功力提升后再继续研究。


参考博客:

https://www.cnblogs.com/wanghl1011/articles/8619148.html

https://selfboot.cn/2016/08/22/threadlocal_overview/


你可能感兴趣的:(python进阶)