nonebot2聊天机器人插件8:基于GAN的虚拟巡天quasistar_engine

nonebot2聊天机器人插件8:基于GAN的虚拟巡天quasistar_engine

  • 1. 插件用途
  • 2. 目录结构
  • 3. 实现难点与解决方案
    • 3.1 图片文件读写冲突
    • 3.2 对抗式生成网络
    • 3.3 GPU模型转CPU运算
  • 4. 代码实现
  • 5. 插件配图
  • 6. 实际效果
  • 7. 下一个插件

该插件涉及知识点:图片文件读写冲突,对抗式生成网络,GPU模型转CPU运算
插件合集:nonebot2聊天机器人插件

该系列为用于QQ群聊天机器人的nonebot2相关插件,不保证完全符合标准规范写法,如有差错和改进余地,欢迎大佬指点修正。
前端:nonebot2
后端:go-cqhttp
插件所用语言:python3
前置环境安装过程建议参考零基础2分钟教你搭建QQ机器人——基于nonebot2,但是请注意该教程中的后端版本过旧导致私聊发图异常,需要手动更新go-cqhttp版本。

1. 插件用途

插件名:类星引擎(Quasistar Engine),或者虚拟巡天~
将训练好的多个对抗式生成网络模型安装到插件内,随机从多个模型中选择一个,通过随机噪声生成星球图片。
1、每个帐号每天都可以得到一张固定的星球图片,随机生成结果每日轮换。
2、bot管理员能够用指定字符串作为种子,生成指定星球。
由于图片由深度学习的对抗式生成网络绘制,因此除非哈希值相同,否则几乎不可能出现完全一样的图片(但受限于训练集数量,可能看起来很相似)。
插件需要能够在没有GPU的云服务器上运行,使用CPU运算。
注:与之前的插件不同,这个插件需要训练好的生成器模型才能工作,训练相关模型文件需要具备神经网络方面的知识,在之前AI相关的博文中给出过该插件中用到的模型文件的训练方法,也不排除之后进一步改进的可能。

2. 目录结构

在plugins文件夹中新建一个文件夹quasistar_engine,文件夹内目录结构如下:

|-quasistar_engine
    |-generator_models
        |-所有生成器模型的.pth参数文件
    |-temp_img
        |-临时图片存储位置
    |-__init__.py
    |-quasistar_engine.py
    |-config.py
    |-run_generator.py
    |-model.py

其中temp_img为用于存储发送的临时图片的文件夹,generator_models为储存生成器模型数据文件的文件夹,quasistar_engine.py为命令事件响应器的位置,config.py用于存储配置项,run_generator.py用于根据参数运行模型生成图片,model.py为pytorch的生成器网络架构,__init__.py为程序启动位置。

3. 实现难点与解决方案

3.1 图片文件读写冲突

如果与之前的插件一样采用固定临时图片名发送,那么在不同群同时发送命令时,可能会产生图片文件的读写冲突导致异常,因此对于不同群组和私聊的信息,使用不同的文件名命名图片文件。

3.2 对抗式生成网络

对抗式生成网络的训练方法在之前的博文里面已经有详细解释:
彩色星球图片生成1:使用Gan实现(pytorch版)
彩色星球图片生成2:同时使用传统Gan判别器和马尔可夫判别器(pytorch版)
彩色星球图片生成3:代码改进(pytorch版)
彩色星球图片生成4:转置卷积+插值缩放+卷积收缩(pytorch版)
此处不再过多赘述,后续仍然可能有进一步改进的博客。

3.3 GPU模型转CPU运算

网络参数模型为了速度,是在GPU环境下训练的,而GPU云服务器虽然快,却价格过于昂贵,没有足够的财力配备所以说挖矿的真是害人
因此在读取网络模型参数时,需要将GPU训练的参数文件转换为使用CPU运算的格式。
将参数加载代码从G_model.load_state_dict(load(G_model_path))
改为G_model.load_state_dict(load(G_model_path, map_location=device('cpu')))

4. 代码实现

__init__.py

from .quasistar_engine import *

config.py

class Config:
    # 记录在哪些群组中使用
    used_in_group = ["131551175"]
    # 插件执行优先级
    priority = 10
    # 机器人QQ号
    bot_id = "123456789"
    # 管理员QQ号,管理员无视冷却cd
    super_uid = ["673321342"]
    # 触发冷却时间(秒),在这段时间内不会连续两次触发
    cd = 10

