10X单细胞(10X空间转录组)批次效应去除大盘点2 & AWGAN

hello,大家好,今天我们要再度深入认识一下批次效应,关于批次效应,我们之前分享了很多了,文章列在这里,供大家参考

10X单细胞(10X空间转录组)多样本批次效应去除分析之RCA2

10X单细胞(10X空间转录组)整合分析批次处理之细节(harmony)

10X单细胞(10X空间转录组)批次去除(整合)分析之Scanorama

10X单细胞(10X空间转录组)批次效应去除大盘点

10X单细胞(空间转录组)数据整合分析批次矫正之liger

单细胞数据用Harmony算法进行批次矫正

批次效应

10X单细胞(10X空间转录组)数据分析总结之各种NMF。

今天我们要深入一下批次去除的方法AWGAN,参考文献在AWGAN: A Powerful Batch Correction Model for scRNA-seq Data

目前去除批次效应的两种思路

  • 第一个是假设来自不同批次的数据遵循相同的分布,这样就可以确定分布中的参数,并通过统计方法去除数据集中的批次效应,这样的方法包括 Limma(Limma
    powers differential expression analyses for rna-sequencing and microarray studies),
    Combat(Adjusting batch effects in microarray expression data using empirical bayes methods), Liger(Single-cell multi-omic integration compares and contrasts features of brain cell identity), and scVI(Bayesian inference for a generative model of transcriptome profiles from single-cell rna sequencing)
  • 第二种是选择一个批次作为参考批次,其他批次作为query批次,通过构建从query批次到参考批次的映射,得到批次效应校正数据,这类方法包括Mutual Nearest Neighbor (MNN) , Seurat v4, Batch balanced kNN (BBKNN),iMAP, and Harmony

Wasserstein Generative Adversarial Network (WGAN) combined with an attention mechanism to reduce the differences among batches.

图片.png

AWGAN的分析步骤,three key steps: attention-driven data preprocessing, AWGAN training, and model evaluation

  • 第一步,跨批次选择高度可变的基因作为数据集成的共同特征。此外,数据集不能直接输入到训练模型中,因为需要确保分析的方法可以保留生物信息。因此,训练数据集中的细胞数量需要兼容为#(SubReferData) = #(QueryData)。可以通过选择最近邻对来生成与query数据匹配的参考数据子集,因为训练 GAN 非常棘手并且要求两个样本分布尽可能相似,利用匹配对生成训练集很重要。当有两批时,只需要训练一次模型。当目标数据集中有 k(k > 2) 个批次时,通过运行attention algorithm和模型训练过程 k -1 次来利用增量匹配学习策略。每一轮之后,一个query数据集会被整合到总参考批次中,使参考批次越来越大。采用这种策略是因为不同的query批次可能有不同的分布。
  • 第二步,在生成训练数据集后,我们利用具有梯度惩罚的 WGAN(一个软件) 生成一个映射,该映射可以将query数据转换为类似于参考数据的分布。 此外,为了减少模型崩溃的可能性,还允许自适应机制为梯度惩罚寻找合适的正则化系数并调整训练时期的数量。 WGAN 由一个生成器模型和一个判别器模型组成。 前一个模型用于从query数据中生成遵循新分布的数据,而后一个模型用于区分校正后的查询数据是否与参考数据具有相似的分布。 此外,鉴别器可以找到批次特定的细胞类型并保持它们的属性。 分析采用对抗训练的策略来提升两个模型的能力,最终实现它们之间的Nash equilibrium。
  • 在最后一步,在去除批量效应后生成一个集成的 scRNA-seq 数据集。 为了将结果可视化,将 UMAP 作为一种流形学习选择。此外,为了量化去除批量效应后的结果,分析考虑了四个指标:ASW、kBET、LISI 和 Graph Connectivity,用于定量评估批量效应去除的性能。

各个软件之间的比较

Small-scale scRNA-seq Datasets

