train_dreambooth.py
代码:
accelerator = Accelerator()->
# Generate class image if prior oreservation is enabled
if args.with_prior_preservation:
if cur_class_images
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
for example in sample_dataloader:
images = pipeline(example['prompt']).images
tokenizer = AutoTokenizer.from_pretrained(,"tokenizer")->
text_encoder_cls = import_model_class_from_model_name_or_path()->
noise_scheduler = DDPMScheduler.from_pretrained(,"scheduler")->
text_encoder = text_encoder_cls.from_pretrained(,"text_encoder")->
vae = AutoencoderKL.from_pretrained(,"vae")->
unet = UNet2DConditionModel.from_pretrained(,"unet")->
accelerator.register_save_state_pre_hook(save_model_hook)->
accelerator.register_load_state_pre_hook(load_model_hook)->
vae.requires_grad_(False)->
text_encoder.requires_grad_(False)->
unet.enable_gradient_checkpointing()->
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
optimizer_class = torch.optim.AdamW->
params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters())->
optimizer = optimizer_class(params_to_optimize,lr,betas,weight_decay,eps)->
train_dataset = DreamBoothDataset(instance_data_root,instance_prompt,class_data_root,class_prompt,class_num,tokenizer,size,center_crop,encoder_hidden_states,instance_prompt_encoder_hidden_states,tokenizer_max_length)->
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
lr_scheduler = get_scheduler(lr_scheduler,optimizer,num_warmup_steps,num_training_steps,num_cycles,power)->
unet,text_encoder,optimizer,train_dataloader,lr_scheduler = accelerator.prepare(unet,text_encoder,optimizer,train_dataloader,lr_scheduler)->
for epoch in rang(first_epoch,args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step,batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch['pixel_values'].to(weight_dtype)
model_input = vae.encode(batch['pixel_values'].to().latent_dist.sample())
model_input = model_input*vae.config.scaling_factor
noise = torch.randn_like(model_input)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
encoder_hidden_states = encode_prompt(text_encoder,batch["input_ids"],batch["attention_mask"], text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,)
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels).sample
target = noise
model_pred,model_pred_prior = torch.chunk(model_pred,2,dim=0)->
target,target_prior = torch.chunk(target,2,dim=2)->
loss = F.mse_loss(model_pred.float(),target.float())
prior_loss = F.mse_loss(model_pred_prior.float(),target_prior.float())
loss = loss+args.prior_loss_weight*prior_loss
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
train_dreambooth_lora.py
accelerate = Accelerator()->
tokenizer = AutoTokenizer.from_pretrained(,"tokenizer")->
text_encoder_cls = import_model_class_from_model_name_or_path()->
noise_scheduler = DDPMScheduler.from_pretrained(,"scheduler")->
text_encoder = text_encoder_cls.from_pretrained(,"text_encoder")->
vae = AutoencoderKL.from_pretrained(,"vae")->
unet = UNet2DConditionModel.from_pretrained(,"unet")->
vae.requires_grad_(False)->
text_encoder.requires_grad_(False)->
unet.requires_grad_(False)->
# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
for name,attn_processor in unet.attn_processor.items():
if isinstance(attn_processor,(AttnAddedKVProcessor,SlicedAttnAddedKVProcessor,AttnAddedKVProcessor)):
lora_attn_processor_class = LoRAAttnAddedKVProcessor
else:
lora_attn_process_class = (LoRAAttnProcessor2_0 if hasattr(F,'scaled_dot_product_attention') else LoRAAttnProcessor)
module = lora_attn_processor_class(hidden_size,cross_attention_dim,rank)
unet_lora_attn_procs[name] = module
unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs)
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder)
optimizer_class = torch.optim.AdamW->
params_to_optimize = (itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters())->
optimizer = optimizer_class(params_to_optimize,lr,betas,weight_decay,eps)->
train_dataset = DreamBoothDataset(instance_data_root,instance_prompt,class_data_root,class_prompt,class_num,tokenizer,size,center_crop,encoder_hidden_states,instance_prompt_encoder_hidden_states,tokenizer_max_length)->
train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
lr_scheduler = get_scheduler(lr_scheduler,optimizer,num_warmup_steps,num_training_steps,num_cycles,power)->
unet,text_encoder,optimizer,train_dataloader,lr_scheduler = accelerator.prepare(unet,text_encoder,optimizer,train_dataloader,lr_scheduler)->
for epoch in range(first_epoch,args.num_train_epoch):
unet.train()
text_encoder.train()
for step,batch in enumerate(train_dataloader):
with accelerate.accumulate(unet):
pixel_values = batch['pixel_values']
model_input = vae.encode(pixel_values).latent_dist_sample()
model_input = model_input*vae.config.scaling_factor
noise = torch.randn_like(model_input)
timesteps = torch.randint(0,noise_scheduler.config.num_train_timesteps)
noisy_model_input = noise_scheduler.add_noise(model_input,noise,timesteps)
encoder_hidden_states = encode_prompt(text_encoder,batch["input_ids"],batch["attention_mask"], text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,)
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels).sample
target = noise
loss = F.mse_loss(model_pred.float(),target.float())
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()