model.py

import torch.nn as nn


# 生成器,此处填入模型文件对应的生成器网络结构
class G_net(nn.Module):
    def __init__(self):
        super(G_net, self).__init__()
        pass

    def forward(self, img_seeds):
        pass
        return imgs


# 获取模型
def get_G_model():
    model = G_net()
    return model

run_generator.py

import random
from .model import get_G_model
import os
from time import sleep
from torchvision.utils import save_image
import torch

# 存储所有模型的地址
models_path = os.path.split(os.path.realpath(__file__))[0] + '/generator_models'

# 存储所有临时图片的地址
temp_img_path = os.path.split(os.path.realpath(__file__))[0] + '/temp_img/'

# 存储所有待用的生成器模型
G_models = []


# 初始化加载所有生成器模型参数
def init_G_models(G_models):
    # 首先,读出所有模型文件的存储目录【允许任意个模型文件】
    # 读取文件夹下所有数据文件
    for path, dirs, files in os.walk(models_path, topdown=False):
        for name in files:
            G_models.append(os.path.join(path, name))

    # 对每一个模型文件,都创建一个生成器模型,并且保存在内存中
    def path_to_model(G_model_path):
        G_model = get_G_model()
        # 从磁盘加载之前保存的模型参数
        try:
            G_model.load_state_dict(torch.load(G_model_path, map_location=torch.device('cpu')))
            return G_model
        except:
            print("模型数据文件" + G_model_path + "不存在或使用了错误的文件,请关闭程序并检查文件,一分钟后程序将自动关闭。")
            sleep(60)
            exit()
    # 所有路径字符串都转换为了生成器模型
    G_models = list(map(path_to_model, G_models))
    # 需要设置为训练模式,以通过dropout产生更多的GAN多样性
    # 但是BN层必须被设置为测试模式,以保证输出性能
    for G_model in G_models:
        G_model.train()
        for _, module in G_model.named_modules():
            # 将所有BN层设置为测试模式
            if isinstance(module, nn.BatchNorm1d) or isinstance(module, nn.BatchNorm2d):
                module.training = False
    return G_models


# 初始化加载所有生成器模型参数
G_models = init_G_models(G_models)


# 创建图片
# 输入种子数和图片名,将图片写入磁盘,返回一张由GAN生成的图片的地址
def create_img(seed_int, img_name):
    # 生成器输入种子数的维度
    img_seed_dim = 128
    # 图片分辨率
    img_size = 528
    # 输出图片的路径
    img_path = temp_img_path + img_name + '.png'

    # 固定随机数种子
    random.seed(seed_int)
    torch.manual_seed(seed_int)

    # 根据种子数随机选择一个生成器模型使用
    G_model = random.choice(G_models)

    with torch.no_grad():
        # 输出一张生成器产生的图片到输出文件夹
        # 产生随机正态分布噪声
        img_seeds = torch.randn(1, img_seed_dim)
        # 生成图像
        fake_images = G_model(img_seeds)
        # 将图像缩放到[0,1]的区间
        fake_images = 0.5 * (fake_images + 1)
        fake_images = fake_images.clamp(0, 1)
        # 用torchvision自带的save_image()函数输出到磁盘文件
        fake_images = fake_images.view(-1, 3, img_size, img_size)
        save_image(fake_images, img_path)

    # 返回输出图像的路径
    return img_path

quasistar_engine.py

from nonebot import on_command
from nonebot.typing import T_State
from nonebot.adapters import Bot, Event
from nonebot.adapters.cqhttp import MessageSegment
from .config import Config
from time import time, localtime
import json
from hashlib import md5
from .run_generator import *
from nonebot.permission import SUPERUSER

# 记录上一次响应时间
last_response = {
     }


# 判断是否过了响应cd的函数,默认使用配置文件中的cd
# 如果已经超过了最短响应间隔,返回True
def cool_down(group_id, cd = Config.cd):
    global last_response
    if group_id not in last_response:
        return True
    else:
        return time() - last_response[group_id] > cd


# 每个用户都可以得到的每日随机星球图片
search_star = on_command("观星", priority=Config.priority)


