microsoft/simMIM原碼:
https://github.com/microsoft/SimMIM/tree/main
microsoft/Swim-Transformer:
https://github.com/microsoft/Swin-Transformer/tree/main
上面兩個repo的模型應該是一樣的,模型使用models/simmim.py,因為都沒有提供視覺化的程式,我使用MAE的代碼改寫,另外Xiang Li等人的UM-MAE也有對應的代碼。
( config請根據模型自己建立一個dict )
( build_simmim(config) 從 models/simmim.py 拿)
simMIM(改寫輸出)
class SimMIM(nn.Module):
def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):
super().__init__()
self.config = config
self.encoder = encoder
self.encoder_stride = encoder_stride
self.decoder = nn.Sequential(
nn.Conv2d(
in_channels=self.encoder.num_features,
out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
nn.PixelShuffle(self.encoder_stride),
)
self.in_chans = in_chans
self.patch_size = patch_size
def forward(self, x, mask):
z = self.encoder(x, mask)
x_rec = self.decoder(z)
mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
# norm target as prompted
if self.config['MODEL']['norm_target']:
x = norm_targets(x, self.config['MODEL']['norm_patch_size'])
loss_recon = F.l1_loss(x, x_rec, reduction='none')
loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
"""
注意,這裡要改寫成回傳x_rec作為輸出
"""
return x_rec,loss
Utilities
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as T
#自行替換成符合你的資料集的mean與std
image_mean = np.array([0.485, 0.456, 0.406])
image_std = np.array([0.229, 0.224, 0.225])
class MyTransform:
def __init__(self, config, mask_ratio):
self.transform_img = T.ToTensor()
model_patch_size=config['MODEL']['patch_size']
self.mask_generator = MaskGenerator(
input_size=config['DATA']['input_size'],
mask_patch_size=config['DATA']['mask_patch_size'],
model_patch_size=model_patch_size,
mask_ratio=mask_ratio,
)
def __call__(self, img):
img = self.transform_img(img)
mask = self.transform_img(self.mask_generator())
return img, mask
def show_image(image, title=''):
# image is [H, W, 3]
assert image.shape[2] == 3
plt.imshow(torch.clip((image * image_std + image_mean) * 255, 0, 255).int())
plt.title(title, fontsize=16)
plt.axis('off')
return
def prepare_model(chkpt_dir):
# build model
model = build_simmim(config)
# load model
checkpoint = torch.load(chkpt_dir, map_location='cpu')
rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k]
for k in rpe_mlp_keys:
checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)
msg = model.load_state_dict(checkpoint['model'], strict=False)
print(msg)
del checkpoint
model.eval()
return model
def run_one_image(img, model,mask_ratio=0.65):
#transform and mask
transform = MyTransform(config, mask_ratio)
x,mask = transform(img)
# run simMIM
y,_ = model(x.unsqueeze(dim=0).float(), mask)
y = y.detach().squeeze(0)
print(y.shape)
# visualize the mask
mask = mask.repeat_interleave(model.patch_size, 1).repeat_interleave(model.patch_size, 2).contiguous()
im_masked = x * (1 - mask)
# Reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 24]
plt.subplot(1, 4, 1)
show_image(torch.einsum('chw->hwc', x), "original")
plt.subplot(1, 4, 2)
show_image(torch.einsum('chw->hwc', im_masked), "masked")
plt.subplot(1, 4, 3)
show_image(torch.einsum('chw->hwc', y), "reconstruction")
plt.subplot(1, 4, 4)
show_image(torch.einsum('chw->hwc', im_paste), "reconstruction + visible")
plt.show()
#如果有checkpoint自行替換路徑
model = prepare_model('../out_dir/pretrain/simMIM_pt_base_192_w6-45.pth')
#Prepare Image
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
#T.RandomCrop((192,192)),
T.RandomResizedCrop((192,192)),
])
#準備你要測試的圖片
img = Image.open('../Fabrics/Quixel/001/oi2uhyp_2K_Roughness.jpg')
img = transform(img)
img = np.array(img) / 255.
assert img.shape == (192, 192, 3)
# normalize by ImageNet mean and std
img = img - image_mean
img = img / image_std
plt.rcParams['figure.figsize'] = [3,3]
show_image(torch.tensor(img))
#Visualize
torch.manual_seed(123456)
print('simMIM with pixel reconstruction:')
run_one_image(img,model,mask_ratio=0.65)
轉載請標記出處。
另外我有將程式碼改寫成單卡可在jupyter上跑的simMIM,如果有需要可以詢問。