图片.png
  • 注:CL(Pure Cell Lines) 数据集不同批次效应去除方法的性能。 (a) 在去除批次效应之前,带有批次标签的数据分布的可视化。 (b) 在去除批次效应之前,带有细胞类型标签的数据分布的可视化。 (c) 去除批次效应后带有批次标签的数据分布的可视化。 (d) 去除批次效应后带有细胞类型标签的数据分布的可视化。 (e) CL 数据集上不同批次效应去除方法的 ASW 评估。 (f) CL 数据集上不同批次效应去除方法的 LISI 评估。


    图片.png
  • DC(Human Dendritic Cells) 数据集不同批次效应去除方法的性能。 (a) 在去除批次效应之前,带有批次标签的数据分布的可视化。 (b) 在去除批次效应之前,带有细胞类型标签的数据分布的可视化。 (c) 去除批次效应后带有批次标签的数据分布的可视化。 (d) 去除批次效应后带有细胞类型标签的数据分布的可视化。 (e) DC 数据集上不同批次效应去除方法的 ASW 评估。 (f) DC 数据集上不同批次效应去除方法的 LISI 评估


    图片.png
  • 不同批次效应去除方法对Human Pancreas dataset的性能。 (a) 在去除批次效应之前,带有批次标签的数据分布的可视化。 (b) 在去除批次效应之前,带有细胞类型标签的数据分布的可视化。 (c) 去除批次效应后带有批次标签的数据分布的可视化。 (d) 去除批次效应后带有细胞类型标签的数据分布的可视化。 (e) 胰腺 rm 数据集上不同批次效应去除方法的 ASW 评估。 (f) 对 Pancreas rm 数据集的不同批次效应去除方法的 LISI 评估。


    图片.png
  • The performance of different batch effect removal methods for the PBMC 3&68K.
    (a) Visualization for the data distribution with batch label before the batch effect removal. (b) Visualization for the data distribution with cell type label before the batch effect removal. (c) Visualization for the data distribution with batch label after the batch effect removal. (d) Visualization for the data distribution with cell type label after the batch effect removal. (e) ASW assessments of different batch effect removal methods on the PBMC3&68K dataset. (f) LISI assessments of different batch effect removal methods on the PBMC3&68K dataset.

Large-scale scRNA-seq Datasets

图片.png

图片.png

示例代码(python)

import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision import transforms
from torchvision import datasets
import torch.utils.data as Data  #Data是用来批训练的模块
from torchvision.utils import save_image
import numpy as np
import os
import pandas as pd
import torch.optim.lr_scheduler as lr_s 
from collections import Counter
import loompy
from scipy.spatial.distance import cdist
import scprep
import imap  #used for feature detected
import numpy as np
import squidpy as sq
import pandas as pd
import matplotlib.pyplot as plt
import phate
import graphtools as gt
import magic
import os
import datetime
import scanpy as sc
from skmisc.loess import loess
import sklearn.preprocessing as preprocessing
import umap.umap_ as umap
from numba import jit
from sklearn.metrics import silhouette_score
import random

def silhouette_coeff_ASW(adata, method_use='raw',save_dir='', save_fn='', percent_extract=0.8):
    random.seed(0)
    asw_fscore = []
    asw_bn = []
    asw_bn_sub = []
    asw_ctn = [] 
    iters = []
    for i in range(20):
        iters.append('iteration_'+str(i+1))
        rand_cidx = np.random.choice(adata.obs_names, size=int(len(adata.obs_names) * percent_extract), replace=False)
        print('nb extracted cells: ',len(rand_cidx))
        adata_ext = adata[rand_cidx,:]
        asw_batch = silhouette_score(adata_ext.X, adata_ext.obs['batch'])
        asw_celltype = silhouette_score(adata_ext.X, adata_ext.obs['louvain'])
        min_val = -1
        max_val = 1
        asw_batch_norm = (asw_batch - min_val) / (max_val - min_val)
        asw_celltype_norm = (asw_celltype - min_val) / (max_val - min_val)
        
        fscoreASW = (2 * (1 - asw_batch_norm)*(asw_celltype_norm))/(1 - asw_batch_norm + asw_celltype_norm)
        asw_fscore.append(fscoreASW)
        asw_bn.append(asw_batch_norm)
        asw_bn_sub.append(1-asw_batch_norm)
        asw_ctn.append(asw_celltype_norm)
    
#     iters.append('median_value')
#     asw_fscore.append(np.round(np.median(fscoreASW),3))
#     asw_bn.append(np.round(np.median(asw_batch_norm),3))
#     asw_bn_sub.append(np.round(1 - np.median(asw_batch_norm),3))
#     asw_ctn.append(np.round(np.median(asw_celltype_norm),3))
    df = pd.DataFrame({'asw_batch_norm':asw_bn, 'asw_batch_norm_sub': asw_bn_sub,
                       'asw_celltype_norm': asw_ctn, 'fscore':asw_fscore,
                       'method_use':np.repeat(method_use, len(asw_fscore))})
    df.to_csv(save_dir + save_fn + '.csv')
    print('Save output of pca in: ',save_dir)
    print(df.values.shape)
    print(df.keys())
    return df

