Sartorius - Cell Instance Segmentation-细胞分割kaggle竞赛

Sartorius - Cell Instance Segmentation-细胞分割kaggle竞赛_第1张图片

赛题介绍:

背景介绍:

概括: 一种神经细胞(神经母细胞瘤细胞系 SH-SY5Y)在现有的模型中分数表现始终最差,找一种方案来应对该数据并提高成绩

神经系统疾病,包括阿尔茨海默氏症和脑肿瘤等神经退行性疾病,是全球死亡和残疾的主要原因。然而,很难量化这些致命疾病对治疗的反应如何。一种被接受的方法是通过光学显微镜检查神经元细胞,这种方法既方便又非侵入性。不幸的是,在显微图像中分割单个神经元细胞可能具有挑战性且耗时。在计算机视觉的帮助下,这些细胞的准确实例分割可能会导致新的有效药物发现,以治疗数百万患有这些疾病的人。

当前的解决方案对神经元细胞的准确性尤其有限。在开发细胞实例分割模型的内部研究中, 在测试的八种不同癌细胞类型中始终表现出最低的精度分数。这可能是因为神经元细胞具神经母细胞瘤细胞系 SH-SY5Y有非常独特的、不规则的和凹形的形态与之相关,这使得它们很难用常用的面具头进行分割。在本次比赛中,您将在描绘神经系统疾病研究中常用的神经元细胞类型的生物图像中检测和描绘不同的感兴趣对象。更具体地说,您将使用相差显微镜图像来训练和测试您的模型,例如神经元细胞的分割。成功的模型将以高准确度做到这一点。

二:数据介绍:

分割图像中的神经元细胞。
掩码图像数量很少,但注释对象的数量相当多。
(一图多框)
隐藏测试集大约有 240 张图像。
注意:虽然不允许预测重叠,但完整提供了训练标签(包括重叠部分)。这是为了确保为模型提供每个对象的完整数据。消除预测中的重叠是参赛者的一项任务。

文件

train.csv - 所有训练对象的 ID 和掩码。没有为测试集提供这些元数据。
id - 对象的唯一标识符
annotation - 已识别神经元细胞的运行长度编码像素
width - 源图像宽度
height - 源图像高度
cell_type - 细胞系
plate_time - 时间盘已创建
sample_date - 创建日期样本
sample_id - 样品标识符
elapsed_timedelta - 自第一次拍摄样本图像以来的时间
sample_submission.csv - 格式正确的示例提交文件
train - png图像
测试- png图像
(图像只有少部分可以提供下载)本赛提为notebook赛

注释:

提交的代码中,读取test数据时会自动读取所有的测试数据
train_semi_supervised - 提供未标记的图像,以防您想将其他数据用于半监督方法。
##竞赛给出额外的数据集
LIVECell_dataset_2021 - LIVECell 数据集数据的镜像。LIVECell 是本次比赛的前身数据集。您将找到该SH-SHY5Y细胞系的额外数据,以及竞赛数据集中未涵盖的其他几个可能对迁移学习感兴趣的细胞系。

Baseline银牌方案讲解

使用MMdetection框架在kaggle上的推断代码(训练代码还在整理中)

in order to use mmdection ,we need to uninstall pytorch-1.9 and install pytorch-1.7

%cd /kaggle/working
#环境配置
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torch-1.7.0+cu110-cp37-cp37m-linux_x86_64.whl' --no-deps
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torchvision-0.8.1+cu110-cp37-cp37m-linux_x86_64.whl' --no-deps
!pip install '/kaggle/input/pytorch-170-cuda-toolkit-110221/torchaudio-0.7.0-cp37-cp37m-linux_x86_64.whl' --no-deps

