diffusers中textual inversion微调

textual_inversion.py

accelerator = Accelerator() ->

tokenizer = CLIPTokenizer.from_pretrained(,"tokenizer")->
noise_scheduler = DDPMScheduler.from_pretrained(,'scheduler')->
text_encoder = CLIPTokenizer.from_pretrained(,'text_encoder')->
vae = AutoencoderKL.from_Pretrained(,'vae')->
unet = UNet2DConditionModel.from_pretrained(,'unet')->

placeholder_tokens = [args.placeholder_token]->
for i in range(1,args.num_vectors):
    additional_tokens.append(f"{args.placeholder}_{i}")
placeholder_tokens += additional_tokens
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
token_ids = tokenizer.encode(args.initializer_token,add_special_tokens=False)->

initializer_token_id = token_ids[0]
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)

text_encoder.resize_token_embeddings(len(tokenizer))
token_embeds = text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
    for token_id in placeholder_token_ids:
        token_embeds[token_id] = token_embeds[initializer_token_id].clone()

vae.require_grads_(False)
unet.require_grads_(False)
text_encoder.text_model.encoder.requires_grad_(False)
text_encoder.text_model.final_layer_norm.requires_grad_(False)
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)

args.learning_rate = (args.learning_rate*args.gradient_accumulation_steps*args.train_batch_size*accelerator.num_processes)
# only optimize the embeddings
optimizer = torch.optim.AdamW(text_encoder.get_input_embeddings().parameters(),args.learning_rate,betas=(args.adam_beta1,args.adam_beta2),weight_decay=args.adam_weight_decay,args.adam_epsilon)

train_dataset = TextualInversionDataset(args.train_data_dir,tokenizer,resolution,args.placeholder_token,repeats,args.learnable_property,args.center_crop)->
TextualInversionDataset.__getitem__
- placeholder_string = self.placeholder_token ->
- text = random.choice(self.templates).format(placeholder_string) # 随机选择一种模版
- example['input_ids'] = self.tokenizer(text,padding='max_length',truncation=True,max_length=self.tokenizer.model_max_length,return_tensors='pt').input_ids[0]
- example['pixel_values'] = torch.from_numpy(image).permute(2,0,1)
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size,shuffle)

num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
lr_scheduler = get_scheduler(lr_scheduler,optimizer,args.lr_warmup_steps*accelerator.num_processes,args.max_train_steps*accelerator.num_processes)
text_encoder,optimizer,train_dataloader,lr_scheduler = accelerator.prepare(text_encoder,optimizer,train_dataloader,lr_scheduler)

total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()

for epoch in range(first_epoch,args.num_train_epochs):
    text_encoder.train()
    for steps,batch in enumerate(train_dataloader):
        with accelerator.accumulate(text_encoder):
            latents = vae.encode(batch['pixel_values'].to().latent_dist.sample().detach())
            latents = latents*vae.config.scaling_factor
            
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            timesteps = torch.randn(0,noise_scheduler.config.num_train_timesteps,(bsz,))
            noisy_latents = noise_scheduler.add_noise(latents,noise,timesteps)
            encoder_hidden_states = text_encoder(batch['input_ids'])[0]
            model_pred = unet(noisy_latents,timesteps,encoder_hidden_states).sample
            
            target = noise
            loss = F.mse_loss(model_pred.float(),target.float())
            accelerator.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.step()
               
            # 除了新加的toekns,其他的都不更新
            index_no_updates = torch.ones((len(tokenizer),),dtype=torch.bool)
            index_no_updates[min(placeholder_token_ids):max(placeholder_token_ids)+1] = False
            with torch.no_grad():    
                accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates]=orig_embeds_params[index_no_updates]

你可能感兴趣的:(大模型,多模态和生成,stable,diffusion,embeddings)