TensorFlow2代码解读(7)

import tensorflow as tf
from tensorflow.keras import layers,Sequential
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

class BasicBlock(layers.Layer):
    def __init__(self,filter_num,stride = 1):
        super(BasicBlock,self).__init__()
        
        self.conv1 = layers.Conv2D(filter_num,(3,3),strides = stride,padding = 'same')
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')
        
        self.conv2 = layers.Conv2D(filter_num,(3,3),strides = 1,padding = 'same')
        self.bn2 = layers.BatchNormalization()
        
        if stride != 1:
            self.downsample = Sequential()
            self.downsample.add(layers.Conv2D(filter_num,(1,1),strides = stride))
        else:
            self.downsample = lambda x:x
        
    def call(self,inputs,training = None):
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        identity = self.downsample(inputs)
        
        output = layers.add([out,identity])
        output = tf.nn.relu(output)
        
        return output
    
class ResNet(tf.keras.Model):
    def __init__(self,layer_dims,num_classes=100):
        super(ResNet,self).__init__()
        
        self.stem = Sequential([layers.Conv2D(64,(3,3),strides=(1,1)),
                                layers.BatchNormalization(),
                                layers.Activation('relu'),
                                layers.MaxPool2D(pool_size=(2,2),strides=(1,1),padding='same')
                                ])
        
        self.layer1 = self.build_resblock(64,layer_dims[0])
        self.layer2 = self.build_resblock(128,layer_dims[1],stride = 2)
        self.layer3 = self.build_resblock(256,layer_dims[2],stride = 2)
        self.layer4 = self.build_resblock(512,layer_dims[3],stride = 2)
        
        self.avgpool = layers.GlobalAveragePooling2D()
        self.fc = layers.Dense(num_classes)
        
    def call(self,inputs,training = None):
        x = self.stem(inputs)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = self.fc(x)
        
        return x
        
    def build_resblock(self,filter_num,blocks,stride = 1):
        res_blocks = Sequential()
        res_blocks.add(BasicBlock(filter_num,stride))
        
        for _ in range(1,blocks):
            res_blocks.add(BasicBlock(filter_num,stride = 1))
            
        return res_blocks
        

def resnet18():
    return ResNet([2,2,2,2])
        
import tensorflow as tf
from tensorflow.keras import layers, Sequential
import os

导入所需的模块
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

设置 TensorFlow 的日志级别,将日志输出级别设置为 2,即只输出错误信息。
class BasicBlock(layers.Layer):
    def __init__(self, filter_num, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same')
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')

        self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same')
        self.bn2 = layers.BatchNormalization()

        if stride != 1:
            self.downsample = Sequential()
            self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride))
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None):
        out = self.conv1(inputs)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        identity = self.downsample(inputs)

        output = layers.add([out, identity])
        output = tf.nn.relu(output)

        return output


定义 BasicBlock 类,它是 ResNet 的基本残差块。包含两个卷积层、批归一化层和激活函数。
class ResNet(tf.keras.Model):
    def __init__(self, layer_dims, num_classes=100):
        super(ResNet, self).__init__()

        self.stem = Sequential([
            layers.Conv2D(64, (3, 3), strides=(1, 1)),
            layers.BatchNormalization(),
            layers.Activation('relu'),
            layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')
        ])

        self.layer1 = self.build_resblock(64, layer_dims[0])
        self.layer2 = self.build_resblock(128, layer_dims[1], stride=2)
        self.layer3 = self.build_resblock(256, layer_dims[2], stride=2)
        self.layer4 = self.build_resblock(512, layer_dims[3], stride=2)

        self.avgpool = layers.GlobalAveragePooling2D()
        self.fc = layers.Dense(num_classes)

    def call(self, inputs, training=None):
        x = self.stem(inputs)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.fc(x)

        return x

    def build_resblock(self, filter_num, blocks, stride=1):
        res_blocks = Sequential()
        res_blocks.add(BasicBlock(filter_num, stride))

        for _ in range(1, blocks):
            res_blocks.add(BasicBlock(filter_num, stride=1))

        return res_blocks


定义 ResNet 类,它是一个包含多个残差块的网络结构。包括一个初始卷积层、多个残差块层、全局平均池化层和全连接层。
def resnet18():
    return ResNet([2, 2, 2, 2])

定义一个函数 resnet18(),用于创建一个包含4个残差块层的 ResNet-18 模型。

你可能感兴趣的:(TensorFlow2代码解读,tensorflow,人工智能,python,深度学习)