python关键词对联_对联(代码1:CNN+GRU+Attention网络、PyTorch)

功能:对对联,输入上联得到下联。

网络:CNN+GRU+Attention

环境:python3.7 + PyTorch 1.3.1

环境安装

安装anaconda环境:AIPyTorch13

cd

source /home/zyb/zyb/anaconda3/etc/profile.d/conda.sh

cd

conda create --name AIPyTorch13 python=3.7.4 ipykernel

安装第三方包

cd

source /home/zyb/zyb/anaconda3/etc/profile.d/conda.sh

conda activate AIPyTorch13

pip install msgpack

pip install jupyter

pip install torch==1.3.1

pip install torchvision==0.2.1

pip install matplotlib

pip install numpy==1.19.5

pip install scipy

pip install pytest

pip install tqdm

pip install tensorboardX

pip install tensorflow-gpu==2.2.0

conda install cudatoolkit=10.1

pip install gensim

pip install jieba

pip install  pypinyin

conda install yaml

pip install easydict

pip install numpy

pip install opencv-python

git下载代码

cd

cd /home/zyb/zyb

mkdir couplet

cd

cd /home/zyb/zyb/couplet

git clone https://github.com/neoql/open_couplet.git

获取和整理样本

这儿采用了清洁(删除了一些敏感词)后的样本。

cd

cd /home/zyb/zyb/couplet/open_couplet

git clone https://github.com/v-zich/couplet-clean-dataset.git

cd

cp -rf /home/zyb/zyb/couplet/open_couplet/couplet-clean-dataset/couplets /home/zyb/zyb/couplet/open_couplet/dataset

cd

cp -rf /home/zyb/zyb/couplet/open_couplet/dataset/test /home/zyb/zyb/couplet/open_couplet/dataset/dev

训练

整理代码

cd

cp -rf /home/zyb/zyb/couplet/open_couplet/open_couplet /home/zyb/zyb/couplet/open_couplet/scripts/open_couplet

cd

cp -rf /home/zyb/zyb/couplet/open_couplet/train /home/zyb/zyb/couplet/open_couplet/scripts/train

进入环境

cd

source /home/zyb/zyb/anaconda3/etc/profile.d/conda.sh

conda activate AIPyTorch13

cd

cd /home/zyb/zyb/couplet/open_couplet

生成词表

python scripts/build_vocab.py \

dataset/train/in.txt dataset/train/out.txt \

dataset/dev/in.txt dataset/dev/out.txt \

--add-cn-punctuations \

--unused-tokens=10 \

--output-file=experiment/vocab.txt

参数说明:

匿名参数:用于生成词表的文件列表

–add-cn-punctuations:预先添加中文标点(可不设置, 默认为False)

–unused-tokens: 添加未启用token(可不设置,默认为0)

–output-file: 设置词表的输出文件

注:该脚本在添加字符时会自动将英文标点转为中文标点

训练模型

python scripts/train_seq2seq.py \

--vocab_file=experiment/vocab.txt \

--hidden_size=1000 \

--rnn_layers=2 \

--cnn_kernel_size=3 \

--dropout_p=0.1 \

--train_set_dir=dataset/train \

--dev_set_dir=dataset/dev \

--save_dir=experiment/checkpoints \

--logging_dir=experiment/log \

--learning_rate=0.001 \

--num_epochs=50 \

--batch_size=128 \

--max_grad_norm=5

参数说明:

–vocab_file: 设置词表文件

–hidden_size: 隐层维度

–rnn_layers: RNN(GRU)层数

–cnn_kernel_size: CNN卷积核大小

–dropout_p: dropout层置0概率

–train_set_dir: 训练集所在目录

–dev_set_dir: 开发集所在目录

–save_dir: 模型保存路径

–logging_dir: tensorboard的Summary输出路径

–learning_rate: 学习率

–num_epochs: epoch数量

–batch_size: 批处理数量

–max_grad_norm: 梯度裁减,最大norm值

训练速度:2d

测试

查看/experiment/checkpoints/best_checkpoint.json中提示第63449为最优

STEP=63449

python scripts/demo.py \

--model=experiment/checkpoints/checkpoint_$STEP \

--vocab_file=experiment/vocab.txt

参数说明:

–model: 模型检查点目录

–vocab_file: 设置词表文件

输入测试的上联会得到对应的下联。

你可能感兴趣的:(python关键词对联)