# we need to move mmdetection to kaggle working path.
#框架mmdetection
!cp -r ../input/mmdetectionv2140/* /kaggle/working/
!cp -r ../input/mmdetection-new/* /kaggle/working/
#your model
!cp -r ../input/handudu/mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py /kaggle/working/
!cp -r ../input/timtim/timm-0.4.12-py3-none-any.whl /kaggle/working/
!pip install '/kaggle/working/timm-0.4.12-py3-none-any.whl' --no-deps
!pip install '/kaggle/working/addict-2.4.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/working/yapf-0.31.0-py2.py3-none-any.whl' --no-deps
!pip install '/kaggle/working/terminal-0.4.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/working/terminaltables-3.1.0-py3-none-any.whl' --no-deps
!pip install '/kaggle/working/mmcv_full-1_3_8-cu110-torch1_7_0/mmcv_full-1.3.8-cp37-cp37m-manylinux1_x86_64.whl' --no-deps
!pip install '/kaggle/working/pycocotools-2.0.2/pycocotools-2.0.2' --no-deps
!pip install '/kaggle/working/mmpycocotools-12.0.3/mmpycocotools-12.0.3' --no-deps
!mv /kaggle/working/CBNetV2-main  /kaggle/working/mmdetection
%cd /kaggle/working/mmdetection
!pip install -e .

importing some tools are basic and important

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
import sklearn
import torchvision
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import numpy as np
import cupy as cp#numpy的加速版本  https://blog.csdn.net/qq_41185868/article/details/103479683?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164166309316781685314823%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=164166309316781685314823&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-103479683.pc_search_insert_ulrmf&utm_term=cupy&spm=1018.2226.3001.4187

import gc
import pandas as pd
import os
import matplotlib.pyplot as plt
import PIL
import json
from PIL import Image, ImageEnhance#进行图像增强。https://blog.csdn.net/update7/article/details/106593060?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164166317916780261950994%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=164166317916780261950994&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-2-106593060.pc_search_insert_ulrmf&utm_term=from+PIL+import+Image%2C+ImageEnhance&spm=1018.2226.3001.4187
import albumentations as A#图像数据增强库 最常用
import mmdet
import mmcv
from albumentations.pytorch import ToTensorV2 #数据增强 https://blog.csdn.net/zhangyuexiang123/article/details/107705311?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522164166334716781683935007%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=164166334716781683935007&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-107705311.pc_search_insert_ulrmf&utm_term=from+albumentations.pytorch+import+ToTensorV2&spm=1018.2226.3001.4187
import seaborn as sns
import glob
from pathlib import Path
import pycocotools   ##CoCo
from pycocotools import mask
import numpy.random
import random
from glob import glob
from tqdm.notebook import tqdm
import cv2
import re
import shutil
from mmdet.apis import inference_detector, init_detector, show_result_pyplot, set_random_seed

add post function,such as rle trick in cell competition


def rle_encoding(x):
    dots = np.where(x.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b>prev+1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return ' '.join(map(str, run_lengths))
def get_mask_from_result(result):
    d = {True : 1, False : 0}
    u,inv = np.unique(result,return_inverse = True)
    mk = cp.array([d[x] for x in u])[inv].reshape(result.shape)
#     print(mk.shape)
    return mk
def does_overlap(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            #import pdb; pdb.set_trace()
            #print("Found overlapping masks!")
            return True
    return False
def remove_overlapping_pixels(mask, other_masks):
    for other_mask in other_masks:
        if np.sum(np.logical_and(mask, other_mask)) > 0:
            print("Overlap detected")
            mask[np.logical_and(mask, other_mask)] = 0
    return mask

def one_hot(y, num_classes, dtype=cp.uint8): # GPU
    y = cp.array(y, dtype='int')
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])
    y = y.ravel()
    if not num_classes:
        num_classes = cp.max(y) + 1
    n = y.shape[0]
    categorical = cp.zeros((n, num_classes), dtype=dtype)
    categorical[cp.arange(n), y] = 1
    output_shape = input_shape + (num_classes,)
    categorical = cp.reshape(categorical, output_shape)
    return categorical

实例化模型(本地训练上传到kaggle)

config_file = f'mask_rcnn_cbv2_swin_tiny_patch4_window7_mstrain_480-800_adamw_3x_coco.py'
checkpoint_file = '../input/final-model-w/0317.pth' #本地上传的权重文件,云服务器训练
model = init_detector(config_file, checkpoint_file, device='cuda:0')
masks = []
files = []
#from mmdet.apis import inference_detector, init_detector, show_result_pyplot, set_random_seed
MIN_PIXELS = {0: 75, 1: 75, 2: 150}
confidence_thresholds = {0: 0.25, 1: 0.65, 2: 0.35}
for imgs in tqdm(glob('../input/sartorius-cell-instance-segmentation/test/*')):#测试集合所有的数据说明
    result = inference_detector(model, imgs)#api
    
    pred_class_ls = [len(result[0][0]), len(result[0][1]), len(result[0][2])]
    pred_class = pred_class_ls.index(max(len(result[0][0]), len(result[0][1]), len(result[0][2])))
    msk = []
    for i, classe in enumerate(result[0]):
        if classe.shape != (0, 5):
            bbs = classe
            sgs = result[1][i]
            for bb, sg in zip(bbs,sgs):
                box = bb[:4]
                cnf = bb[4]
                if cnf >= confidence_thresholds[pred_class]:
                    mask = get_mask_from_result(sg)
                    mask = remove_overlapping_pixels(mask, msk)
                    if mask.sum() >= MIN_PIXELS[pred_class]:
                        msk.append(mask)
    for mk in msk:
            rle_mask = rle_encoding(mk)##########rle_encoding
            masks.append(rle_mask)
            files.append(str(imgs.split('/')[4].split('.')[0]))

生成提交文件

files = pd.Series(files, name='id')
preds = pd.Series(masks, name='predicted')
submission_df = pd.concat([files, preds], axis=1)
submission_df.to_csv('submission.csv', index=False)
submission_df.head()

优质讲解
MMdetection官网
MMdetection知乎

你可能感兴趣的:(计算机视觉,深度学习,机器学习)