open-mmlab有许多非常实用的框架,其中目标检测的话mmdetection确实很实用。但语义分割的话当属mmsegmentation,这篇博客介绍mmsegmentation如何将输入图像通道数修改为单通道。
RuntimeError: Given groups=1, weight of size [32, 1, 3, 3], expected input[4, 3, 320, 220] to have 1 channels, but got 3 channels instead
def __init__(self,
in_channels=1,
#in_channels=3,
downsample_dw_channels=(32, 48),
global_in_channels=64,
global_block_channels=(64, 96, 128),
global_block_strides=(2, 2, 1),
global_out_channels=128,
higher_in_channels=64,
lower_in_channels=128,
fusion_out_channels=128,
out_indices=(0, 1, 2),
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False,
dw_act_cfg=None,
init_cfg=None):
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.85),
dict(type='RandomFlip', prob=0.9),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='RGB2Gray', out_channels=1),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=img_scale,
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='RGB2Gray', out_channels=1),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same
网上解释的很多关于这个问题,直接找到对应地方进行修改就可以,去到mmseg下segmentors打开encoder_decoder.py,添加img=img.float
def extract_feat(self, img):
"""Extract features from images."""
img = img.float()
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def encode_decode(self, img, img_metas):
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
img = img.float()
x = self.extract_feat(img)
out = self._decode_head_forward_test(x, img_metas)
out = resize(
input=out,
size=img.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
return out