GitHub - HongxinXiang/ImageMol: ImageMol is a molecular image-based pre-training deep learning framework for computational drug discovery.
我在windows上下载并测试的
首先是其中每个文件夹中文件的意义:
1. preparing dataset
在第一步“1. preparing dataset”中, 由 data.CSV 文件中的smiles列转换成 molecular image(在224文件夹下) ,并且data.CSV生成 “data_for_pretrain.csv”(只是原来smiles列变成了filename)
也就是说224中的每一张照片对应data.csv / data_for_pretrain.csv的每一行
2. fine-tune
其中计算的逻辑:
我们要使用分类器进行fine-tune,将 224中的图像输入进model,得到 维度为“(64,617)”的pred,然后把它与data_for_pretrain.csv中真实的标签 labels(64,617)计算损失,其中64是batch-size,617为每一个图片对应的label的长度(我认为此时fine-tune是一种监督学习,它是利用有标签的数据进行微调,然后预训练模型使通过self-supervise的方式进行“无监督方式”训练模型)【总结来看:是用过把image的信息embedding,然后直接预测label,中间过程并没有生成任何的image...,与使用image进行“OG, OP”不一样,因为“OG, OP”是要生成image的】
# model(image) <==> label 他们之间计算损失
# 化学分子性质预测主要包含对已经预训练好的模型拿来进行finetune,下游任务主要是在模型上添加线性分类器,来预测下游图形标签,然后再以端到端的方式进行微调。(其中label是分子的性质)
1. GPU environment
CUDA 10.1
2. create a new conda environment
conda create -n imagemol python=3.7.3
conda activate imagemol
3. download some packages
conda install -c rdkit rdkit
windows:
linux:
pip install torch-cluster torch-scatter torch-sparse torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.4.0%2Bcu101.html
pip install -r requirements.txt
source activate imagemol
注意:这里有一个错误,github上requirements.txt写的 scikit-learn==1.21.2,但是现在最新版本的scikit-learn==1.1.2,并且它还必须python>=3.8
所以这里,我安装的是“ numpy-1.21.6 scikit-learn-1.0.2”
data.CSV文件中的smiles列生成224*224的molecular image(每一行都可以生成一个分子图像), and data.CSV文件生成“data_for_pretrain.csv”(只是原来的smiles列变成了filename列)
Download pretraining data and put it into ./datasets/pretraining/data/
Preprocess dataset:(这个时间比较长,这里要处理接近一千万份文件,我总共花费了22个小时)
python ./data_process/smiles2img_pretrain.py --dataroot ./datasets/pretraining/ --dataset data
Note: You can find the toy dataset in ./datasets/toy/pretraining/
训练一个预训练模型,从而用于下游任务
Usage:
usage: pretrain.py [-h] [--lr LR] [--wd WD] [--workers WORKERS] [--val_workers VAL_WORKERS] [--epochs EPOCHS] [--start_epoch START_EPOCH] [--batch BATCH] [--momentum MOMENTUM] [--checkpoints CHECKPOINTS] [--seed SEED] [--dataroot DATAROOT] [--dataset DATASET] [--ckpt_dir CKPT_DIR] [--modelname {ResNet18}] [--verbose] [--ngpu NGPU] [--gpu GPU] [--nc NC] [--ndf NDF] [--imageSize IMAGESIZE] [--Jigsaw_lambda JIGSAW_LAMBDA] [--cluster_lambda CLUSTER_LAMBDA] [--constractive_lambda CONSTRACTIVE_LAMBDA] [--matcher_lambda MATCHER_LAMBDA] [--is_recover_training IS_RECOVER_TRAINING] [--cl_mask_type {random_mask,rectangle_mask,mix_mask}] [--cl_mask_shape_h CL_MASK_SHAPE_H] [--cl_mask_shape_w CL_MASK_SHAPE_W] [--cl_mask_ratio CL_MASK_RATIO]
Code to pretrain:
python pretrain.py --ckpt_dir ./ckpts/pretraining/ \ --checkpoints 1 \ --Jigsaw_lambda 1 \ --cluster_lambda 1 \ --constractive_lambda 1 \ --matcher_lambda 1 \ --is_recover_training 1 \ --batch 256 \ --dataroot ./datasets/pretraining/ \ --dataset data \ --gpu 0,1,2,3 \ --ngpu 4
For testing, you can simply pre-train ImageMol using single GPU on toy dataset:
python pretrain.py --ckpt_dir ./ckpts/pretraining-toy/ \ --checkpoints 1 \ --Jigsaw_lambda 1 \ --cluster_lambda 1 \ --constractive_lambda 1 \ --matcher_lambda 1 \ --is_recover_training 1 \ --batch 16 \ --dataroot ./datasets/toy/pretraining/ \ --dataset data \ --gpu 0 \ --ngpu 1
python pretrain.py --ckpt_dir ./ckpts/pretraining-toy/ --checkpoints 1 --Jigsaw_lambda 1 --cluster_lambda 1 --constractive_lambda 1 --matcher_lambda 1 --is_recover_training 1 --batch 16 --dataroot ./datasets/toy/pretraining --dataset data --gpu 0 --ngpu 1
我使用测试训练进行了训练了17个epoch,总共生成了17个预训练模型:
(imagemol) D:\pycharm_workspace\1\ImageMol>python pretrain.py --ckpt_dir ./ckpts/pretraining-toy/ --checkpoints 1 --Jigsaw_lambda 1 --cluster_lambda 1 --constractive_lambda 1 --matcher_lambda 1 --is_recover_training 1 --batch 16 --dataroot ./datasets/toy/pretraining --dataset data --gpu 0 --ngpu 1
ImageMol(
(embedding_layer): Sequential(
(0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(7): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(8): AdaptiveAvgPool2d(output_size=(1, 1))
)
(bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(jigsaw_classifier): Linear(in_features=512, out_features=101, bias=True)
(class_classifier1): Linear(in_features=512, out_features=100, bias=True)
(class_classifier2): Linear(in_features=512, out_features=1000, bias=True)
(class_classifier3): Linear(in_features=512, out_features=10000, bias=True)
)
Matcher(
(fc): Linear(in_features=512, out_features=2, bias=True)
(logic): LogSoftmax()
)
generator(
(projection): Sequential(
(0): Linear(in_features=512, out_features=128, bias=True)
(1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): LeakyReLU(negative_slope=0.2, inplace=True)
)
(netG): Sequential(
(0): ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
netlocalD(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
)
100%|█| 60/60 [02:29<00:00, 2.50s/it, C_loss=0.02, ClsLoss_100=4.36, ClsLoss_1000=7.16, ClsLoss_10000=9.29, ClsTotalLo
100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:19<00:00, 4.79s/it]
Epoch: [1][train] TotalLoss: 23.96166973114014 JigLoss: 1.9193571281929807 ClsLoss_100: 4.92697454293569 ClsLoss_1000: 7.11983088652293 ClsLoss_10000: 9.33748327891032 ClsTotalLoss(fftotal): 21.384288565317785 AvgConstractiveLoss: 0.1648771699983627 AvgReasonabilityLoss: 0.49314683731645315 AvgRecoverLoss: 5.0754999422313025
Epoch: [1][val] JigsawAcc: 0.82 ClusterAcc100: 0.0 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.0 ConstractiveLoss: 0.004153344817459583 ReasonabilityLoss: 0.013568593263626099 RecoverLoss: 0.2023082345724106
100%|█| 60/60 [01:06<00:00, 1.12s/it, C_loss=0.0138, ClsLoss_100=4.49, ClsLoss_1000=7.29, ClsLoss_10000=8.7, ClsTotalL
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00, 4.66s/it]
Epoch: [2][train] TotalLoss: 21.496857229868567 JigLoss: 1.0841593230764073 ClsLoss_100: 4.781556669871013 ClsLoss_1000: 6.793236716588339ClsLoss_10000: 8.56167837778727 ClsTotalLoss(fftotal): 20.136471748352054 AvgConstractiveLoss: 0.07614539431718487 AvgReasonabilityLoss: 0.20008064570526282 AvgRecoverLoss: 3.5995476151506103
Epoch: [2][val] JigsawAcc: 0.86 ClusterAcc100: 0.04 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.013333333333333334 ConstractiveLoss: 0.0020515684410929683 ReasonabilityLoss: 0.002394016683101654 RecoverLoss: 0.29341010212898255
100%|█| 60/60 [01:06<00:00, 1.11s/it, C_loss=0.0453, ClsLoss_100=4.83, ClsLoss_1000=6.7, ClsLoss_10000=8.39, ClsTotalLoss=19.9, JigLoss=1.68, M_loss=0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:20<00:00, 5.18s/it]
Epoch: [3][train] TotalLoss: 20.89204044342041 JigLoss: 1.003493670374155 ClsLoss_100: 4.813545378049215 ClsLoss_1000: 6.738298972447714ClsLoss_10000: 8.154170513153076 ClsTotalLoss(fftotal): 19.706014760335286 AvgConstractiveLoss: 0.04975322242826223 AvgReasonabilityLoss: 0.13277878270794943 AvgRecoverLoss: 3.6251505697766935
Epoch: [3][val] JigsawAcc: 0.86 ClusterAcc100: 0.04 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.013333333333333334 ConstractiveLoss: 0.0015005466900765898 ReasonabilityLoss: 0.01395109474658966 RecoverLoss: 0.20824745893478394
100%|█| 60/60 [01:08<00:00, 1.14s/it, C_loss=0.0533, ClsLoss_100=4.81, ClsLoss_1000=6.56, ClsLoss_10000=8, ClsTotalLoss=19.4, JigLoss=0.897, M_loss=0.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:19<00:00, 4.94s/it]
Epoch: [4][train] TotalLoss: 20.512035369873047 JigLoss: 0.9941644340753556 ClsLoss_100: 4.804758842786154 ClsLoss_1000: 6.710355981191003ClsLoss_10000: 7.865747348467508 ClsTotalLoss(fftotal): 19.38086214065552 AvgConstractiveLoss: 0.048341524026667075 AvgReasonabilityLoss: 0.08866713357468446 AvgRecoverLoss: 4.077671728531518
Epoch: [4][val] JigsawAcc: 0.76 ClusterAcc100: 0.0 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.0 ConstractiveLoss: 0.004614624008536339 ReasonabilityLoss: 0.008996209204196932 RecoverLoss: 0.31746475636959076
100%|█| 60/60 [01:07<00:00, 1.12s/it, C_loss=0.07, ClsLoss_100=4.49, ClsLoss_1000=6.72, ClsLoss_10000=7.93, ClsTotalLoss=19.1, JigLoss=3.02, M_loss=1.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00, 4.62s/it]
Epoch: [5][train] TotalLoss: 20.410388692220042 JigLoss: 1.0800357540448504 ClsLoss_100: 4.7787369966506965 ClsLoss_1000: 6.687890203793843ClsLoss_10000: 7.672290404637656 ClsTotalLoss(fftotal): 19.13891773223877 AvgConstractiveLoss: 0.05595408985391261 AvgReasonabilityLoss: 0.13548127884666128 AvgRecoverLoss: 4.075054861356814
Epoch: [5][val] JigsawAcc: 0.84 ClusterAcc100: 0.0 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.0 ConstractiveLoss: 0.002793522924184799 ReasonabilityLoss: 0.0062981796264648445 RecoverLoss: 0.18601167559623719
100%|█| 60/60 [01:06<00:00, 1.11s/it, C_loss=0.0246, ClsLoss_100=4.8, ClsLoss_1000=7.11, ClsLoss_10000=7.69, ClsTotalLoss=19.6, JigLoss=3.16, M_loss=1
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:21<00:00, 5.27s/it]
Epoch: [6][train] TotalLoss: 20.339342339833575 JigLoss: 1.1087447846929228 ClsLoss_100: 4.799536720911663 ClsLoss_1000: 6.672529379526773ClsLoss_10000: 7.573689786593119 ClsTotalLoss(fftotal): 19.04575576782228 AvgConstractiveLoss: 0.053185406218593315 AvgReasonabilityLoss: 0.13165633815030256 AvgRecoverLoss: 4.161377739906311
Epoch: [6][val] JigsawAcc: 0.84 ClusterAcc100: 0.04 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.013333333333333334 ConstractiveLoss: 0.0014511872828006744 ReasonabilityLoss: 0.005211712941527367 RecoverLoss: 0.1715515339374542
100%|█| 60/60 [01:07<00:00, 1.13s/it, C_loss=0.0319, ClsLoss_100=4.63, ClsLoss_1000=6.67, ClsLoss_10000=7.54, ClsTotalLoss=18.8, JigLoss=0.878, M_loss
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:22<00:00, 5.57s/it]
Epoch: [7][train] TotalLoss: 20.192618624369306 JigLoss: 1.0693875233332315 ClsLoss_100: 4.75707601706187 ClsLoss_1000: 6.66301953792572 ClsLoss_10000: 7.496208945910134 ClsTotalLoss(fftotal): 18.916304334004717 AvgConstractiveLoss: 0.047720210254192354 AvgReasonabilityLoss: 0.15920654994746053 AvgRecoverLoss: 4.174393929044406
Epoch: [7][val] JigsawAcc: 0.78 ClusterAcc100: 0.02 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.006666666666666667 ConstractiveLoss: 0.0035598852671682836 ReasonabilityLoss: 0.0006275869905948638 RecoverLoss: 0.431844232082367
100%|█| 60/60 [01:06<00:00, 1.11s/it, C_loss=0.0174, ClsLoss_100=4.66, ClsLoss_1000=6.93, ClsLoss_10000=7.62, ClsTotalLoss=19.2, JigLoss=1.58, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00, 4.55s/it]
Epoch: [8][train] TotalLoss: 19.8997561454773 JigLoss: 0.9543801690141359 ClsLoss_100: 4.7688616196314495 ClsLoss_1000: 6.649607149759929ClsLoss_10000: 7.420435182253519 ClsTotalLoss(fftotal): 18.83890390396117 AvgConstractiveLoss: 0.04308367188399036 AvgReasonabilityLoss: 0.06338832139347991 AvgRecoverLoss: 4.2416440991063915
Epoch: [8][val] JigsawAcc: 0.74 ClusterAcc100: 0.02 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.006666666666666667 ConstractiveLoss: 0.0024805641174316406 ReasonabilityLoss: 0.0048112948238849635 RecoverLoss: 0.25158809542655947
100%|█| 60/60 [01:06<00:00, 1.11s/it, C_loss=0.0381, ClsLoss_100=4.57, ClsLoss_1000=6.51, ClsLoss_10000=7.65, ClsTotalLoss=18.7, JigLoss=0.78, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00, 4.75s/it]
Epoch: [9][train] TotalLoss: 20.007253615061437 JigLoss: 1.0684301803509397 ClsLoss_100: 4.787256534894307 ClsLoss_1000: 6.628996245066327ClsLoss_10000: 7.368737761179607 ClsTotalLoss(fftotal): 18.78499059677124 AvgConstractiveLoss: 0.045559848969181374 AvgReasonabilityLoss: 0.10827299573769171 AvgRecoverLoss: 4.120056470980247
Epoch: [9][val] JigsawAcc: 0.84 ClusterAcc100: 0.02 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.006666666666666667 ConstractiveLoss: 0.002133065667003393 ReasonabilityLoss: 0.0024875566363334657 RecoverLoss: 0.19161114037036897
100%|█| 60/60 [01:07<00:00, 1.12s/it, C_loss=0.0287, ClsLoss_100=4.74, ClsLoss_1000=5.96, ClsLoss_10000=7.37, ClsTotalLoss=18.1, JigLoss=1.38, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:23<00:00, 5.88s/it]
Epoch: [10][train] TotalLoss: 19.906644948323567 JigLoss: 1.0861985241373382 ClsLoss_100: 4.752160032590231 ClsLoss_1000: 6.603877162933351ClsLoss_10000: 7.332753666241961 ClsTotalLoss(fftotal): 18.688790893554682 AvgConstractiveLoss: 0.04013523093114296 AvgReasonabilityLoss: 0.09152033478021618 AvgRecoverLoss: 4.039306613306204
Epoch: [10][val] JigsawAcc: 0.82 ClusterAcc100: 0.02 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.006666666666666667 ConstractiveLoss: 0.0020973937958478927 ReasonabilityLoss: 0.0024364076554775236 RecoverLoss: 0.18724668622016907
100%|█| 60/60 [01:08<00:00, 1.14s/it, C_loss=0.0463, ClsLoss_100=4.77, ClsLoss_1000=6.7, ClsLoss_10000=7.51, ClsTotalLoss=19, JigLoss=1.42, M_loss=0.0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:29<00:00, 7.25s/it]
Epoch: [11][train] TotalLoss: 19.741016260782875 JigLoss: 0.960937575995922 ClsLoss_100: 4.751010529200236 ClsLoss_1000: 6.601791771252948ClsLoss_10000: 7.31111691792806 ClsTotalLoss(fftotal): 18.663919194539393 AvgConstractiveLoss: 0.040693634655326605 AvgReasonabilityLoss: 0.07546558417379856 AvgRecoverLoss: 4.537612893184026
Epoch: [11][val] JigsawAcc: 0.82 ClusterAcc100: 0.06 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.02 ConstractiveLoss: 0.00219877889379859 ReasonabilityLoss: 0.00048301540315151217 RecoverLoss: 0.22352851390838624
100%|█| 60/60 [01:07<00:00, 1.13s/it, C_loss=0.0232, ClsLoss_100=4.68, ClsLoss_1000=6.71, ClsLoss_10000=7.56, ClsTotalLoss=19, JigLoss=0.0433, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:20<00:00, 5.13s/it]
Epoch: [12][train] TotalLoss: 19.715614541371657 JigLoss: 1.000559947391351 ClsLoss_100: 4.731101123491921 ClsLoss_1000: 6.583861041069033ClsLoss_10000: 7.277890539169311 ClsTotalLoss(fftotal): 18.592852783203124 AvgConstractiveLoss: 0.03278368460790566 AvgReasonabilityLoss: 0.08941795468175164 AvgRecoverLoss: 4.371748948221404
Epoch: [12][val] JigsawAcc: 0.78 ClusterAcc100: 0.02 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.006666666666666667 ConstractiveLoss: 0.001537955142557621 ReasonabilityLoss: 0.009752252995967867 RecoverLoss: 0.18312466382980347
100%|█| 60/60 [01:07<00:00, 1.12s/it, C_loss=0.0647, ClsLoss_100=4.85, ClsLoss_1000=6.72, ClsLoss_10000=7.55, ClsTotalLoss=19.1, JigLoss=0.0258, M_los
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:23<00:00, 5.96s/it]
Epoch: [13][train] TotalLoss: 19.82162863413493 JigLoss: 1.0898563351792594 ClsLoss_100: 4.738014833132426 ClsLoss_1000: 6.585498579343159ClsLoss_10000: 7.261646850903826 ClsTotalLoss(fftotal): 18.585160191853845 AvgConstractiveLoss: 0.03763424041680991 AvgReasonabilityLoss: 0.10897765451421342 AvgRecoverLoss: 4.363648629685243
Epoch: [13][val] JigsawAcc: 0.74 ClusterAcc100: 0.0 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.0 ConstractiveLoss: 0.0021116061881184577 ReasonabilityLoss: 0.0005279695987701416 RecoverLoss: 0.21621067449450493
100%|█| 60/60 [01:07<00:00, 1.12s/it, C_loss=0.0384, ClsLoss_100=4.52, ClsLoss_1000=6.62, ClsLoss_10000=7.4, ClsTotalLoss=18.5, JigLoss=0.752, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00, 4.57s/it]
Epoch: [14][train] TotalLoss: 19.565531730651863 JigLoss: 0.9612024967869126 ClsLoss_100: 4.7033114989598594 ClsLoss_1000: 6.577907792727153ClsLoss_10000: 7.236068709691365 ClsTotalLoss(fftotal): 18.51728795369466 AvgConstractiveLoss: 0.040761741731936744 AvgReasonabilityLoss: 0.04627952212467788 AvgRecoverLoss: 4.573408140925071
Epoch: [14][val] JigsawAcc: 0.84 ClusterAcc100: 0.0 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.0 ConstractiveLoss: 0.0019987770169973373 ReasonabilityLoss: 0.002506262362003326 RecoverLoss: 0.19087217301130294
100%|█| 60/60 [01:07<00:00, 1.13s/it, C_loss=0.00878, ClsLoss_100=4.72, ClsLoss_1000=6.57, ClsLoss_10000=7.43, ClsTotalLoss=18.7, JigLoss=0.0542, M_lo
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:29<00:00, 7.25s/it]
Epoch: [15][train] TotalLoss: 19.65483560562134 JigLoss: 1.0458930411065623 ClsLoss_100: 4.722701446215311 ClsLoss_1000: 6.552511111895243ClsLoss_10000: 7.212150192260744 ClsTotalLoss(fftotal): 18.48736282984416 AvgConstractiveLoss: 0.045769707330813036 AvgReasonabilityLoss: 0.07581018048028149 AvgRecoverLoss: 4.728145783022047
Epoch: [15][val] JigsawAcc: 0.82 ClusterAcc100: 0.0 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.0 ConstractiveLoss: 0.0038376058358699085 ReasonabilityLoss: 0.00415668860077858 RecoverLoss: 0.1998233237862587
100%|█| 60/60 [01:14<00:00, 1.24s/it, C_loss=0.00582, ClsLoss_100=4.77, ClsLoss_1000=6.37, ClsLoss_10000=7.25, ClsTotalLoss=18.4, JigLoss=1.37, M_loss
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:20<00:00, 5.13s/it]
Epoch: [16][train] TotalLoss: 19.632916323343913 JigLoss: 1.0232878476381304 ClsLoss_100: 4.723408063252767 ClsLoss_1000: 6.555336928367613ClsLoss_10000: 7.190120752652486 ClsTotalLoss(fftotal): 18.468865553538002 AvgConstractiveLoss: 0.049930801638402036 AvgReasonabilityLoss: 0.0908320085921635 AvgRecoverLoss: 4.406435360635318
Epoch: [16][val] JigsawAcc: 0.8 ClusterAcc100: 0.04 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.013333333333333334 ConstractiveLoss: 0.002326355976983905 ReasonabilityLoss: 0.00012519627809524537 RecoverLoss: 0.18016339898109435
100%|█| 60/60 [01:07<00:00, 1.13s/it, C_loss=0.0456, ClsLoss_100=5.02, ClsLoss_1000=6.61, ClsLoss_10000=7.08, ClsTotalLoss=18.7, JigLoss=1.66, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:21<00:00, 5.46s/it]
Epoch: [17][train] TotalLoss: 19.40125710169475 JigLoss: 0.9150347652534644 ClsLoss_100: 4.6697273492813105 ClsLoss_1000: 6.52956008116404 ClsLoss_10000: 7.167783602078757 ClsTotalLoss(fftotal): 18.367071247100828 AvgConstractiveLoss: 0.0530752846505493 AvgReasonabilityLoss: 0.06607597895587485 AvgRecoverLoss: 4.356251256788772
Epoch: [17][val] JigsawAcc: 0.84 ClusterAcc100: 0.0 ClusterAcc1000: 0.0 ClusterAcc10000: 0.0 ClusterAcc(avg): 0.0 ConstractiveLoss: 0.003248931393027306 ReasonabilityLoss: 0.01237701490521431 RecoverLoss: 0.20707530111074446
15%|▏| 9/60 [00:13<01:16, 1.50s/it, C_loss=0.0265, ClsLoss_100=4.53, ClsLoss_1000=6.36, ClsLoss_10000=7.01, ClsTotalLoss=17.9, JigLoss=0.91, M_loss=0
将预训练的模型在下游任务中微调,这里的下游任务是分类任务,(我们的模型将会通过计算分类的可能性 and 真实的标签 的交叉熵损失来 fine-tuned )。
1. Download pre-trained ImageMol
You can download pre-trained model and push it into the folder ckpts/
2. Finetune with pre-trained ImageMol
a) You can download molecular property prediciton datasets, CYP450 datasets and SARS-CoV-2 datasets and put it into datasets/finetuning/
b) The usage is as follows:
usage: finetune.py [-h] [--dataset DATASET] [--dataroot DATAROOT] [--gpu GPU] [--workers WORKERS] [--lr LR] [--weight_decay WEIGHT_DECAY] [--momentum MOMENTUM] [--seed SEED] [--runseed RUNSEED] [--split {random,stratified,scaffold,random_scaffold,scaffold_balanced}] [--epochs EPOCHS] [--start_epoch START_EPOCH] [--batch BATCH] [--resume PATH] [--imageSize IMAGESIZE] [--image_model IMAGE_MODEL] [--image_aug] [--task_type {classification,regression}] [--save_finetune_ckpt {0,1}] [--log_dir LOG_DIR]
c) You can run ImageMol by simply using the following code:
python finetune.py --gpu ${gpu_no} \ --save_finetune_ckpt ${save_finetune_ckpt} \ --log_dir ${log_dir} \ --dataroot ${dataroot} \ --dataset ${dataset} \ --task_type ${task_type} \ --resume ${resume} \ --image_aug \ --lr ${lr} \ --batch ${batch} \ --epochs ${epoch}
For example:
python finetune.py --gpu 0 \ --save_finetune_ckpt 1 \ --log_dir ./logs/toxcast \ --dataroot ./datasets/finetuning/benchmarks \ --dataset toxcast \ --task_type classification \ --resume ./ckpts/ImageMol.pth.tar \ --image_aug \ --lr 0.5 \ --batch 64 \ --epochs 20
Note: You can tune more hyper-parameters during fine-tuning (see b) Usage).
文档结构为:
这里我运行的命令为:
python finetune.py --gpu 0 --save_finetune_ckpt 1 --log_dir ./logs/toxcast --dataroot ./datasets/finetuning/MPP/classification/ --dataset toxcast --task_type classification --resume ./ckpts/ImageMol.pth.tar --image_aug --lr 0.5 --batch 64 --epochs 20
(imagemol) D:\pycharm_workspace\1\ImageMol>python finetune.py --gpu 0 --save_finetune_ckpt 1 --log_dir ./logs/toxcast --dataroot ./datasets/finetuning/MPP/classification/ --dataset toxcast --task_type classification --resume ./ckpts/ImageMol.pth.tar --image_aug --lr 0.5 --batch 64 --epochs 20
Architecture: ResNet18
eval_metric: rocauc
=> loading checkpoint './ckpts/ImageMol.pth.tar'
resume model info: arch: ResNet18
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=617, bias=True)
)
params: {'total_params': 11493033, 'total_trainable_params': 11493033}
[train epoch 0] loss: 0.243: 100%|███████████████████████████████████████████████████| 108/108 [03:11<00:00, 1.78s/it]
[valid epoch 0] loss: 0.210: 100%|███████████████████████████████████████████████████| 108/108 [00:35<00:00, 3.01it/s]
[valid epoch 0] loss: 0.215: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 0] loss: 0.242: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.10it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 0, 'patience': 0, 'Loss': 0.21016514742815937, 'Train': 0.692967624681878, 'Validation': 0.6699200988533057, 'Test': 0.6675434573214044}
[train epoch 1] loss: 0.208: 100%|███████████████████████████████████████████████████| 108/108 [01:50<00:00, 1.02s/it]
[valid epoch 1] loss: 0.199: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00, 2.96it/s]
[valid epoch 1] loss: 0.207: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.14it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 1] loss: 0.230: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.05it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 1, 'patience': 0, 'Loss': 0.19889598422580296, 'Train': 0.736537038469584, 'Validation': 0.6994855864691358, 'Test': 0.7106267564003038}
[train epoch 2] loss: 0.205: 100%|███████████████████████████████████████████████████| 108/108 [01:49<00:00, 1.02s/it]
[valid epoch 2] loss: 0.191: 100%|███████████████████████████████████████████████████| 108/108 [00:37<00:00, 2.89it/s]
[valid epoch 2] loss: 0.204: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.03it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 2] loss: 0.225: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.09it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 2, 'patience': 0, 'Loss': 0.19094857463130244, 'Train': 0.7681750106520694, 'Validation': 0.713344336728154, 'Test': 0.7183564591204591}
[train epoch 3] loss: 0.197: 100%|███████████████████████████████████████████████████| 108/108 [01:51<00:00, 1.03s/it]
[valid epoch 3] loss: 0.188: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00, 3.00it/s]
[valid epoch 3] loss: 0.204: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.13it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 3] loss: 0.225: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.12it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 3, 'patience': 0, 'Loss': 0.18818754620022243, 'Train': 0.7921560594726607, 'Validation': 0.7106094323867567, 'Test': 0.7319484429668179}
[train epoch 4] loss: 0.190: 100%|███████████████████████████████████████████████████| 108/108 [01:48<00:00, 1.00s/it]
[valid epoch 4] loss: 0.186: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00, 2.96it/s]
[valid epoch 4] loss: 0.207: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.13it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 4] loss: 0.237: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.15it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 4, 'patience': 1, 'Loss': 0.18566039756492334, 'Train': 0.8042472190567201, 'Validation': 0.6931130950238147, 'Test': 0.7199134337688832}
[train epoch 5] loss: 0.188: 100%|███████████████████████████████████████████████████| 108/108 [01:52<00:00, 1.04s/it]
[valid epoch 5] loss: 0.430: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00, 2.98it/s]
[valid epoch 5] loss: 0.447: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.09it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 5] loss: 0.464: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.12it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 5, 'patience': 2, 'Loss': 0.4297037831059209, 'Train': 0.6456776040426601, 'Validation': 0.5981636500016693, 'Test': 0.6110698962504955}
[train epoch 6] loss: 0.183: 100%|███████████████████████████████████████████████████| 108/108 [01:49<00:00, 1.02s/it]
[valid epoch 6] loss: 0.177: 100%|███████████████████████████████████████████████████| 108/108 [00:37<00:00, 2.89it/s]
[valid epoch 6] loss: 0.202: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.07it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 6] loss: 0.229: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.08it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 6, 'patience': 3, 'Loss': 0.17700029302526404, 'Train': 0.8396707209323335, 'Validation': 0.7284599584377031, 'Test': 0.7407342942104}
[train epoch 7] loss: 0.178: 100%|███████████████████████████████████████████████████| 108/108 [01:52<00:00, 1.04s/it]
[valid epoch 7] loss: 0.167: 100%|███████████████████████████████████████████████████| 108/108 [00:35<00:00, 3.00it/s]
[valid epoch 7] loss: 0.206: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.04it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 7] loss: 0.226: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.13it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 7, 'patience': 0, 'Loss': 0.16746581042254413, 'Train': 0.8556398733428161, 'Validation': 0.731332185707106, 'Test': 0.7492089206757857}
[train epoch 8] loss: 0.175: 100%|███████████████████████████████████████████████████| 108/108 [01:47<00:00, 1.00it/s]
[valid epoch 8] loss: 0.167: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00, 2.98it/s]
[valid epoch 8] loss: 0.208: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 8] loss: 0.222: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.11it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 8, 'patience': 0, 'Loss': 0.16673416561550564, 'Train': 0.8644607852261611, 'Validation': 0.7262760021080655, 'Test': 0.7608370643734503}
[train epoch 9] loss: 0.169: 100%|███████████████████████████████████████████████████| 108/108 [01:48<00:00, 1.01s/it]
[valid epoch 9] loss: 0.159: 100%|███████████████████████████████████████████████████| 108/108 [00:35<00:00, 3.03it/s]
[valid epoch 9] loss: 0.210: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.14it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 9] loss: 0.232: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.08it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 9, 'patience': 1, 'Loss': 0.15940298857512297, 'Train': 0.8746694145575502, 'Validation': 0.7254201813375953, 'Test': 0.7528834309399343}
[train epoch 10] loss: 0.167: 100%|██████████████████████████████████████████████████| 108/108 [01:48<00:00, 1.01s/it]
[valid epoch 10] loss: 0.163: 100%|██████████████████████████████████████████████████| 108/108 [00:35<00:00, 3.01it/s]
[valid epoch 10] loss: 0.203: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00, 1.99it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 10] loss: 0.225: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00, 1.84it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 10, 'patience': 2, 'Loss': 0.16331072206850406, 'Train': 0.8759709469067016, 'Validation': 0.7251833198910242, 'Test': 0.7444111139117744}
[train epoch 11] loss: 0.164: 100%|██████████████████████████████████████████████████| 108/108 [01:53<00:00, 1.05s/it]
[valid epoch 11] loss: 0.155: 100%|██████████████████████████████████████████████████| 108/108 [00:37<00:00, 2.88it/s]
[valid epoch 11] loss: 0.204: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00, 1.98it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 11] loss: 0.222: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00, 1.86it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 11, 'patience': 3, 'Loss': 0.15538870846783673, 'Train': 0.8912424252365218, 'Validation': 0.7240295542159721, 'Test': 0.753792204644943}
[train epoch 12] loss: 0.160: 100%|██████████████████████████████████████████████████| 108/108 [01:51<00:00, 1.03s/it]
[valid epoch 12] loss: 0.208: 100%|██████████████████████████████████████████████████| 108/108 [00:37<00:00, 2.86it/s]
[valid epoch 12] loss: 0.257: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00, 1.98it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 12] loss: 0.277: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00, 1.92it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 12, 'patience': 4, 'Loss': 0.20842158352887188, 'Train': 0.8415256259318344, 'Validation': 0.6897028119621558, 'Test': 0.7009206570443501}
[train epoch 13] loss: 0.156: 100%|██████████████████████████████████████████████████| 108/108 [01:51<00:00, 1.03s/it]
[valid epoch 13] loss: 0.152: 100%|██████████████████████████████████████████████████| 108/108 [00:35<00:00, 3.01it/s]
[valid epoch 13] loss: 0.209: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.11it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 13] loss: 0.232: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.12it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 13, 'patience': 5, 'Loss': 0.15160143816912616, 'Train': 0.8998601673869299, 'Validation': 0.7268795973894058, 'Test': 0.7505586760702562}
[train epoch 14] loss: 0.155: 100%|██████████████████████████████████████████████████| 108/108 [01:47<00:00, 1.00it/s]
[valid epoch 14] loss: 0.146: 100%|██████████████████████████████████████████████████| 108/108 [00:35<00:00, 3.03it/s]
[valid epoch 14] loss: 0.211: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 14] loss: 0.237: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.11it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 14, 'patience': 6, 'Loss': 0.1462357309129503, 'Train': 0.9063451701636047, 'Validation': 0.7288304309373426, 'Test': 0.7492306562547975}
[train epoch 15] loss: 0.152: 100%|██████████████████████████████████████████████████| 108/108 [01:47<00:00, 1.00it/s]
[valid epoch 15] loss: 0.143: 100%|██████████████████████████████████████████████████| 108/108 [00:35<00:00, 3.03it/s]
[valid epoch 15] loss: 0.204: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 15] loss: 0.231: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.11it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 15, 'patience': 7, 'Loss': 0.1434713469611274, 'Train': 0.913865910273068, 'Validation': 0.7261471932487874, 'Test': 0.7458198395498474}
[train epoch 16] loss: 0.149: 100%|██████████████████████████████████████████████████| 108/108 [01:47<00:00, 1.00it/s]
[valid epoch 16] loss: 0.142: 100%|██████████████████████████████████████████████████| 108/108 [00:36<00:00, 2.97it/s]
[valid epoch 16] loss: 0.209: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 16] loss: 0.232: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.09it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 16, 'patience': 8, 'Loss': 0.14186685173599808, 'Train': 0.919563304190028, 'Validation': 0.7392030101383237, 'Test': 0.7529772831075494}
[train epoch 17] loss: 0.148: 100%|██████████████████████████████████████████████████| 108/108 [01:49<00:00, 1.01s/it]
[valid epoch 17] loss: 0.140: 100%|██████████████████████████████████████████████████| 108/108 [00:36<00:00, 2.94it/s]
[valid epoch 17] loss: 0.207: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.01it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 17] loss: 0.237: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.03it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 17, 'patience': 0, 'Loss': 0.14025064750953956, 'Train': 0.9222371081994448, 'Validation': 0.7333314785222289, 'Test': 0.7468212697003538}
[train epoch 18] loss: 0.146: 100%|██████████████████████████████████████████████████| 108/108 [01:49<00:00, 1.01s/it]
[valid epoch 18] loss: 0.135: 100%|██████████████████████████████████████████████████| 108/108 [00:36<00:00, 2.96it/s]
[valid epoch 18] loss: 0.208: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.02it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 18] loss: 0.238: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00, 2.00it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 18, 'patience': 1, 'Loss': 0.1350550651550293, 'Train': 0.9267112674770283, 'Validation': 0.744386517934743, 'Test': 0.7505391845271615}
[train epoch 19] loss: 0.141: 100%|██████████████████████████████████████████████████| 108/108 [01:50<00:00, 1.02s/it]
[valid epoch 19] loss: 0.132: 100%|██████████████████████████████████████████████████| 108/108 [00:36<00:00, 2.97it/s]
[valid epoch 19] loss: 0.212: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.06it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 19] loss: 0.235: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00, 2.11it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 19, 'patience': 0, 'Loss': 0.13230810342011629, 'Train': 0.9326320968407155, 'Validation': 0.7282062713930527, 'Test': 0.7534077089166729}
final results: highest_valid: 0.744, final_train: 0.927, final_test: 0.751
最后会在logs下生成微调后的模型:
To ensure the reproducibility of ImageMol, we provided finetuned models for eight datasets, including:
BBBP
Tox21
ClinTox
HIV
BACE
SIDER
MUV
ToxCast
You can evaluate the finetuned model by using the following command:
python evaluate.py --dataroot ${dataroot} \ --dataset ${dataset} \ --task_type ${task_type} \ --resume ${resume} \ --batch ${batch}
For example:
python evaluate.py --dataroot ./datasets/finetuning/benchmarks \ --dataset toxcast \ --task_type classification \ --resume ./toxcast.pth \ --batch 128
python evaluate.py --dataroot ./datasets/finetuning/MPP/classification --dataset toxcast --task_type classification --resume ./ckpts/toxcast.pth --batch 128
More about GradCAM heatmap can be found from this link: https://drive.google.com/file/d/1uu3Q6WLz8bJqcDaHEG84o3mFvemHoA2v/view?usp=sharing
To facilitate observation of high-confidence regions in the GradCAM heatmap, we use a confidence to filter out lower-confidence regions, which can be found from this link: https://drive.google.com/file/d/1631kSSiM_FSRBBkfh7PwI5p3LGqYYpMc/view?usp=sharing
We also provide a script to generate GradCAM heatmaps:
usage: main.py [-h] [--image_model IMAGE_MODEL] --resume PATH --img_path IMG_PATH --gradcam_save_path GRADCAM_SAVE_PATH [--thresh THRESH]
you can run the following script:
python main.py --resume ${resume} \ --img_path ${img_path} \ --gradcam_save_path ${gradcam_save_path} \ --thresh ${thresh}
If you want to process your own dataset and obtain molecular images, use the following steps:
preprocess_list(smiles)
of this link to process your raw SMILES data;dataloader.image_dataloader.Smiles2Img(smis, size=224, savePath=None)
https://github.com/HongxinXiang/ImageMol