Real AWGAN

adata = sc.read_loom('CRC_CONCAT.loom', sparse=False)
#preprocessing, same as the preprocessing code in the model
adata = imap.stage1.data_preprocess(adata)
res1 = sq.gr.ligrec(
    adata,
    n_perms=2000,
    cluster_key="celltype",
    copy=True,
    use_raw=False,
    transmitter_params={"categories": "ligand"},
    receiver_params={"categories": "receptor"}
)
res1

{'means': cluster_1 B cell ... Myeloid cell
cluster_2 B cell CD4 T cell ... ILC Myeloid cell
source target ...
FYN NTRK2 0.000000 0.020737 ... 0.000000 0.000000
CSF1 NTRK2 0.000000 0.003346 ... 0.000000 0.000000
HGF NTRK2 0.000000 0.002333 ... 0.000000 0.000000
AREG NTRK2 0.000000 0.177876 ... 0.000000 0.000000
PDGFC NTRK2 0.000000 0.001336 ... 0.000000 0.000000
... ... ... ... ... ...
SERPINF1 PLXDC2 0.009675 0.009378 ... 0.045745 0.186307
HPGDS PTGDR 0.004825 0.017215 ... 0.280152 0.043597
PTGDR2 0.000000 0.002751 ... 0.000000 0.045051
EBI3 IL12RB2 0.012528 0.015942 ... 0.017023 0.014924
VSTM1 ADGRG3 0.000805 0.000315 ... 0.075256 0.070944
[1167 rows x 25 columns],
'metadata': aspect_intercell_source ... uniprot_intercell_target
source target ...
FYN NTRK2 functional ... Q16620
CSF1 NTRK2 functional ... Q16620
HGF NTRK2 functional ... Q16620
AREG NTRK2 functional ... Q16620
PDGFC NTRK2 functional ... Q16620
... ... ... ...
SERPINF1 PLXDC2 functional ... Q6UX71
HPGDS PTGDR functional ... Q13258
PTGDR2 functional ... Q9Y5Y4
EBI3 IL12RB2 functional ... COMPLEX:P40189_Q99665
VSTM1 ADGRG3 functional ... Q86Y34
[1167 rows x 42 columns],
'pvalues': cluster_1 B cell ... Myeloid cell
cluster_2 B cell CD4 T cell CD8 T cell ... CD8 T cell ILC Myeloid cell
source target ...
FYN NTRK2 NaN NaN NaN ... NaN NaN NaN
CSF1 NTRK2 NaN NaN NaN ... NaN NaN NaN
HGF NTRK2 NaN NaN NaN ... NaN NaN NaN
AREG NTRK2 NaN NaN NaN ... NaN NaN NaN
PDGFC NTRK2 NaN NaN NaN ... NaN NaN NaN
... ... ... ... ... ... ... ...
SERPINF1 PLXDC2 NaN NaN NaN ... NaN 0.998 0.0
HPGDS PTGDR NaN NaN NaN ... 0.0 0.000 NaN
PTGDR2 NaN NaN NaN ... NaN NaN NaN
EBI3 IL12RB2 NaN NaN NaN ... NaN NaN NaN
VSTM1 ADGRG3 NaN NaN NaN ... NaN NaN NaN
[1167 rows x 25 columns]}

