Accurate prediction of molecular targets using a self-supervised image rep...(代码理解)

GitHub - HongxinXiang/ImageMol: ImageMol is a molecular image-based pre-training deep learning framework for computational drug discovery.

我在windows上下载并测试的 


首先是其中每个文件夹中文件的意义:

1. preparing dataset

在第一步“1. preparing dataset”中, 由 data.CSV 文件中的smiles列转换成 molecular image(在224文件夹下) ,并且data.CSV生成 “data_for_pretrain.csv”(只是原来smiles列变成了filename)

也就是说224中的每一张照片对应data.csv / data_for_pretrain.csv的每一行

Accurate prediction of molecular targets using a self-supervised image rep...(代码理解)_第1张图片

 2. fine-tune

其中计算的逻辑: 

我们要使用分类器进行fine-tune,将 224中的图像输入进model,得到 维度为“(64,617)”的pred,然后把它与data_for_pretrain.csv中真实的标签 labels(64,617)计算损失,其中64是batch-size,617为每一个图片对应的label的长度(我认为此时fine-tune是一种监督学习,它是利用有标签的数据进行微调,然后预训练模型使通过self-supervise的方式进行“无监督方式”训练模型)【总结来看:是用过把image的信息embedding,然后直接预测label,中间过程并没有生成任何的image...,与使用image进行“OG, OP”不一样,因为“OG, OP”是要生成image的】

# model(image) <==> label 他们之间计算损失
# 化学分子性质预测主要包含对已经预训练好的模型拿来进行finetune,下游任务主要是在模型上添加线性分类器,来预测下游图形标签,然后再以端到端的方式进行微调。(其中label是分子的性质

 Accurate prediction of molecular targets using a self-supervised image rep...(代码理解)_第2张图片


Install environment

1. GPU environment

CUDA 10.1

2. create a new conda environment

conda create -n imagemol python=3.7.3

conda activate imagemol

3. download some packages

conda install -c rdkit rdkit

windows:

  • pip install https://download.pytorch.org/whl/cu101/torch-1.4.0-cp37-cp37m-win_amd64.whl
  • pip install https://download.pytorch.org/whl/cu101/torchvision-0.5.0-cp37-cp37m-win_amd64.whl

linux:

  • pip install https://download.pytorch.org/whl/cu101/torch-1.4.0-cp37-cp37m-linux_x86_64.whl
  • pip install https://download.pytorch.org/whl/cu101/torchvision-0.5.0-cp37-cp37m-linux_x86_64.whl

pip install torch-cluster torch-scatter torch-sparse torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.4.0%2Bcu101.html

pip install -r requirements.txt

source activate imagemol

注意:这里有一个错误,github上requirements.txt写的 scikit-learn==1.21.2,但是现在最新版本的scikit-learn==1.1.2,并且它还必须python>=3.8
所以这里,我安装的是“ numpy-1.21.6 scikit-learn-1.0.2


Pretraining

1. preparing dataset

data.CSV文件中的smiles列生成224*224的molecular image(每一行都可以生成一个分子图像), and data.CSV文件生成“data_for_pretrain.csv”(只是原来的smiles列变成了filename列)

Download pretraining data and put it into ./datasets/pretraining/data/

