MakeDataset.py
首先准备好一个数据集文件,这里以mydata
文件夹存放图片数据, 实现自定义DataSet
class MyDataset(Dataset):
def __init__(self,resize):
super(MyDataset,self).__init__()
self.resize = resize
def __len__(self):
return len(images)
def __getitem__(self,idx):
img = images[idx]
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB'),
transforms.Resize((self.resize,self.resize)),
transforms.ToTensor(),
transforms.Normalize(mean = [0.485,0.456,0.406],
std = [0.229,0.224,0.225])
])
img_tensor = tf(image)
# `mydata\\ICH\\1470718-1.JPG`
label_tensor = torch.tensor(class_name_index[image.split(os.sep)[-2]])
return img_tensor,label_tensor
这里以一个玩具
模型作为演示,模型的定义如下:
MyModle.py
class MyResNet(nn.Module):
def __init__(self):
super(MyResNet,self).__init__()
general_features = 32
# Initial convolution block
self.conv0 = nn.Conv2d(3,general_features,3,1,padding=1)
self.conv1 = nn.Conv2d(general_features,general_features,3,1,padding =1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(general_features,general_features,3,1,padding=1)
self.relu2 = nn.ReLU()
# Down sample 1/2
self.downsample0 = nn.Maxpool2d(2,2)
self.downsample1 = nn.Maxpool2d(2,2)
self.downsample2 = nn.Maxpool2d(2,2)
self.downsample3 = nn.Maxpool2d(2,2)
self.fc0 = nn.Linear(32*8*8, 2)
def forward(self,x):
x = self.conv0(x) #[1,32,128,128]
x = self.downsample0(x) #[1,32,64,64]
x = self.downsample1(x) #[1,32,32,32]
x = self.relu1(self.conv1(x)) #[1,32,32,32]
x = self.downsample2(x) # [1,32,16,16]
x = self.relu2(self.conv2(x)) #[1,32,16,16]
x = self.downsample3(x) # [1,32,8,8]
x = x.view(x.shape[0],-1) # Flatten
x = x.softmax(self.fc0(x),dim=1)
return x
# x = torch.randn(1,3,128,128)
# m = myResNet()
# summary(m,(3,128,128))
# print(m(x).shape)
训练train.py
获得权重文件
import torch
from torch import optim,nn
from torch.utils.data import Dataloader
from MakeDataSet import MyDataset
from MyModel import MyResNet
train_db = MyDataset(resize = 128)
train_loader = DataLoader(train_db,batch_size=4,shuffle=True)
print('num_train:',len(train_loader.dataset))
model = MyResNet()
optimizer = optim.Adam(model.parameters(),lr =0.001)
criteon = nn.CrossEntropyLoss()
epochs = 5
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
model.train()
logits = model(x)
loss = criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epochs:',epoch,'Loss:',loss)
torch.save(model.state_dict(),'weights_MyResNet.mdl')
print('Save Done')
Visualize_featrue_map
, 这里介绍smooth gradcampp用法
import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from torchcam.methods import SmoothGradCAMpp,CAM,GradCAM,GradCAMpp,XGradCAM,ScoreCAM
from torchcam.utils import overlay_mask
from MyModel import MyResNet
from PIL import image
import matplotlib.pyplot import plt
tf = transforms.Compose([
lambda x:Image.open(x).convert('RGB')
transforms.Resize(128,128)
transforms.ToTensor(),
transforms.Normalize(mean = [0.485,0.456,0.406],
std = [0.229,0.224,0.225]
)
])
img_ICH_test = tf('ICH_test.jpg').unsqueeze(dim=0)
#print(img_ICH_test.shape)
img_Normal_test = tf('Normal_test.jpg').unsqueeze(dim=0)
model = MyResNet()
model.load_state_dict(torch.load('weights_MyResNet.mdl'))
print('loaded from ckpt')
model.eval()
cam_extractor = SmoothGradCAMpp(model,input_shape=(3,128,128))
# cam_extractor = GradCAMpp(model,input_shape=(3,128,128))
# cam_extractor = XGradCAM(model,input_shape=(3,128,128))
# cam_extractor = ScoreCAM(model,input_shape=(3,128,128))
# cam_extractor = SSCAM(model,input_shape=(3,128,128))
# cam_extractor =ISCAM(model,input_shape=(3,128,128))
# cam_extractor = LayerCAM(model,input_shape=(3,128,128))
output = model(img_Normal_test)
print(output)
activation_map = cam_extractor(output.sequeeze(0).argmax().item(),output)
print(activation_map[0],activation_map[0].min(),activation_map[0].max(),activation_map[0].shape)
#fused_map = cam_extractor.fuse_cams(activation_map)
#print(fused_map[0],fused_map[0].min(),fused_map[0].max(),fused_map[0].shape)
result = overlay_mask(to_pil_image(img_Normal_test[0]),
to_pil_image(activation_map[0],mode='F'),alpha=0.3)
plt.imshow(result)
plt.show()
cam_extractor
对象中,由于activation_map输出的是一个tuple,通过索引0取值overlay_mask
进行可视化效果展示,传入原图和激活map,并利用alpha参数设置一定的透明度这个包的主页在: https://pypi.org/project/torchcam/,感兴趣的可以看看