论文是讲乳腺癌分类的,简单做个复现,模型如图1,详细介绍如图2,其余细节:batchsize给的是64.(剩下关于激活函数比较,dropout与BN层影响在此不做讨论,通过代码变形实施起来也简单)
import torch
from torch import nn
from torchinfo import summary
2.对论文中所论述的3个卷积+一个BN层+一个maxpooling层的 结构进行定义
class TCIS(nn.Module):
def __init__(self,channel,padding=True,First=False):
super().__init__()
self.channel = channel
self.padding = padding
self.First = First
self.inchannel = int(self.channel/2)
#第一层第一个卷积接收到的channel是3,单独定义
self.conv_first = nn.Conv2d(in_channels=3, out_channels=self.channel,
kernel_size=3,stride=1, padding = 1)
#除第一层,其余层的第一个卷积channel是当前channel的一半
self.conv1_1 = nn.Conv2d(in_channels=self.inchannel, out_channels=self.channel,
kernel_size=3, stride=1, padding=1)
#除第五层外,第二个卷积和第三个卷积定义
self.conv1_2 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel,
kernel_size=3,stride=1, padding = 1)
#第五层中,第二个卷积和第三个卷积定义
self.conv2_2 = nn.Conv2d(in_channels=self.channel, out_channels=self.channel,
kernel_size=3, stride=1)
#bn层定义
self.bn = nn.BatchNorm2d(self.channel)
#最大池化层定义
self.mp = nn.MaxPool2d(kernel_size=2)
def forward(self,x):
if self.First == False:
x1 = self.conv1_1(x)
else:
x1 = self.conv_first(x)
if self.padding==True:
x2 = self.conv1_2(x1)
x3 = self.conv1_2(x2)
else:
x2 = self.conv2_2(x1)
x3 = self.conv2_2(x2)
xbn = self.bn(x3)
xout = self.mp(xbn)
return xout
3.对主程序进行定义
class NACHICNN(nn.Module):
def __init__(self):
super().__init__()
#第一层
self.convb1 = TCIS(32,padding=True, First=True)
#第二层
self.convb2 = TCIS(64)
#第三层
self.convb3 = TCIS(128)
#第4层
self.convb4 = TCIS(256)
#第5层
self.convb5 = TCIS(512,padding=False)
self.fc1 = nn.Linear(32768,512)#输入神经元数是通过debug找到的
self.fc2 = nn.Linear(512,512)
self.sigmoid_ = nn.Sigmoid()
def forward(self,x):
x1 = self.convb1(x)
x2 = self.convb2(x1)
x3 = self.convb3(x2)
x4 = self.convb4(x3)
x5 = self.convb5(x4)
x6 = x5.flatten()#(这里我按照原文直接使用flatten(),深度学习中,多数情况下使用全局平均池化+view(),的组合对其进行定义)
x6 = self.fc1(x6)
x7 = self.fc2(x6)
return torch.sigmoid(x7)
4.模型测试代码
if __name__ == "__main__":
#实例化
net = NACHICNN().to(device='cuda')
#打印模型
show_ = summary(net, input_size=(64,3,96,96))
import sys; print(‘Python %s on %s’ % (sys.version, sys.platform))
Python 3.9.13 (main, Aug 25 2022, 23:51:50) [MSC v.1916 64 bit (AMD64)]
Type ‘copyright’, ‘credits’ or ‘license’ for more information
IPython 7.31.1 – An enhanced Interactive Python. Type ‘?’ for help.
PyDev console: using IPython 7.31.1
Python 3.9.13 (main, Aug 25 2022, 23:51:50) [MSC v.1916 64 bit (AMD64)] on win32
待施工。。。
1.数据读取
3.训练代码
4.测试代码
5.主程序