Preprocess dataset:(这个时间比较长,这里要处理接近一千万份文件,我总共花费了22个小时

python ./data_process/smiles2img_pretrain.py --dataroot ./datasets/pretraining/ --dataset data

 Note: You can find the toy dataset in ./datasets/toy/pretraining/

2. start to pretrain

训练一个预训练模型,从而用于下游任务

Usage:

usage: pretrain.py [-h] [--lr LR] [--wd WD] [--workers WORKERS]
                   [--val_workers VAL_WORKERS] [--epochs EPOCHS]
                   [--start_epoch START_EPOCH] [--batch BATCH]
                   [--momentum MOMENTUM] [--checkpoints CHECKPOINTS]
                   [--seed SEED] [--dataroot DATAROOT] [--dataset DATASET]
                   [--ckpt_dir CKPT_DIR] [--modelname {ResNet18}]
                   [--verbose] [--ngpu NGPU] [--gpu GPU] [--nc NC] [--ndf NDF]
                   [--imageSize IMAGESIZE] [--Jigsaw_lambda JIGSAW_LAMBDA]
                   [--cluster_lambda CLUSTER_LAMBDA]
                   [--constractive_lambda CONSTRACTIVE_LAMBDA]
                   [--matcher_lambda MATCHER_LAMBDA]
                   [--is_recover_training IS_RECOVER_TRAINING]
                   [--cl_mask_type {random_mask,rectangle_mask,mix_mask}]
                   [--cl_mask_shape_h CL_MASK_SHAPE_H]
                   [--cl_mask_shape_w CL_MASK_SHAPE_W]
                   [--cl_mask_ratio CL_MASK_RATIO]

Code to pretrain:

python pretrain.py --ckpt_dir ./ckpts/pretraining/ \
                   --checkpoints 1 \
                   --Jigsaw_lambda 1 \
                   --cluster_lambda 1 \
                   --constractive_lambda 1 \
                   --matcher_lambda 1 \
                   --is_recover_training 1 \
                   --batch 256 \
                   --dataroot ./datasets/pretraining/ \
                   --dataset data \
                   --gpu 0,1,2,3 \
                   --ngpu 4

For testing, you can simply pre-train ImageMol using single GPU on toy dataset:

python pretrain.py --ckpt_dir ./ckpts/pretraining-toy/ \
                   --checkpoints 1 \
                   --Jigsaw_lambda 1 \
                   --cluster_lambda 1 \
                   --constractive_lambda 1 \
                   --matcher_lambda 1 \
                   --is_recover_training 1 \
                   --batch 16 \
                   --dataroot ./datasets/toy/pretraining/ \
                   --dataset data \
                   --gpu 0 \
                   --ngpu 1

运行与结果:

python pretrain.py --ckpt_dir ./ckpts/pretraining-toy/ --checkpoints 1 --Jigsaw_lambda 1 --cluster_lambda 1 --constractive_lambda 1 --matcher_lambda 1 --is_recover_training 1 --batch 16 --dataroot ./datasets/toy/pretraining --dataset data --gpu 0 --ngpu 1

我使用测试训练进行了训练了17个epoch,总共生成了17个预训练模型:


(imagemol) D:\pycharm_workspace\1\ImageMol>python pretrain.py --ckpt_dir ./ckpts/pretraining-toy/ --checkpoints 1 --Jigsaw_lambda 1 --cluster_lambda 1 --constractive_lambda 1 --matcher_lambda 1 --is_recover_training 1 --batch 16 --dataroot ./datasets/toy/pretraining --dataset data --gpu 0 --ngpu 1
ImageMol(
  (embedding_layer): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (7): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (downsample): Sequential(
          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): BasicBlock(
        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (8): AdaptiveAvgPool2d(output_size=(1, 1))
  )
  (bn): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (jigsaw_classifier): Linear(in_features=512, out_features=101, bias=True)
  (class_classifier1): Linear(in_features=512, out_features=100, bias=True)
  (class_classifier2): Linear(in_features=512, out_features=1000, bias=True)
  (class_classifier3): Linear(in_features=512, out_features=10000, bias=True)
)
Matcher(
  (fc): Linear(in_features=512, out_features=2, bias=True)
  (logic): LogSoftmax()
)
generator(
  (projection): Sequential(
    (0): Linear(in_features=512, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (netG): Sequential(
    (0): ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
netlocalD(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)
100%|█| 60/60 [02:29<00:00,  2.50s/it, C_loss=0.02, ClsLoss_100=4.36, ClsLoss_1000=7.16, ClsLoss_10000=9.29, ClsTotalLo
100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:19<00:00,  4.79s/it]
Epoch: [1][train]       TotalLoss: 23.96166973114014    JigLoss: 1.9193571281929807     ClsLoss_100: 4.92697454293569  ClsLoss_1000: 7.11983088652293   ClsLoss_10000: 9.33748327891032 ClsTotalLoss(fftotal): 21.384288565317785       AvgConstractiveLoss: 0.1648771699983627 AvgReasonabilityLoss: 0.49314683731645315       AvgRecoverLoss: 5.0754999422313025
Epoch: [1][val] JigsawAcc: 0.82 ClusterAcc100: 0.0      ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.0    ConstractiveLoss: 0.004153344817459583  ReasonabilityLoss: 0.013568593263626099 RecoverLoss: 0.2023082345724106
100%|█| 60/60 [01:06<00:00,  1.12s/it, C_loss=0.0138, ClsLoss_100=4.49, ClsLoss_1000=7.29, ClsLoss_10000=8.7, ClsTotalL
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.66s/it]
Epoch: [2][train]       TotalLoss: 21.496857229868567   JigLoss: 1.0841593230764073     ClsLoss_100: 4.781556669871013  ClsLoss_1000: 6.793236716588339ClsLoss_10000: 8.56167837778727  ClsTotalLoss(fftotal): 20.136471748352054       AvgConstractiveLoss: 0.07614539431718487        AvgReasonabilityLoss: 0.20008064570526282       AvgRecoverLoss: 3.5995476151506103
Epoch: [2][val] JigsawAcc: 0.86 ClusterAcc100: 0.04     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.013333333333333334   ConstractiveLoss: 0.0020515684410929683 ReasonabilityLoss: 0.002394016683101654 RecoverLoss: 0.29341010212898255
100%|█| 60/60 [01:06<00:00,  1.11s/it, C_loss=0.0453, ClsLoss_100=4.83, ClsLoss_1000=6.7, ClsLoss_10000=8.39, ClsTotalLoss=19.9, JigLoss=1.68, M_loss=0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:20<00:00,  5.18s/it]
Epoch: [3][train]       TotalLoss: 20.89204044342041    JigLoss: 1.003493670374155      ClsLoss_100: 4.813545378049215  ClsLoss_1000: 6.738298972447714ClsLoss_10000: 8.154170513153076 ClsTotalLoss(fftotal): 19.706014760335286       AvgConstractiveLoss: 0.04975322242826223        AvgReasonabilityLoss: 0.13277878270794943       AvgRecoverLoss: 3.6251505697766935
Epoch: [3][val] JigsawAcc: 0.86 ClusterAcc100: 0.04     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.013333333333333334   ConstractiveLoss: 0.0015005466900765898 ReasonabilityLoss: 0.01395109474658966  RecoverLoss: 0.20824745893478394
100%|█| 60/60 [01:08<00:00,  1.14s/it, C_loss=0.0533, ClsLoss_100=4.81, ClsLoss_1000=6.56, ClsLoss_10000=8, ClsTotalLoss=19.4, JigLoss=0.897, M_loss=0.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:19<00:00,  4.94s/it]
Epoch: [4][train]       TotalLoss: 20.512035369873047   JigLoss: 0.9941644340753556     ClsLoss_100: 4.804758842786154  ClsLoss_1000: 6.710355981191003ClsLoss_10000: 7.865747348467508 ClsTotalLoss(fftotal): 19.38086214065552        AvgConstractiveLoss: 0.048341524026667075       AvgReasonabilityLoss: 0.08866713357468446       AvgRecoverLoss: 4.077671728531518
Epoch: [4][val] JigsawAcc: 0.76 ClusterAcc100: 0.0      ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.0    ConstractiveLoss: 0.004614624008536339  ReasonabilityLoss: 0.008996209204196932 RecoverLoss: 0.31746475636959076
100%|█| 60/60 [01:07<00:00,  1.12s/it, C_loss=0.07, ClsLoss_100=4.49, ClsLoss_1000=6.72, ClsLoss_10000=7.93, ClsTotalLoss=19.1, JigLoss=3.02, M_loss=1.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.62s/it]
Epoch: [5][train]       TotalLoss: 20.410388692220042   JigLoss: 1.0800357540448504     ClsLoss_100: 4.7787369966506965 ClsLoss_1000: 6.687890203793843ClsLoss_10000: 7.672290404637656 ClsTotalLoss(fftotal): 19.13891773223877        AvgConstractiveLoss: 0.05595408985391261        AvgReasonabilityLoss: 0.13548127884666128       AvgRecoverLoss: 4.075054861356814
Epoch: [5][val] JigsawAcc: 0.84 ClusterAcc100: 0.0      ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.0    ConstractiveLoss: 0.002793522924184799  ReasonabilityLoss: 0.0062981796264648445        RecoverLoss: 0.18601167559623719
100%|█| 60/60 [01:06<00:00,  1.11s/it, C_loss=0.0246, ClsLoss_100=4.8, ClsLoss_1000=7.11, ClsLoss_10000=7.69, ClsTotalLoss=19.6, JigLoss=3.16, M_loss=1
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:21<00:00,  5.27s/it]
Epoch: [6][train]       TotalLoss: 20.339342339833575   JigLoss: 1.1087447846929228     ClsLoss_100: 4.799536720911663  ClsLoss_1000: 6.672529379526773ClsLoss_10000: 7.573689786593119 ClsTotalLoss(fftotal): 19.04575576782228        AvgConstractiveLoss: 0.053185406218593315       AvgReasonabilityLoss: 0.13165633815030256       AvgRecoverLoss: 4.161377739906311
Epoch: [6][val] JigsawAcc: 0.84 ClusterAcc100: 0.04     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.013333333333333334   ConstractiveLoss: 0.0014511872828006744 ReasonabilityLoss: 0.005211712941527367 RecoverLoss: 0.1715515339374542
100%|█| 60/60 [01:07<00:00,  1.13s/it, C_loss=0.0319, ClsLoss_100=4.63, ClsLoss_1000=6.67, ClsLoss_10000=7.54, ClsTotalLoss=18.8, JigLoss=0.878, M_loss
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:22<00:00,  5.57s/it]
Epoch: [7][train]       TotalLoss: 20.192618624369306   JigLoss: 1.0693875233332315     ClsLoss_100: 4.75707601706187   ClsLoss_1000: 6.66301953792572 ClsLoss_10000: 7.496208945910134 ClsTotalLoss(fftotal): 18.916304334004717       AvgConstractiveLoss: 0.047720210254192354       AvgReasonabilityLoss: 0.15920654994746053       AvgRecoverLoss: 4.174393929044406
Epoch: [7][val] JigsawAcc: 0.78 ClusterAcc100: 0.02     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.006666666666666667   ConstractiveLoss: 0.0035598852671682836 ReasonabilityLoss: 0.0006275869905948638        RecoverLoss: 0.431844232082367
100%|█| 60/60 [01:06<00:00,  1.11s/it, C_loss=0.0174, ClsLoss_100=4.66, ClsLoss_1000=6.93, ClsLoss_10000=7.62, ClsTotalLoss=19.2, JigLoss=1.58, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.55s/it]
Epoch: [8][train]       TotalLoss: 19.8997561454773     JigLoss: 0.9543801690141359     ClsLoss_100: 4.7688616196314495 ClsLoss_1000: 6.649607149759929ClsLoss_10000: 7.420435182253519 ClsTotalLoss(fftotal): 18.83890390396117        AvgConstractiveLoss: 0.04308367188399036        AvgReasonabilityLoss: 0.06338832139347991       AvgRecoverLoss: 4.2416440991063915
Epoch: [8][val] JigsawAcc: 0.74 ClusterAcc100: 0.02     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.006666666666666667   ConstractiveLoss: 0.0024805641174316406 ReasonabilityLoss: 0.0048112948238849635        RecoverLoss: 0.25158809542655947
100%|█| 60/60 [01:06<00:00,  1.11s/it, C_loss=0.0381, ClsLoss_100=4.57, ClsLoss_1000=6.51, ClsLoss_10000=7.65, ClsTotalLoss=18.7, JigLoss=0.78, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.75s/it]
Epoch: [9][train]       TotalLoss: 20.007253615061437   JigLoss: 1.0684301803509397     ClsLoss_100: 4.787256534894307  ClsLoss_1000: 6.628996245066327ClsLoss_10000: 7.368737761179607 ClsTotalLoss(fftotal): 18.78499059677124        AvgConstractiveLoss: 0.045559848969181374       AvgReasonabilityLoss: 0.10827299573769171       AvgRecoverLoss: 4.120056470980247
Epoch: [9][val] JigsawAcc: 0.84 ClusterAcc100: 0.02     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.006666666666666667   ConstractiveLoss: 0.002133065667003393  ReasonabilityLoss: 0.0024875566363334657        RecoverLoss: 0.19161114037036897
100%|█| 60/60 [01:07<00:00,  1.12s/it, C_loss=0.0287, ClsLoss_100=4.74, ClsLoss_1000=5.96, ClsLoss_10000=7.37, ClsTotalLoss=18.1, JigLoss=1.38, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:23<00:00,  5.88s/it]
Epoch: [10][train]      TotalLoss: 19.906644948323567   JigLoss: 1.0861985241373382     ClsLoss_100: 4.752160032590231  ClsLoss_1000: 6.603877162933351ClsLoss_10000: 7.332753666241961 ClsTotalLoss(fftotal): 18.688790893554682       AvgConstractiveLoss: 0.04013523093114296        AvgReasonabilityLoss: 0.09152033478021618       AvgRecoverLoss: 4.039306613306204
Epoch: [10][val]        JigsawAcc: 0.82 ClusterAcc100: 0.02     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.006666666666666667  ConstractiveLoss: 0.0020973937958478927  ReasonabilityLoss: 0.0024364076554775236        RecoverLoss: 0.18724668622016907
100%|█| 60/60 [01:08<00:00,  1.14s/it, C_loss=0.0463, ClsLoss_100=4.77, ClsLoss_1000=6.7, ClsLoss_10000=7.51, ClsTotalLoss=19, JigLoss=1.42, M_loss=0.0
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:29<00:00,  7.25s/it]
Epoch: [11][train]      TotalLoss: 19.741016260782875   JigLoss: 0.960937575995922      ClsLoss_100: 4.751010529200236  ClsLoss_1000: 6.601791771252948ClsLoss_10000: 7.31111691792806  ClsTotalLoss(fftotal): 18.663919194539393       AvgConstractiveLoss: 0.040693634655326605       AvgReasonabilityLoss: 0.07546558417379856       AvgRecoverLoss: 4.537612893184026
Epoch: [11][val]        JigsawAcc: 0.82 ClusterAcc100: 0.06     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.02   ConstractiveLoss: 0.00219877889379859   ReasonabilityLoss: 0.00048301540315151217       RecoverLoss: 0.22352851390838624
100%|█| 60/60 [01:07<00:00,  1.13s/it, C_loss=0.0232, ClsLoss_100=4.68, ClsLoss_1000=6.71, ClsLoss_10000=7.56, ClsTotalLoss=19, JigLoss=0.0433, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:20<00:00,  5.13s/it]
Epoch: [12][train]      TotalLoss: 19.715614541371657   JigLoss: 1.000559947391351      ClsLoss_100: 4.731101123491921  ClsLoss_1000: 6.583861041069033ClsLoss_10000: 7.277890539169311 ClsTotalLoss(fftotal): 18.592852783203124       AvgConstractiveLoss: 0.03278368460790566        AvgReasonabilityLoss: 0.08941795468175164       AvgRecoverLoss: 4.371748948221404
Epoch: [12][val]        JigsawAcc: 0.78 ClusterAcc100: 0.02     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.006666666666666667  ConstractiveLoss: 0.001537955142557621   ReasonabilityLoss: 0.009752252995967867 RecoverLoss: 0.18312466382980347
100%|█| 60/60 [01:07<00:00,  1.12s/it, C_loss=0.0647, ClsLoss_100=4.85, ClsLoss_1000=6.72, ClsLoss_10000=7.55, ClsTotalLoss=19.1, JigLoss=0.0258, M_los
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:23<00:00,  5.96s/it]
Epoch: [13][train]      TotalLoss: 19.82162863413493    JigLoss: 1.0898563351792594     ClsLoss_100: 4.738014833132426  ClsLoss_1000: 6.585498579343159ClsLoss_10000: 7.261646850903826 ClsTotalLoss(fftotal): 18.585160191853845       AvgConstractiveLoss: 0.03763424041680991        AvgReasonabilityLoss: 0.10897765451421342       AvgRecoverLoss: 4.363648629685243
Epoch: [13][val]        JigsawAcc: 0.74 ClusterAcc100: 0.0      ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.0    ConstractiveLoss: 0.0021116061881184577 ReasonabilityLoss: 0.0005279695987701416        RecoverLoss: 0.21621067449450493
100%|█| 60/60 [01:07<00:00,  1.12s/it, C_loss=0.0384, ClsLoss_100=4.52, ClsLoss_1000=6.62, ClsLoss_10000=7.4, ClsTotalLoss=18.5, JigLoss=0.752, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.57s/it]
Epoch: [14][train]      TotalLoss: 19.565531730651863   JigLoss: 0.9612024967869126     ClsLoss_100: 4.7033114989598594 ClsLoss_1000: 6.577907792727153ClsLoss_10000: 7.236068709691365 ClsTotalLoss(fftotal): 18.51728795369466        AvgConstractiveLoss: 0.040761741731936744       AvgReasonabilityLoss: 0.04627952212467788       AvgRecoverLoss: 4.573408140925071
Epoch: [14][val]        JigsawAcc: 0.84 ClusterAcc100: 0.0      ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.0    ConstractiveLoss: 0.0019987770169973373 ReasonabilityLoss: 0.002506262362003326 RecoverLoss: 0.19087217301130294
100%|█| 60/60 [01:07<00:00,  1.13s/it, C_loss=0.00878, ClsLoss_100=4.72, ClsLoss_1000=6.57, ClsLoss_10000=7.43, ClsTotalLoss=18.7, JigLoss=0.0542, M_lo
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:29<00:00,  7.25s/it]
Epoch: [15][train]      TotalLoss: 19.65483560562134    JigLoss: 1.0458930411065623     ClsLoss_100: 4.722701446215311  ClsLoss_1000: 6.552511111895243ClsLoss_10000: 7.212150192260744 ClsTotalLoss(fftotal): 18.48736282984416        AvgConstractiveLoss: 0.045769707330813036       AvgReasonabilityLoss: 0.07581018048028149       AvgRecoverLoss: 4.728145783022047
Epoch: [15][val]        JigsawAcc: 0.82 ClusterAcc100: 0.0      ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.0    ConstractiveLoss: 0.0038376058358699085 ReasonabilityLoss: 0.00415668860077858  RecoverLoss: 0.1998233237862587
100%|█| 60/60 [01:14<00:00,  1.24s/it, C_loss=0.00582, ClsLoss_100=4.77, ClsLoss_1000=6.37, ClsLoss_10000=7.25, ClsTotalLoss=18.4, JigLoss=1.37, M_loss
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:20<00:00,  5.13s/it]
Epoch: [16][train]      TotalLoss: 19.632916323343913   JigLoss: 1.0232878476381304     ClsLoss_100: 4.723408063252767  ClsLoss_1000: 6.555336928367613ClsLoss_10000: 7.190120752652486 ClsTotalLoss(fftotal): 18.468865553538002       AvgConstractiveLoss: 0.049930801638402036       AvgReasonabilityLoss: 0.0908320085921635        AvgRecoverLoss: 4.406435360635318
Epoch: [16][val]        JigsawAcc: 0.8  ClusterAcc100: 0.04     ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.013333333333333334  ConstractiveLoss: 0.002326355976983905   ReasonabilityLoss: 0.00012519627809524537       RecoverLoss: 0.18016339898109435
100%|█| 60/60 [01:07<00:00,  1.13s/it, C_loss=0.0456, ClsLoss_100=5.02, ClsLoss_1000=6.61, ClsLoss_10000=7.08, ClsTotalLoss=18.7, JigLoss=1.66, M_loss=
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:21<00:00,  5.46s/it]
Epoch: [17][train]      TotalLoss: 19.40125710169475    JigLoss: 0.9150347652534644     ClsLoss_100: 4.6697273492813105 ClsLoss_1000: 6.52956008116404 ClsLoss_10000: 7.167783602078757 ClsTotalLoss(fftotal): 18.367071247100828       AvgConstractiveLoss: 0.0530752846505493 AvgReasonabilityLoss: 0.06607597895587485       AvgRecoverLoss: 4.356251256788772
Epoch: [17][val]        JigsawAcc: 0.84 ClusterAcc100: 0.0      ClusterAcc1000: 0.0     ClusterAcc10000: 0.0    ClusterAcc(avg): 0.0    ConstractiveLoss: 0.003248931393027306  ReasonabilityLoss: 0.01237701490521431  RecoverLoss: 0.20707530111074446
 15%|▏| 9/60 [00:13<01:16,  1.50s/it, C_loss=0.0265, ClsLoss_100=4.53, ClsLoss_1000=6.36, ClsLoss_10000=7.01, ClsTotalLoss=17.9, JigLoss=0.91, M_loss=0

Accurate prediction of molecular targets using a self-supervised image rep...(代码理解)_第3张图片


Finetuning

将预训练的模型在下游任务中微调,这里的下游任务是分类任务,(我们的模型将会通过计算分类的可能性 \tilde{\mathcal{Y}}_{n}^{g t} and 真实的标签 \mathcal{Y}_{n}^{g t} 的交叉熵损失来 fine-tuned )。

1. Download pre-trained ImageMol

You can download pre-trained model and push it into the folder ckpts/

2. Finetune with pre-trained ImageMol

a) You can download molecular property prediciton datasets, CYP450 datasets and SARS-CoV-2 datasets and put it into datasets/finetuning/

b) The usage is as follows:

usage: finetune.py [-h] [--dataset DATASET] [--dataroot DATAROOT] [--gpu GPU]
                   [--workers WORKERS] [--lr LR] [--weight_decay WEIGHT_DECAY]
                   [--momentum MOMENTUM] [--seed SEED] [--runseed RUNSEED]
                   [--split {random,stratified,scaffold,random_scaffold,scaffold_balanced}]
                   [--epochs EPOCHS] [--start_epoch START_EPOCH]
                   [--batch BATCH] [--resume PATH] [--imageSize IMAGESIZE]
                   [--image_model IMAGE_MODEL] [--image_aug]
                   [--task_type {classification,regression}]
                   [--save_finetune_ckpt {0,1}] [--log_dir LOG_DIR]

c) You can run ImageMol by simply using the following code:

python finetune.py --gpu ${gpu_no} \
                   --save_finetune_ckpt ${save_finetune_ckpt} \
                   --log_dir ${log_dir} \
                   --dataroot ${dataroot} \
                   --dataset ${dataset} \
                   --task_type ${task_type} \
                   --resume ${resume} \
                   --image_aug \
                   --lr ${lr} \
                   --batch ${batch} \
                   --epochs ${epoch}

For example:

python finetune.py --gpu 0 \
                   --save_finetune_ckpt 1 \
                   --log_dir ./logs/toxcast \
                   --dataroot ./datasets/finetuning/benchmarks \
                   --dataset toxcast \
                   --task_type classification \
                   --resume ./ckpts/ImageMol.pth.tar \
                   --image_aug \
                   --lr 0.5 \
                   --batch 64 \
                   --epochs 20

Note: You can tune more hyper-parameters during fine-tuning (see b) Usage).

运行与结果:

文档结构为:

Accurate prediction of molecular targets using a self-supervised image rep...(代码理解)_第4张图片

这里我运行的命令为:

python finetune.py --gpu 0 --save_finetune_ckpt 1 --log_dir ./logs/toxcast --dataroot ./datasets/finetuning/MPP/classification/ --dataset toxcast --task_type classification --resume ./ckpts/ImageMol.pth.tar --image_aug --lr 0.5 --batch 64 --epochs 20

(imagemol) D:\pycharm_workspace\1\ImageMol>python finetune.py --gpu 0 --save_finetune_ckpt 1 --log_dir ./logs/toxcast --dataroot ./datasets/finetuning/MPP/classification/ --dataset toxcast --task_type classification --resume ./ckpts/ImageMol.pth.tar --image_aug --lr 0.5 --batch 64 --epochs 20
Architecture: ResNet18
eval_metric: rocauc
=> loading checkpoint './ckpts/ImageMol.pth.tar'
resume model info: arch: ResNet18
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=617, bias=True)
)
params: {'total_params': 11493033, 'total_trainable_params': 11493033}
[train epoch 0] loss: 0.243: 100%|███████████████████████████████████████████████████| 108/108 [03:11<00:00,  1.78s/it]
[valid epoch 0] loss: 0.210: 100%|███████████████████████████████████████████████████| 108/108 [00:35<00:00,  3.01it/s]
[valid epoch 0] loss: 0.215: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 0] loss: 0.242: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.10it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 0, 'patience': 0, 'Loss': 0.21016514742815937, 'Train': 0.692967624681878, 'Validation': 0.6699200988533057, 'Test': 0.6675434573214044}
[train epoch 1] loss: 0.208: 100%|███████████████████████████████████████████████████| 108/108 [01:50<00:00,  1.02s/it]
[valid epoch 1] loss: 0.199: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00,  2.96it/s]
[valid epoch 1] loss: 0.207: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.14it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 1] loss: 0.230: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.05it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 1, 'patience': 0, 'Loss': 0.19889598422580296, 'Train': 0.736537038469584, 'Validation': 0.6994855864691358, 'Test': 0.7106267564003038}
[train epoch 2] loss: 0.205: 100%|███████████████████████████████████████████████████| 108/108 [01:49<00:00,  1.02s/it]
[valid epoch 2] loss: 0.191: 100%|███████████████████████████████████████████████████| 108/108 [00:37<00:00,  2.89it/s]
[valid epoch 2] loss: 0.204: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.03it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 2] loss: 0.225: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.09it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 2, 'patience': 0, 'Loss': 0.19094857463130244, 'Train': 0.7681750106520694, 'Validation': 0.713344336728154, 'Test': 0.7183564591204591}
[train epoch 3] loss: 0.197: 100%|███████████████████████████████████████████████████| 108/108 [01:51<00:00,  1.03s/it]
[valid epoch 3] loss: 0.188: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00,  3.00it/s]
[valid epoch 3] loss: 0.204: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.13it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 3] loss: 0.225: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.12it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 3, 'patience': 0, 'Loss': 0.18818754620022243, 'Train': 0.7921560594726607, 'Validation': 0.7106094323867567, 'Test': 0.7319484429668179}
[train epoch 4] loss: 0.190: 100%|███████████████████████████████████████████████████| 108/108 [01:48<00:00,  1.00s/it]
[valid epoch 4] loss: 0.186: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00,  2.96it/s]
[valid epoch 4] loss: 0.207: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.13it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 4] loss: 0.237: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.15it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 4, 'patience': 1, 'Loss': 0.18566039756492334, 'Train': 0.8042472190567201, 'Validation': 0.6931130950238147, 'Test': 0.7199134337688832}
[train epoch 5] loss: 0.188: 100%|███████████████████████████████████████████████████| 108/108 [01:52<00:00,  1.04s/it]
[valid epoch 5] loss: 0.430: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00,  2.98it/s]
[valid epoch 5] loss: 0.447: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.09it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 5] loss: 0.464: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.12it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 5, 'patience': 2, 'Loss': 0.4297037831059209, 'Train': 0.6456776040426601, 'Validation': 0.5981636500016693, 'Test': 0.6110698962504955}
[train epoch 6] loss: 0.183: 100%|███████████████████████████████████████████████████| 108/108 [01:49<00:00,  1.02s/it]
[valid epoch 6] loss: 0.177: 100%|███████████████████████████████████████████████████| 108/108 [00:37<00:00,  2.89it/s]
[valid epoch 6] loss: 0.202: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.07it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 6] loss: 0.229: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.08it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 6, 'patience': 3, 'Loss': 0.17700029302526404, 'Train': 0.8396707209323335, 'Validation': 0.7284599584377031, 'Test': 0.7407342942104}
[train epoch 7] loss: 0.178: 100%|███████████████████████████████████████████████████| 108/108 [01:52<00:00,  1.04s/it]
[valid epoch 7] loss: 0.167: 100%|███████████████████████████████████████████████████| 108/108 [00:35<00:00,  3.00it/s]
[valid epoch 7] loss: 0.206: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.04it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 7] loss: 0.226: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.13it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 7, 'patience': 0, 'Loss': 0.16746581042254413, 'Train': 0.8556398733428161, 'Validation': 0.731332185707106, 'Test': 0.7492089206757857}
[train epoch 8] loss: 0.175: 100%|███████████████████████████████████████████████████| 108/108 [01:47<00:00,  1.00it/s]
[valid epoch 8] loss: 0.167: 100%|███████████████████████████████████████████████████| 108/108 [00:36<00:00,  2.98it/s]
[valid epoch 8] loss: 0.208: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 8] loss: 0.222: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.11it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 8, 'patience': 0, 'Loss': 0.16673416561550564, 'Train': 0.8644607852261611, 'Validation': 0.7262760021080655, 'Test': 0.7608370643734503}
[train epoch 9] loss: 0.169: 100%|███████████████████████████████████████████████████| 108/108 [01:48<00:00,  1.01s/it]
[valid epoch 9] loss: 0.159: 100%|███████████████████████████████████████████████████| 108/108 [00:35<00:00,  3.03it/s]
[valid epoch 9] loss: 0.210: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.14it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 9] loss: 0.232: 100%|█████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.08it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 9, 'patience': 1, 'Loss': 0.15940298857512297, 'Train': 0.8746694145575502, 'Validation': 0.7254201813375953, 'Test': 0.7528834309399343}
[train epoch 10] loss: 0.167: 100%|██████████████████████████████████████████████████| 108/108 [01:48<00:00,  1.01s/it]
[valid epoch 10] loss: 0.163: 100%|██████████████████████████████████████████████████| 108/108 [00:35<00:00,  3.01it/s]
[valid epoch 10] loss: 0.203: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00,  1.99it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 10] loss: 0.225: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00,  1.84it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 10, 'patience': 2, 'Loss': 0.16331072206850406, 'Train': 0.8759709469067016, 'Validation': 0.7251833198910242, 'Test': 0.7444111139117744}
[train epoch 11] loss: 0.164: 100%|██████████████████████████████████████████████████| 108/108 [01:53<00:00,  1.05s/it]
[valid epoch 11] loss: 0.155: 100%|██████████████████████████████████████████████████| 108/108 [00:37<00:00,  2.88it/s]
[valid epoch 11] loss: 0.204: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00,  1.98it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 11] loss: 0.222: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00,  1.86it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 11, 'patience': 3, 'Loss': 0.15538870846783673, 'Train': 0.8912424252365218, 'Validation': 0.7240295542159721, 'Test': 0.753792204644943}
[train epoch 12] loss: 0.160: 100%|██████████████████████████████████████████████████| 108/108 [01:51<00:00,  1.03s/it]
[valid epoch 12] loss: 0.208: 100%|██████████████████████████████████████████████████| 108/108 [00:37<00:00,  2.86it/s]
[valid epoch 12] loss: 0.257: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00,  1.98it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 12] loss: 0.277: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00,  1.92it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 12, 'patience': 4, 'Loss': 0.20842158352887188, 'Train': 0.8415256259318344, 'Validation': 0.6897028119621558, 'Test': 0.7009206570443501}
[train epoch 13] loss: 0.156: 100%|██████████████████████████████████████████████████| 108/108 [01:51<00:00,  1.03s/it]
[valid epoch 13] loss: 0.152: 100%|██████████████████████████████████████████████████| 108/108 [00:35<00:00,  3.01it/s]
[valid epoch 13] loss: 0.209: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.11it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 13] loss: 0.232: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.12it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 13, 'patience': 5, 'Loss': 0.15160143816912616, 'Train': 0.8998601673869299, 'Validation': 0.7268795973894058, 'Test': 0.7505586760702562}
[train epoch 14] loss: 0.155: 100%|██████████████████████████████████████████████████| 108/108 [01:47<00:00,  1.00it/s]
[valid epoch 14] loss: 0.146: 100%|██████████████████████████████████████████████████| 108/108 [00:35<00:00,  3.03it/s]
[valid epoch 14] loss: 0.211: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 14] loss: 0.237: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.11it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 14, 'patience': 6, 'Loss': 0.1462357309129503, 'Train': 0.9063451701636047, 'Validation': 0.7288304309373426, 'Test': 0.7492306562547975}
[train epoch 15] loss: 0.152: 100%|██████████████████████████████████████████████████| 108/108 [01:47<00:00,  1.00it/s]
[valid epoch 15] loss: 0.143: 100%|██████████████████████████████████████████████████| 108/108 [00:35<00:00,  3.03it/s]
[valid epoch 15] loss: 0.204: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 15] loss: 0.231: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.11it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 15, 'patience': 7, 'Loss': 0.1434713469611274, 'Train': 0.913865910273068, 'Validation': 0.7261471932487874, 'Test': 0.7458198395498474}
[train epoch 16] loss: 0.149: 100%|██████████████████████████████████████████████████| 108/108 [01:47<00:00,  1.00it/s]
[valid epoch 16] loss: 0.142: 100%|██████████████████████████████████████████████████| 108/108 [00:36<00:00,  2.97it/s]
[valid epoch 16] loss: 0.209: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.10it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 16] loss: 0.232: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.09it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 16, 'patience': 8, 'Loss': 0.14186685173599808, 'Train': 0.919563304190028, 'Validation': 0.7392030101383237, 'Test': 0.7529772831075494}
[train epoch 17] loss: 0.148: 100%|██████████████████████████████████████████████████| 108/108 [01:49<00:00,  1.01s/it]
[valid epoch 17] loss: 0.140: 100%|██████████████████████████████████████████████████| 108/108 [00:36<00:00,  2.94it/s]
[valid epoch 17] loss: 0.207: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.01it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 17] loss: 0.237: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.03it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 17, 'patience': 0, 'Loss': 0.14025064750953956, 'Train': 0.9222371081994448, 'Validation': 0.7333314785222289, 'Test': 0.7468212697003538}
[train epoch 18] loss: 0.146: 100%|██████████████████████████████████████████████████| 108/108 [01:49<00:00,  1.01s/it]
[valid epoch 18] loss: 0.135: 100%|██████████████████████████████████████████████████| 108/108 [00:36<00:00,  2.96it/s]
[valid epoch 18] loss: 0.208: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.02it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 18] loss: 0.238: 100%|████████████████████████████████████████████████████| 14/14 [00:07<00:00,  2.00it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 18, 'patience': 1, 'Loss': 0.1350550651550293, 'Train': 0.9267112674770283, 'Validation': 0.744386517934743, 'Test': 0.7505391845271615}
[train epoch 19] loss: 0.141: 100%|██████████████████████████████████████████████████| 108/108 [01:50<00:00,  1.02s/it]
[valid epoch 19] loss: 0.132: 100%|██████████████████████████████████████████████████| 108/108 [00:36<00:00,  2.97it/s]
[valid epoch 19] loss: 0.212: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.06it/s]
Some target is missing! Missing ratio: 0.02 [605/617]
[valid epoch 19] loss: 0.235: 100%|████████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.11it/s]
Some target is missing! Missing ratio: 0.01 [610/617]
{'epoch': 19, 'patience': 0, 'Loss': 0.13230810342011629, 'Train': 0.9326320968407155, 'Validation': 0.7282062713930527, 'Test': 0.7534077089166729}
final results: highest_valid: 0.744, final_train: 0.927, final_test: 0.751

