batchsamper学习总结

欢迎关注我们组的微信公众号,更多好文章在等你呦!
微信公众号名:碳硅数据
公众号二维码:
在这里插入图片描述今天看代码看到了一个很好的关于batchsampler的实现,做了一些测试,记录一下

import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import Sampler
from torch.utils.data import DataLoader
import numpy as np 

class SingleCellDataset(Dataset):
    """
    Dataloader of single-cell data
    """
    def __init__(self, adata, use_layer='X'):
        """
        create a SingleCellDataset object
            
        Parameters
        ----------
        adata
            AnnData object wrapping the single-cell data matrix
        """
        self.adata = adata
        self.shape = adata.shape
        self.use_layer = use_layer
        
    def __len__(self):
        return self.adata.shape[0]
    
    def __getitem__(self, idx):
        if self.use_layer == 'X':
            if isinstance(self.adata.X[idx], np.ndarray):
                x = self.adata.X[idx].squeeze().astype(float)
            else:
                x = self.adata.X[idx].toarray().squeeze().astype(float)
        else:
            if self.use_layer in self.adata.layers:
                x = self.adata.layers[self.use_layer][idx]
            else:
                x = self.adata.obsm[self.use_layer][idx]
                
        domain_id = self.adata.obs['batch'].cat.codes.iloc[idx]
        return x, domain_id, idx

class BatchSampler(Sampler):
    """
    Batch-specific Sampler
    sampled data of each batch is from the same dataset.
    """
    def __init__(self, batch_size, batch_id, drop_last=False):
        """
        create a BatchSampler object
        
        Parameters
        ----------
        batch_size
            batch size for each sampling
        batch_id
            batch id of all samples
        drop_last
            drop the last samples that not up to one batch
            
        """
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.batch_id = batch_id

    def __iter__(self):
        batch = {}
        sampler = np.random.permutation(len(self.batch_id))
        for idx in sampler:
            c = self.batch_id[idx]
            if c not in batch:
                batch[c] = []
            batch[c].append(idx)

            if len(batch[c]) == self.batch_size:
                yield batch[c]
                batch[c] = []

        for c in batch.keys():
            if len(batch[c]) > 0 and not self.drop_last:
                yield batch[c]
            
    def __len__(self):
        if self.drop_last:
            return len(self.batch_id) // self.batch_size
        else:
            return (len(self.batch_id)+self.batch_size-1) // self.batch_size
        
# scdata = SingleCellDataset(adata, use_layer="X") # Wrap AnnData into Pytorch Dataset
# batch_sampler = BatchSampler(64, adata.obs['batch'], drop_last=False)
# testloader = DataLoader(scdata, batch_sampler=batch_sampler, num_workers=0)

测试如下

from torch.utils.data import sampler
# 定义数据和对应的采样
data = list([17, 22, 3, 41, 8])
seq_sampler = sampler.SequentialSampler(data_source=data)
# 迭代获取采样器生成的索引
for index in seq_sampler:
    print("index: {}, data: {}".format(str(index), str(data[index])))

结果如下
batchsamper学习总结_第1张图片

首先要搞清楚这个sampler和Dataloader之间的关系,从上面的额例子可以看到,seq_sampler是直接可以迭代输出看结果的
,这个对我很重要

同样的我测试这个自定义的BatchSampler可以用同样的方式

from torch.utils.data import sampler
# 定义数据和对应的采样
batch_id = np.random.choice([0,1],100)
data = np.array((range(100)))
batch_sampler = BatchSampler(5,batch_id,drop_last=False)
# 迭代获取采样器生成的索引
for index in batch_sampler:
    print("index: {},data={} ,batch_id={}".format(index,str(data[index]),str(batch_id[index])))


结果如下
batchsamper学习总结_第2张图片
batchsamper学习总结_第3张图片可以看到这个Batchsampler的作用就是每次抽样保证从同一个batch中抽,这个倒是我直接看代码看不出来的,我以为两个batch等量的抽呢,所以还是得测试,不然不懂

你可能感兴趣的:(学习)