论文地址:https://arxiv.org/abs/2011.03029
代码地址:https://github.com/InterDigitalInc/CompressAI
一个关于End-to-End Image Compression 的pytorch库,预计复现以下四篇论文的模型。
因为国内网的问题,先下载好torch是比较好的安装方式,先将Conda内源换成清华镜像源,虽然清华源也很慢很慢很慢,但是比国外源略微快上一丢丢。
用conda创建python=3.8, cuda=10.1的虚拟环境(建议)
conda create -n env_name python=3.8 cudatoolkit=10.1 cudnn
创建环境后需要激活虚拟环境,以在该虚拟环境下下载对应的库包。
conda activate env_name
如果使用基础的环境,则直接通过conda下载cuda即可
conda install cudatoolkit=10.1 cudnn
安装pytotch,进入pytorch官网:https://pytorch.org,选择对应的版本
根据提示输入下载命令, -c pytorch 是指指定官方通道,torch的服务器在国外,emmm,会断的厉害,但是清华源断的也挺厉害的,多下载几次=-=:
conda install pytorch torchvision torchaudio
torch下载完成后,根据官方指导,开始下载CompressAI, 从clone工程到你的机器上,下载结束后,进行pip安装。
git clone https://github.com/InterDigitalInc/CompressAI compressai
cd compressai
pip install -U pip && pip install -e . //该命令用下一条命令替换,更快地安装。
pip install -e . -i https://pypi.douban.com/simple //pip 豆瓣源比清华源好=-=
安装结束后,输入
conda list //查看该环境下的安装包,如果出现compressai,即安装成功
python
import compressai //不报错即安装成功
Compressai的一级结构如下,具体使用API指导:https://interdigitalinc.github.io/CompressAI/
其中,主要关注两个目录,compressai目录下即pip编译的源码,修改这里的代码会修改compresssai的API应用, example目录下的是代码是使用范例。
使用:
其中
/path/to/my/image/dataset/ 表示数据集的目录, 该数据集下分为 train 和test目录, train内部放train的 .png图像, test放测试图像。 --cuda 使用GPU,–save保存训练好的模型。
python examples/train.py -d /path/to/my/image/dataset/ --epochs 300 -lr 1e-4 --batch-size 16 --cuda --save
训练结束后需要更新CDF保证熵编码的正常运行:
python -m compressai.utils.update_model [-h] [-n NAME] [-d DIR] [–no-update] [–architecture {factorized-prior,jarhp,mean-scale-hyperprior,scale-hyperprior}] filepath
python -m compressai.utils.update_model [-d DIR] [--architecture {factorized-prior,jarhp,mean-scale-hyperprior,scale-hyperprior}] filepath
评价模型:
/path/to/images/folder/ 和上述的不同,该文件夹内直接存储需要test的png图像。
-a $ARCH 表示采用的预设定的模型,列表如下六种。
-p $MODEL_CHECKPOINT 表示存储的网络模型。
python -m compressai.utils.eval_model checkpoint /path/to/images/folder/ -a $ARCH -p $MODEL_CHECKPOINT...
由于训练结束需要更新Entropy的CDF以正常进行测试阶段的熵编码工作,但是上述的CDF更新制定了预先定义好的框架,当采用自己的框架的时候,CDF的更新需要自行阅读对应源码并且修改进行CDF的更新。