最后会在logs下生成微调后的模型:

Accurate prediction of molecular targets using a self-supervised image rep...(代码理解)_第5张图片


Finetuned models(Finetuned 好的模型)

To ensure the reproducibility of ImageMol, we provided finetuned models for eight datasets, including:

  • BBBP

  • Tox21

  • ClinTox

  • HIV

  • BACE

  • SIDER

  • MUV

  • ToxCast

You can evaluate the finetuned model by using the following command:

python evaluate.py --dataroot ${dataroot} \
                   --dataset ${dataset} \
                   --task_type ${task_type} \
                   --resume ${resume} \
                   --batch ${batch}

For example:

python evaluate.py --dataroot ./datasets/finetuning/benchmarks \
                   --dataset toxcast \
                   --task_type classification \
                   --resume ./toxcast.pth \
                   --batch 128
python evaluate.py --dataroot ./datasets/finetuning/MPP/classification --dataset toxcast --task_type classification --resume ./ckpts/toxcast.pth --batch 128

Accurate prediction of molecular targets using a self-supervised image rep...(代码理解)_第6张图片


GradCAM Visualization

More about GradCAM heatmap can be found from this link: https://drive.google.com/file/d/1uu3Q6WLz8bJqcDaHEG84o3mFvemHoA2v/view?usp=sharing

To facilitate observation of high-confidence regions in the GradCAM heatmap, we use a confidence to filter out lower-confidence regions, which can be found from this link: https://drive.google.com/file/d/1631kSSiM_FSRBBkfh7PwI5p3LGqYYpMc/view?usp=sharing

run script

We also provide a script to generate GradCAM heatmaps:

usage: main.py [-h] [--image_model IMAGE_MODEL] --resume PATH --img_path
               IMG_PATH --gradcam_save_path GRADCAM_SAVE_PATH
               [--thresh THRESH]

you can run the following script:

python main.py --resume ${resume} \
               --img_path ${img_path} \
               --gradcam_save_path ${gradcam_save_path} \
               --thresh ${thresh}

Process your own dataset

If you want to process your own dataset and obtain molecular images, use the following steps:

  1. Preprocessing smiles: Please use the method preprocess_list(smiles) of this link to process your raw SMILES data;
  2. Transforming smiles to image: Convert canonical smiles to molecular images using dataloader.image_dataloader.Smiles2Img(smis, size=224, savePath=None)

https://github.com/HongxinXiang/ImageMol

你可能感兴趣的:(自监督学习相关,AIDrug,人工智能,深度学习,计算机视觉)