sq.pl.ligrec(res1, alpha=0.005, save='crc_ligen_origin.pdf')
index.png
adata1 = adata[adata.obs['batch'] == 'batch2']
adata2 = adata[adata.obs['batch'] == 'batch1']
adata = adata1.concatenate(adata2, batch_categories=['batch2','batch1'])
data_umap = umap.UMAP().fit_transform(adata.X)
scprep.plot.scatter2d(data_umap,adata.obs['batch'],figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Batch before removal')
index.png
scprep.plot.scatter2d(data_umap,adata.obs['celltype'],figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Celltype before removal')
index.png
adata_all = adata
adata_all.X = adata_all.X.todense()
#calculate cos distence
@jit(nopython=True)
def pdist(vec1,vec2):
  return np.dot(vec1,vec2)/(np.linalg.norm(vec1)*np.linalg.norm(vec2))
#calculate correlation index
@jit(nopython=True)
def find_correlation_index(frame1, frame2):
  result=[(1,1) for _ in range(len(frame2))]
  for i in range(len(frame2)):
    max_dist = -10
    it1=0
    it2=0
    for j in range(len(frame1)):
      dist = pdist(frame2[i],frame1[j])
      if dist>max_dist:
        max_dist = dist
        it1 = i
        it2 = j 
    result[i] = (it1, it2)
  return result
# A new approach to get the index, what is faster based on our research.
def find_correlation_index(frame1, frame2):
  distlist =  cdist(frame2,frame1,metric='cosine')
  result = np.argmin(distlist,axis=1)
  result2 = []
  for i in range(len(frame2)):
    result2.append((i,result[i]))
  return result2
adata_all = adata.copy()
adata3 = adata_all.copy()
ref_adata = adata3[adata3.obs['batch'] != adata3.obs['batch'][0]]
batch_adata = adata3[adata3.obs['batch'] == adata3.obs['batch'][0]]
c=Counter(adata3.obs['batch'])
c=dict(c)
ind_list = find_correlation_index(ref_adata.X, batch_adata.X)
common_pair = ind_list
donar_1_d = ref_adata.X
donar_2_d = batch_adata.X
result=[]
result1=[]
for i in common_pair:
  result.append(donar_1_d[i[1],:])
  result1.append(donar_2_d[i[0],:])
donar_1_t=np.array(result)
donar_2_t=np.array(result1)
train_data = donar_2_t
train_label = donar_1_t
def training_set_generator(frame1,frame2,ref,batch):
  common_pair = find_correlation_index(frame1,frame2)
  result = []
  result1 = []
  for i in common_pair:
    result.append(ref[i[1],:])
    result1.append(batch[i[0],:])
  return np.array(result),np.array(result1)
np.random.seed(999)
torch.manual_seed(999)
torch.cuda.manual_seed_all(999)
class Mish(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self,x):
    return x*torch.tanh(F.softplus(x))
#WGAN model, and it does not need to use bath normalization based on WGAN paper.
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(2000, 1024),  
            Mish(),
            nn.Linear(1024, 512),  
            Mish(),
            nn.Linear(512, 256),  
            Mish(),
            nn.Linear(256, 128),  
            Mish(),
            nn.Linear(128, 1),  
            Mish()

        )

    def forward(self, x):
        x = self.dis(x)
        return x
 
 
# WGAN generator
# Require batch normalization
class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.relu_l = nn.ReLU(True)
        self.gen = nn.Sequential(
         
            nn.Linear(2000, 1024),  
            nn.BatchNorm1d(1024, eps = 1e-7, momentum=0.01),
            nn.Dropout(0.5),
            Mish(),

            nn.Linear(1024, 512),  
            nn.BatchNorm1d(512, eps = 1e-7, momentum=0.01),
            Mish(),

            nn.Linear(512, 256),  
            nn.BatchNorm1d(256, eps = 1e-7, momentum=0.01),
            Mish(),

            nn.Linear(256, 512),  
            nn.BatchNorm1d(512, eps = 1e-7, momentum=0.01),
            Mish(),
  

            nn.Linear(512, 1024),  
            nn.BatchNorm1d(1024, eps = 1e-7, momentum=0.01),
            Mish(),

            nn.Linear(1024, 2000), 
            nn.Dropout(0.5) 
           
        )

    def forward(self, x):
        gre = self.gen(x)
        return self.relu_l(gre+x)    #residual network
 
 
# 创建对象
D = discriminator()
G = generator()

if torch.cuda.is_available():
  D = D.cuda()
  G = G.cuda()
# calculate gradient penalty
def calculate_gradient_penalty(real_data, fake_data, D): 
  eta = torch.FloatTensor(real_data.size(0),1).uniform_(0,1) 
  eta = eta.expand(real_data.size(0), real_data.size(1)) 
  cuda = True if torch.cuda.is_available() else False 
  if cuda: 
    eta = eta.cuda() 
  else: 
    eta = eta 
  interpolated = eta * real_data + ((1 - eta) * fake_data) 
  if cuda: 
    interpolated = interpolated.cuda() 
  else: 
    interpolated = interpolated 
   # define it to calculate gradient 
  interpolated = Variable(interpolated, requires_grad=True) 
   # calculate probability of interpolated examples 
  prob_interpolated = D(interpolated) 
  # calculate gradients of probabilities with respect to examples 
  gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated, 
  grad_outputs=torch.ones( 
  prob_interpolated.size()).cuda() if cuda else torch.ones( 
  prob_interpolated.size()), 
  create_graph=True, retain_graph=True)[0] 
  grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 
  return grad_penalty