@search_star.handle()
async def handle_first_receive(bot: Bot, event: Event, state: T_State):
    ids = event.get_session_id()
    allow_use = True
    # 如果这是一条群聊信息
    if ids.startswith("group"):
        _, group_id, user_id = event.get_session_id().split("_")
        if group_id not in Config.used_in_group:
            allow_use = False
    # 对于私聊信息,在前方加上private_作为冷却时间存储的key
    else:
        user_id = ids
        group_id = 'private_'+ids
    # 如果允许使用
    if allow_use:
        # 如果已经过了冷却时间,或者用户是管理员
        if cool_down(group_id) or user_id in Config.super_uid:
            # 如果用户不是超级用户,更新cd时间
            if user_id not in Config.super_uid:
                last_response[group_id] = time()
            # 对超级用户,直接使用super_qq号作为存储key与图片保存名,防止在cd期间内出现文件读写冲突
            if user_id in Config.super_uid:
                group_id = 'super_'+user_id
            # 获取用户昵称
            infos = str(await bot.get_stranger_info(user_id=user_id))
            nickname = json.loads(infos.replace("'", '"'))['nickname']
            # 昵称与用户名的组合
            nickname_and_id = nickname + '(' + str(user_id) + ')'
            # 使用[QQ号+当日日期]作为随机数种子
            # 获取日期
            local = localtime(time())
            today = f"{
       local[0]}{
       local[1]}{
       local[2]}日"
            # 组合种子字符串
            seed_str = user_id + today
            # 将种子字符串转换为种子数
            # text to md5
            md5_str = md5(seed_str.encode('utf-8'))
            # md5 to int
            seed_int = int(str(str(int('0x' + md5_str.hexdigest(), 0)))[-16:])
            # 输入种子数和图片名创建图片,使用group_id作为图片名防止冲突
            img_path = create_img(seed_int, group_id)
            # 将图片发送给用户
            await search_star.send(f"{
       nickname_and_id}\n{
       today}\n你观测到了星球:\n"+MessageSegment.image('file:///' + img_path))
            # 发送完之后删除临时文件夹中的图片
            os.remove(img_path)


# bot管理员专属命令,根据后缀信息创造星球
create_star = on_command("创星", permission=SUPERUSER, priority=Config.priority)


@create_star.handle()
async def handle_first_receive(bot: Bot, event: Event, state: T_State):
    msg = str(event.get_message()).strip().replace('\r\n', '').replace('\n', '')
    ids = event.get_session_id()
    # 如果这是一条群聊信息
    if ids.startswith("group"):
        _, group_id, user_id = event.get_session_id().split("_")
    # 对于私聊信息,在前方加上private_作为冷却时间存储的key
    else:
        user_id = ids
    # 将种子字符串转换为种子数
    # text to md5
    md5_str = md5(msg.encode('utf-8'))
    # md5 to int
    seed_int = int(str(str(int('0x' + md5_str.hexdigest(), 0)))[-16:])
    # 输入种子数和图片名创建图片,使用group_id作为图片名防止冲突
    img_path = create_img(seed_int, "create_star")

    # 获取用户昵称
    infos = str(await bot.get_stranger_info(user_id=user_id))
    nickname = json.loads(infos.replace("'", '"'))['nickname']
    # 昵称与用户名的组合
    nickname_and_id = nickname + '(' + str(user_id) + ')'

    # 将图片发送给用户
    await search_star.send(f"{
       nickname_and_id}\n你创造了{
       msg}星:\n"+MessageSegment.image('file:///' + img_path))
    # 发送完之后删除临时文件夹中的图片
    os.remove(img_path)


# 查询帮助命令
star_helper = on_command("巡天帮助", priority=Config.priority)


@star_helper.handle()
async def handle_first_receive(bot: Bot, event: Event, state: T_State):
    await star_helper.finish('''巡天指令说明
普通用户指令:
观星——每天创建一颗随机星球,当天结果固定

管理员指令:
创星 [星球名字]——根据星球名字创建一颗随机星球''')

5. 插件配图

该插件无配图

6. 实际效果

7. 下一个插件

暂未完成

你可能感兴趣的:(nonebot2聊天机器人插件,深度学习,神经网络,python,生成对抗网络)