如何理解Transformer缺乏像CNN那样的归纳偏置


具体示例:“数字位置分类任务”

我们设计一个简单的任务来对比 CNN 和 Transformer 对位置变化的处理能力:

任务设定
  • 输入:28x28 灰度图像,包含一个手写数字(0~9),但数字位置可能出现在图像任意位置(而非固定居中)。
  • 目标:模型需要同时完成两个任务:
    1. 分类:识别数字类别(0~9)。
    2. 定位:预测数字的中心坐标(x, y,取值范围 [0, 27])。
  • 训练数据:仅包含数字出现在图像左侧半区的样本(x ≤ 13)。
  • 测试数据:数字出现在图像右侧半区(x > 13),测试模型对未见过位置的泛化能力。

1. CNN 模型设计

import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征提取(隐含平移不变性)
        self.features = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),  # 3x3卷积核,滑动检测局部特征
            nn.ReLU(),
            nn.MaxPool2d(2),                 # 14x14
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)                  # 7x7
        )
        # 分类头
        self.classifier = nn.Sequential(
            nn.Linear(32*7*

你可能感兴趣的:(transformer,cnn,深度学习)