train_text_to_image.py
代码:
accelerator = Accelerator()->
noise_sheduler = DDPMScheduler.from_pretrained(,"scheduler")->
tokenizer = CLIPTokenizer.from_pretrained(,"tokenizer")->
text_encoder = CLIPTokenizer.form_pretrained(,"text_encoder")->
vae = AutoencoderKL.from_pretrained(,"vae")->
unet = UNet2DConditionModel(,'unet')->
vae.requires_grad_(False)->
text_encoder.requires_grad_(False)->
unet.enable_gradient_checkpoint()->
optimizer_cls = torch.optim.AdamW->
optimizer = optimizer_cls(unet.parameters(),lr,betas,weight_decay,eps)->
dataset = load_dataset(dataset_name,dataset_config_name,cache_dir)->
train_transforms = transforms.Compose([])->
train_dataset = dataset['train'].with_transformer(preprocess_train)->
- images = [image.convert("RGB") for image in examples[image_column]]
- examples["pixel_values"] = [train_transforms(image) for image in images]
- examples["input_ids"] = tokenize_captions(examples)
-- inputs = tokenizer(captions,max_length,padding,truncation)->
train_dataloader = torch.utils.data.DataLoader(train_dataset,shuffle)->
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 梯度累计和bs相关->
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch->
lr_scheduler = get_scheduler()->
unet,optimizer,train_dataloader,lr_scheduler = accelerator.prepare(unet,optimizer,train_dataloader,lr_scheduler)->
text_encoder.to(accelerator.device)->
vae.to(accelerator.device)->
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
for epoch in range(first_epoch,args.num_train_epochs):
unet.train()
for step,batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist_sample()
latents = latents*vae.config.scaling_factor
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
target = noise
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
train_text_to_image_lora.py
代码:
accelerator = Accelerator()->
noise_sheduler = DDPMScheduler.from_pretrained(,"scheduler")->
tokenizer = CLIPTokenizer.from_pretrained(,"tokenizer")->
text_encoder = CLIPTokenizer.form_pretrained(,"text_encoder")->
vae = AutoencoderKL.from_pretrained(,"vae")->
unet = UNet2DConditionModel(,'unet')->
unet.requires_grad_(False)->
vae.requires_grad_(False)->
text_encoder.requires_grad_(False)->
unet.to(accelerator.device)->
vae.to(accelerator.device)->
text_encoder(accelerator.device)->
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
for name in unet.attn_processors.keys():
lora_attn_procs[name]=LoRAAttnProcessor(hidden_size,cross_attention_dim,rank)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayer(unet.attn_processors)->
optimizer_cls = torch.optim.AdamW->
optimizer = optimizer_cls(lora_layers.parameters(),lr,betas,weight_decay,eps)->
dataset = load_dataset(args.dataset_name,args.dataset_config_name,cache_dir)->
train_transforms = transforms.Compose([])->
train_dataset = dataset['train'].with_transformer(preprocess_train)->
- images = [image.convert("RGB") for image in examples[image_column]]
- examples["pixel_values"] = [train_transforms(image) for image in images]
- examples["input_ids"] = tokenize_captions(examples)
-- inputs = tokenizer(captions,max_length,padding,truncation)->
train_dataloader = torch.utils.data.DataLoader(train_dataset,shuffle)->
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 梯度累计和bs相关->
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch->
lr_scheduler = get_scheduler()->
lora_layers,optimizer,train_dataloader,lr_scheduler = accelerator.prepare(lora_layers,optimizer,train_dataloader,lr_scheduler)->
text_encoder.to(accelerator.device)->
vae.to(accelerator.device)->
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
for epoch in range(first_epoch,args.num_train_epochs):
unet.train()
for step,batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist_sample()
latents = latents*vae.config.scaling_factor
noise = torch.randn_like(latents)
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
target = noise
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pipeline = DiffusionPipeline.from_pretrained()->
pipeline.unet.load_attn_procs()->