@jit(nopython = True)
def determine_batch(val1):
  val_list =[32,40,52,64,128,256]
  for i in val_list:
    if val1%i !=1:
      return i
    else:
      continue
  return val1
# parameters
EPOCH = 40
MAX_ITER = train_data.shape[0]
batch = 128
b1 = 0.9
b2 = 0.999
lambda_1 = 1/100


d_optimizer = torch.optim.AdamW(D.parameters(), lr=0.0001)
g_optimizer = torch.optim.AdamW(G.parameters(), lr=0.0001)

c=Counter(adata3.obs['batch'])
c=dict(c)

err_G = []
err_D = []
stop = 0
iter = 0
#####################Since we only have two batches, so we adopt easier structure
for epoch in range(EPOCH):
  print(epoch)
  for time in range(0,MAX_ITER,batch):
    true_data = torch.FloatTensor(train_label[time:time+batch,:]).cuda()
    false_data = torch.FloatTensor(train_data[time:time+batch,:]).cuda()
    

    #train d at first

    d_optimizer.zero_grad()

    real_out = D(true_data)
    real_label_loss = -torch.mean(real_out)

    err_D.append(real_label_loss.cpu().float())

    # train use WGAN

    fake_out_new = G(false_data).detach()
    fake_out = D(fake_out_new)

    div = calculate_gradient_penalty(true_data, fake_out_new, D)

    label_loss = real_label_loss+torch.mean(fake_out)+div/lambda_1
    label_loss.backward()

    err_D.append(label_loss.cpu().item())
    

    d_optimizer.step()
  
    #train G

    real_out = G(false_data)
    real_output = D(real_out)

    real_loss1 = -torch.mean(real_output)
    err_G.append(real_loss1.cpu().item())

    g_optimizer.zero_grad()

    real_loss1.backward()
    g_optimizer.step()

    if(time%100==0):
      print("g step loss",real_loss1)
    iter += 1

  if stop == 1:
    break

G.eval()
test_data = torch.FloatTensor(train_data).cuda()
remove_batch_data = G(test_data).detach().cpu().numpy()

data =np.vstack([remove_batch_data, donar_1_d])

data.shape

(53018, 3892)

data_umap = umap.UMAP().fit_transform(data)

scprep.plot.scatter2d(data_umap, c=adata3.obs['batch'], figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Batch after removal')

scprep.plot.scatter2d(data_umap, c=adata3.obs['celltype'], figsize=(12,8), cmap="Spectral",
                      ticks=False, label_prefix="UMAP", s = 20, title = 'Celltype after removal')

obs = pd.DataFrame()

obs['batch'] = adata3.obs['batch']
obs['louvain'] = adata3.obs['celltype']

funcdata = sc.AnnData(data, obs)


funcdata = sc.AnnData(data, obs)
silhouette_coeff_ASW(funcdata)

Evaluation

sc.set_figure_params(dpi=100,color_map = 'viridis_r',fontsize=25)
sc.settings.verbosity = 1
sc.logging.print_header()
adata_gold = sc.read_loom('CRC_gold.loom', sparse=False)
adata = sc.read_loom('CRC_CONCAT.loom', sparse=False)
adata_all = imap.stage1.data_preprocess(adata,'batch')
adata_gold.obs['celltype'] = adata_gold.obs.louvain.copy()
adata_gold.var_names = adata_all.var_names.copy()
adata_gold.obs_names = adata_all.obs_names.copy()
sc.tl.rank_genes_groups(adata_gold, groupby='celltype', method='wilcoxon')
sc.pl.rank_genes_groups_heatmap(adata_gold, n_genes=2, use_raw=False, swap_axes=True, vmin=-3, vmax=3, cmap='bwr',figsize=(20,7), show=False)
index.png
sc.pl.rank_genes_groups_tracksplot(adata_gold, n_genes=2,figsize=(25,7))
index.png
adata_new = adata_gold
res = sq.gr.ligrec(
    adata_new,
    n_perms=1000,
    cluster_key="celltype",
    copy=True,
    use_raw=False,
    transmitter_params={"categories": "ligand"},
    receiver_params={"categories": "receptor"}
)
df_new.to_csv('cellphone_new10x.csv')
sq.pl.ligrec(res, alpha=0.005,save='crc_ligen.pdf')
图片.png

生活很好,有你更好

你可能感兴趣的:(10X单细胞(10X空间转录组)批次效应去除大盘点2 & AWGAN)