【创新实训】推荐系统之召回池设计

召回池

我的想法是,输入形式为batch时模型总计算时间应当比一个一个喂小不少,因此可以建立一个比如200为大小的召回池,最多每隔0.5s送入模型进行计算,当waiting list已经到达了200,立即执行计算,重置定时任务。

消息队列

python的queue库是一个线程安全的队列,可以用作消息队列。
其基本用法参考:
Python之queue模块以及生产消费者模型

先写一个类包装

class Pack:
    def __init__(self, flag=False, _id=None, seq=None, hidden=None):
        self.flag = flag    # True means activation signal
        self._id = _id
        self.seq = seq
        if self.hidden:
            self.hidden = hidden
        else:
            self.hidden = [0] * setting['hidden_dim']

flag为True时,意味着这是一个有意义的输入;False时用作Timer对队列的唤醒令牌。

我们的召回池,wating_list就是消息队列:

class BatchPool:
    def __init__(self, pool_size=200, timeout=0.5):
        self.timeout = timeout
        self.uid2idx = dict()
        self.res = dict()
        self.waiting_list = Queue(pool_size)
        self.data = {'seq': list(), 'lengths': list(), 'hidden': list()}

        self.get_recalls()

    def clear(self):
        for key in self.data.keys():
            self.data[key].clear()

    def get_recalls(self):
        while True:
            item = self.waiting_list.get()
            if not item.flag:
                self.uid2idx[item._id] = len(self.data['seqs'])
                self.data['seq'].append(item.seq)
                self.data['hidden'].append(item.hidden)
                self.data['lengths'].append(len(item.seq))
            if item.flag or self.waiting_list.full():
                top_items = recommander.recall(**self.data)
                for uid, idx in self.uid2idx:
                    self.res[uid] = top_items[idx]
                self.clear()
                self.waiting_list.task_done()

当wating_list为空,消费者会阻塞在get()处。最后召回完成后,task_done()使所有join()挂起的线程唤醒。

但现在有个问题,如何保证所有发起请求者都拿到自己所需的数据后,召回池才开始新一轮召回任务?这里生产者变成了召回池,消费者变成了请求线程。

可以,让生产者向消费者发n次消息,然后等在一个condition上,消费者取走到之后判断队列是否为空,空则唤醒生产者。但这样有个问题,如果有个消费者线程挂掉了,那么永远没有唤醒生产者的第n个消费者。

因此需要让消费者通知生产者,生产者等n次通知继续,或者timeout继续。当然,这个方式还是存在问题,如果timeout太大,下一轮迟迟不能开始,如果太小,有可能来不及让存活的消费者拿走。不过进程挂掉的概率很小,所以可以暂不考虑。

    def get_recalls(self):
        while True:
            item = self.waiting_list.get()
            if item.flag:
                self.uid2idx[item._id] = len(self.data['seqs'])
                self.data['seq'].append(item.seq)
                self.data['hidden'].append(item.hidden)
                self.data['lengths'].append(len(item.seq))
            if not item.flag and len(self.uid2idx) > 0 or self.waiting_list.full():
                top_items = recommander.recall(**self.data)
                for uid, idx in self.uid2idx:
                    self.res[uid] = top_items[idx]
                self.waiting_list.task_done()
                for i in range(len(self.uid2idx)):
                    try:
                        self.notify_queue.get()
                    except Exception as e:
                        print(e)
                self.clear()

    def ask_for_recall(self, _id, seq, hidden):
        self.waiting_list.put(Pack(True, _id, seq, hidden))
        self.waiting_list.join()
        self.notify_queue.put(1)

定时任务

python提供了Timer定时器,但是只能按固定的interval开启。

因此使用APScheduler库进行调度。

参考Python 定时任务的实现方式

        self.shed = BackgroundScheduler()
        self.job = self.shed.add_job(self.awake, 'interval', seconds=timeout)
        self.shed.start()
        self.get_recalls()

    def awake(self):
        if len(self.uid2idx) > 0:
            self.waiting_list.put(Pack())

    def get_recalls(self):
        while True:
            item = self.waiting_list.get()
            if item.flag:
                self.uid2idx[item._id] = len(self.data['seqs'])
                self.data['seq'].append(item.seq)
                self.data['hidden'].append(item.hidden)
                self.data['lengths'].append(len(item.seq))
            if not item.flag and len(self.uid2idx) > 0 or len(self.uid2idx) >= self.pool_size:
                self.job.pause()
                top_items = recommander.recall(**self.data)
                for uid, idx in self.uid2idx:
                    self.res[uid] = top_items[idx]
                self.waiting_list.task_done()
                for i in range(len(self.uid2idx)):
                    try:
                        self.notify_queue.get()
                    except Exception as e:
                        print(e)
                self.clear()
                self.job.resume()

可能遇到的错误 APScheduler: LookupError: No trigger by the name “interval” was found

优化与Debug

后来想了想,消息队列的最大大小不能是200,必须比他大,因为在消费者召回计算中,还有新的生产者向消息队列中添加。设为无限大?问题是在新生产者添加后进入等待,消费者计算完之后notify所有的生产者,新生产者被提前唤醒了。

所以设置200比较安全,但是如果定时器一直发消息,模型一直在计算,那么有可能消息队列里全都是唤醒令牌。所以必须在计算时让定时器暂停,这样就没问题了。

另外,get_recalls不能在主线程调用,不然会一直等在get()处,要加一个线程运行它

recommander = Recommander(setting['model'])
pool = BatchPool()
thread = threading.Thread(target=pool.get_recalls)
thread.start()

因为api调用的进程需要等待召回池返回,需要挂起。可以使用用户级别的协程来管理。需要使用async关键字声明异步方法,使用await等待返回。

    async def ask_for_recall(self, _id, seq, hidden):
        self.waiting_list.put(Pack(True, _id, seq, hidden))
        self.waiting_list.join()
        iids = self.res[_id]['iids']
        emb = self.res[_id]['emb']
        self.notify_queue.put(1)
        return iids, emb
async def get_rec_items(id, emb):
    iids, emb = await pool.ask_for_recall(id, emb)
    return iids, emb
@api_view(['GET'])
def recommand(requests):
    emb = None
    _id = requests.COOKIES.get('_id')
    user = None
    if _id:
        user = User.objects.only(['emb']).with_id(ObjectId('_id'))
        if not user:
            return JsonResponse({'err': '用户不存在,请重新登录'},
                                json_dumps_params={'ensure_ascii': False})
        if user.get('emb'):
            emb = user['emb']
    seq = requests.COOKIES.get('sess')
    loop = asyncio.get_event_loop()
    future = asyncio.ensure_future(get_rec_items('_id', emb))
    loop.run_until_complete(future)
    iids, emb = future.result()
    if user:
        user.emb = emb
        user.save()
    items = Details.objects.filter(sourceId__in=iids)
    return JsonResponse({'rec': items},
                        json_dumps_params={'ensure_ascii': False})
    loop = asyncio.get_event_loop()
    future = asyncio.ensure_future(get_rec_items('_id', emb))
    loop.run_until_complete(future)
    iids, emb = future.result()

这是一般调用的关键代码,但发起的请求线程并非主线程,get_event_loop会报错,建议用run()来执行,或者new一个loop

    future = asyncio.ensure_future(get_rec_items(_id, seq, emb))
    asyncio.run(future)

API

/api/recommand

无参数

待办事项

现在如果sess为空不推荐,应当通过计算item(余弦)相似度的方式推荐。

你可能感兴趣的:(山软项目实训)