功能:对对联,输入上联得到下联。
网络:CNN+GRU+Attention
环境:python3.7 + PyTorch 1.3.1
环境安装
安装anaconda环境:AIPyTorch13
cd
source /home/zyb/zyb/anaconda3/etc/profile.d/conda.sh
cd
conda create --name AIPyTorch13 python=3.7.4 ipykernel
安装第三方包
cd
source /home/zyb/zyb/anaconda3/etc/profile.d/conda.sh
conda activate AIPyTorch13
pip install msgpack
pip install jupyter
pip install torch==1.3.1
pip install torchvision==0.2.1
pip install matplotlib
pip install numpy==1.19.5
pip install scipy
pip install pytest
pip install tqdm
pip install tensorboardX
pip install tensorflow-gpu==2.2.0
conda install cudatoolkit=10.1
pip install gensim
pip install jieba
pip install pypinyin
conda install yaml
pip install easydict
pip install numpy
pip install opencv-python
git下载代码
cd
cd /home/zyb/zyb
mkdir couplet
cd
cd /home/zyb/zyb/couplet
git clone https://github.com/neoql/open_couplet.git
获取和整理样本
这儿采用了清洁(删除了一些敏感词)后的样本。
cd
cd /home/zyb/zyb/couplet/open_couplet
git clone https://github.com/v-zich/couplet-clean-dataset.git
cd
cp -rf /home/zyb/zyb/couplet/open_couplet/couplet-clean-dataset/couplets /home/zyb/zyb/couplet/open_couplet/dataset
cd
cp -rf /home/zyb/zyb/couplet/open_couplet/dataset/test /home/zyb/zyb/couplet/open_couplet/dataset/dev
训练
整理代码
cd
cp -rf /home/zyb/zyb/couplet/open_couplet/open_couplet /home/zyb/zyb/couplet/open_couplet/scripts/open_couplet
cd
cp -rf /home/zyb/zyb/couplet/open_couplet/train /home/zyb/zyb/couplet/open_couplet/scripts/train
进入环境
cd
source /home/zyb/zyb/anaconda3/etc/profile.d/conda.sh
conda activate AIPyTorch13
cd
cd /home/zyb/zyb/couplet/open_couplet
生成词表
python scripts/build_vocab.py \
dataset/train/in.txt dataset/train/out.txt \
dataset/dev/in.txt dataset/dev/out.txt \
--add-cn-punctuations \
--unused-tokens=10 \
--output-file=experiment/vocab.txt
参数说明:
匿名参数:用于生成词表的文件列表
–add-cn-punctuations:预先添加中文标点(可不设置, 默认为False)
–unused-tokens: 添加未启用token(可不设置,默认为0)
–output-file: 设置词表的输出文件
注:该脚本在添加字符时会自动将英文标点转为中文标点
训练模型
python scripts/train_seq2seq.py \
--vocab_file=experiment/vocab.txt \
--hidden_size=1000 \
--rnn_layers=2 \
--cnn_kernel_size=3 \
--dropout_p=0.1 \
--train_set_dir=dataset/train \
--dev_set_dir=dataset/dev \
--save_dir=experiment/checkpoints \
--logging_dir=experiment/log \
--learning_rate=0.001 \
--num_epochs=50 \
--batch_size=128 \
--max_grad_norm=5
参数说明:
–vocab_file: 设置词表文件
–hidden_size: 隐层维度
–rnn_layers: RNN(GRU)层数
–cnn_kernel_size: CNN卷积核大小
–dropout_p: dropout层置0概率
–train_set_dir: 训练集所在目录
–dev_set_dir: 开发集所在目录
–save_dir: 模型保存路径
–logging_dir: tensorboard的Summary输出路径
–learning_rate: 学习率
–num_epochs: epoch数量
–batch_size: 批处理数量
–max_grad_norm: 梯度裁减,最大norm值
训练速度:2d
测试
查看/experiment/checkpoints/best_checkpoint.json中提示第63449为最优
STEP=63449
python scripts/demo.py \
--model=experiment/checkpoints/checkpoint_$STEP \
--vocab_file=experiment/vocab.txt
参数说明:
–model: 模型检查点目录
–vocab_file: 设置词表文件
输入测试的上联会得到对应的下联。