Analyzing A BAD Result In WANXDiscussion G-U-N And Rectified Diffusion
Introduction
In this article, we delve into a specific implementation related to generative models and rectified diffusion, focusing on a potentially problematic outcome (A BAD Result) within the context of the WANXDiscussion category. This discussion primarily revolves around the code snippets provided, which showcase classes like EulerSolver
, EMA
(Exponential Moving Average), and LightningModelForTrain
. These components are crucial in the realm of G-U-N (Generative-Uncertainty Network) and rectified diffusion models, particularly in video generation pipelines. We aim to dissect the code, understand the functionalities of each class, and analyze potential reasons behind a suboptimal result, making this a comprehensive exploration for both practitioners and enthusiasts in the field of AI-driven video synthesis.
Understanding the Code Components
EulerSolver Class
The EulerSolver
class is a critical component in the diffusion process, specifically designed to handle the reverse diffusion steps using an Euler method. Let's break down its functionality:
class EulerSolver:
def __init__(self, sigmas, timesteps=1000, euler_timesteps=50):
self.step_ratio = timesteps // euler_timesteps
self.euler_timesteps = (np.arange(1, euler_timesteps + 1) * self.step_ratio).round().astype(np.int64) - 1
self.euler_timesteps_prev = np.asarray([0] + self.euler_timesteps[:-1].tolist())
self.sigmas = sigmas[self.euler_timesteps]
self.sigmas_prev = np.asarray([sigmas[0]] +
sigmas[self.euler_timesteps[:-1]].tolist()) # either use sigma0 or 0
self.euler_timesteps = torch.from_numpy(self.euler_timesteps).long()
self.euler_timesteps_prev = torch.from_numpy(self.euler_timesteps_prev).long()
self.sigmas = torch.from_numpy(self.sigmas)
self.sigmas_prev = torch.from_numpy(self.sigmas_prev)
def to(self, device):
self.euler_timesteps = self.euler_timesteps.to(device)
self.euler_timesteps_prev = self.euler_timesteps_prev.to(device)
self.sigmas = self.sigmas.to(device)
self.sigmas_prev = self.sigmas_prev.to(device)
return self
def euler_step(self, sample, model_pred, timestep_index):
sigma = extract_into_tensor(self.sigmas, timestep_index, model_pred.shape)
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index, model_pred.shape)
x_prev = sample + (sigma_prev - sigma) * model_pred
return x_prev
def euler_style_multiphase_pred(
self,
sample,
model_pred,
timestep_index,
multiphase,
is_target=False,
):
inference_indices = np.linspace(0, len(self.euler_timesteps), num=multiphase, endpoint=False)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (torch.from_numpy(inference_indices).long().to(self.euler_timesteps.device))
expanded_timestep_index = timestep_index.unsqueeze(1).expand(-1, inference_indices.size(0))
valid_indices_mask = expanded_timestep_index >= inference_indices
last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1)
last_valid_index = inference_indices.size(0) - 1 - last_valid_index
timestep_index_end = inference_indices[last_valid_index]
if is_target:
sigma = extract_into_tensor(self.sigmas_prev, timestep_index, sample.shape)
else:
sigma = extract_into_tensor(self.sigmas, timestep_index, sample.shape)
sigma_prev = extract_into_tensor(self.sigmas_prev, timestep_index_end, sample.shape)
x_prev = sample + (sigma_prev - sigma) * model_pred
return x_prev, timestep_index_end
- Initialization (
__init__
): The constructor takessigmas
(noise levels),timesteps
(total diffusion steps), andeuler_timesteps
(number of Euler steps) as input. It calculates the step ratio and discretizes the timesteps for the Euler method.sigmas
andsigmas_prev
represent the noise levels at current and previous timesteps, respectively. These are crucial for guiding the reverse diffusion process, allowing the model to gradually reconstruct the data from noise. The use ofnumpy
and then converting totorch
tensors ensures compatibility with the PyTorch framework, commonly used in deep learning. - Device Placement (
to
): This method moves the relevant tensors (euler_timesteps, sigmas, etc.) to the specified device (CPU or GPU). Efficient device management is paramount for performance, especially when dealing with large models and datasets. By moving these tensors to the GPU, we leverage its parallel processing capabilities, significantly accelerating computations. - Euler Step (
euler_step
): This function performs a single Euler step, updating the sample based on the model's prediction. It extracts the noise levels (sigma
andsigma_prev
) for the current timestep and calculates the updated sample (x_prev
). The Euler step is the core of the solver, iteratively refining the sample to reduce noise and reconstruct the original data. The formulax_prev = sample + (sigma_prev - sigma) * model_pred
is a discrete approximation of the reverse diffusion process, where the model's prediction (model_pred
) guides the denoising. - Euler Style Multiphase Prediction (
euler_style_multiphase_pred
): This method implements a multiphase prediction strategy, which can be beneficial for refining the output at different stages of the diffusion process. It calculates inference indices to divide the timesteps into phases and predicts the sample for each phase. The multiphase approach allows for a more nuanced control over the denoising process, potentially leading to higher quality results. By adjusting the number of phases, one can balance computational cost and output fidelity. The method also handles a target-specific prediction (is_target
), which may involve using different noise levels for refinement.
The EulerSolver class is essential for managing the reverse diffusion process by discretizing the continuous diffusion trajectory into manageable steps. Its methods allow for precise control over the denoising process, making it a key component in generative models.
EMA Class
The EMA
(Exponential Moving Average) class is used to maintain a moving average of the model parameters during training. This technique often leads to more stable and generalizable models. Here’s a detailed look:
class EMA:
"""Exponential Moving Average for model parameters"""
def __init__(self, model, decay=0.9999, updates=0):
trainable_param_names = list(
filter(lambda named_param: named_param[1].requires_grad, model.named_parameters())
)
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = model.state_dict()
lora_state_dict = {}
self.need_train_cnt = 0
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
self.need_train_cnt += 1
self.state_dict = copy.deepcopy(lora_state_dict)
self.decay = decay
self.updates = updates
def update(self, model):
trainable_param_names = list(
filter(lambda named_param: named_param[1].requires_grad, model.named_parameters())
)
print('trainable_param_names list', len(trainable_param_names))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
print('trainable_param_names set', len(trainable_param_names))
state_dict = model.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
f = False
cnt = 0
with torch.no_grad():
for name, param in lora_state_dict.items():
self.state_dict[name] = self.state_dict[name].to(param.device)
if self.state_dict[name].shape == param.shape:
self.state_dict[name] = self.state_dict[name] * self.decay + (1- self.decay) * param
# self.state_dict[name].mul_(d).add_(param, alpha=1. - d)
f = True
cnt += 1
else:
print('name', name, self.state_dict[name].shape, param.shape)
if f:
self.updates += 1
- Initialization (
__init__
): The constructor takes the model, a decay rate (decay
), and an initial update count (updates
). It filters out the trainable parameters of the model and creates a deep copy of their state dictionary. Thedecay
parameter controls the weight given to the previous EMA parameters versus the current model parameters. A higher decay rate means the EMA model changes more slowly, providing more stability. The deep copy ensures that the EMA model starts with the same parameters as the original model but evolves independently. - Update (
update
): This method updates the EMA parameters using the current model parameters. It iterates through the trainable parameters and applies the moving average formula. The core of the EMA update is the formulaself.state_dict[name] = self.state_dict[name] * self.decay + (1 - self.decay) * param
. This blends the old EMA parameters with the current parameters, giving more weight to the old parameters based on the decay rate. The check for shape equality (self.state_dict[name].shape == param.shape
) is crucial to avoid errors when model architectures or parameter shapes change during training, such as when using techniques like LoRA (Low-Rank Adaptation). Thetorch.no_grad()
context ensures that these updates do not affect the gradient computation, keeping the EMA update separate from the training process.
The EMA class is a valuable tool for stabilizing training and improving model generalization. By maintaining a moving average of the model parameters, it helps smooth out fluctuations and can lead to better performance, particularly in generative models.
LightningModelForTrain Class
The LightningModelForTrain
class is a PyTorch Lightning module that encapsulates the training logic for the diffusion model. It integrates the EulerSolver
, EMA
, and the diffusion pipeline into a cohesive training framework. Let's examine its components:
class LightningModelForTrain(pl.LightningModule):
def __init__(
self,
dit_path,
learning_rate=1e-5,
lora_rank=4,
lora_alpha=4,
train_architecture="lora",
lora_target_modules="q,k,v,o,ffn.0,ffn.2",
init_lora_weights="kaiming",
deepspeed_offload=False,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
pretrained_lora_path=None,
args=None,
):
super().__init__()
self.save_hyperparameters()
self.not_apply_cfg_solver = False
print('multi_phased_distill_schedule', args.multi_phased_distill_schedule)
print('shift', args.shift)
print('decay', args.decay)
print('learning_rate', args.learning_rate)
print('self.not_apply_cfg_solver', self.not_apply_cfg_solver)
self.args = args
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([dit_path])
# self.not_apply_cfg_solver = args.not_apply_cfg_solver
# if os.path.isfile(dit_path):
# model_manager.load_models([dit_path])
# else:
# dit_path = dit_path.split(",")
# model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(args.num_euler_timesteps)
self.teacher_transformer = deepcopy(self.pipe.denoising_model())
self.teacher_transformer.requires_grad_(False)
self.teacher_transformer.eval()
self.teacher_transformer.to(self.device, dtype=self.pipe.torch_dtype)
self.huber_loss = torch.nn.HuberLoss(delta=0.001)
self.freeze_parameters()
if train_architecture == "lora":
self.add_lora_to_model(
self.pipe.denoising_model(),
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_target_modules=lora_target_modules,
init_lora_weights=init_lora_weights,
pretrained_lora_path=pretrained_lora_path,
)
else:
self.pipe.denoising_model().requires_grad_(True)
if self.global_rank == 0:
self.ema = EMA(self.pipe.denoising_model(), args.decay)
self.ema_transformer = deepcopy(self.pipe.denoising_model())
self.ema_transformer.requires_grad_(False)
self.ema_transformer.eval()
self.ema_transformer.to(self.device, dtype=self.pipe.torch_dtype)
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.deepspeed_offload = deepspeed_offload
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=26)
# self.noise_scheduler = FlowMatchScheduler(shift=args.shift, sigma_min=0.0, extra_one_step=True)
# self.noise_scheduler.set_timesteps(1000, training=True)
# self.noise_scheduler = self.pipe.scheduler
self.multi_phased_distill_schedule=args.multi_phased_distill_schedule
sigmas = self.noise_scheduler.sigmas
self.solver = EulerSolver(
sigmas.numpy()[::-1],
self.noise_scheduler.num_train_timesteps,
euler_timesteps=args.num_euler_timesteps,
)
self.solver.to(self.device)
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
# Add LoRA to UNet
self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","),
)
print('dit model', model)
model = inject_adapter_in_model(lora_config, model)
for param in model.parameters():
# Upcast LoRA parameters into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
print('lora dit model', model)
# Lora pretrained lora weights
if pretrained_lora_path is not None:
pretrained_lora_path = pretrained_lora_path[0] if isinstance(pretrained_lora_path, list) else pretrained_lora_path
print("pretrained_lora_path", pretrained_lora_path)
state_dict = load_state_dict(pretrained_lora_path)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
def training_step(self, batch, batch_idx):
multiphase = get_num_phases(self.multi_phased_distill_schedule, self.global_step)
# Data
latents = batch["latents"][:,0].to(self.device,dtype=self.pipe.torch_dtype).transpose(1, 2)
# print('latents', latents.shape)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = prompt_emb["context"][:,0].to(self.device,dtype=self.pipe.torch_dtype)
unprompt_emb = batch["unprompt_emb"]
unprompt_emb["context"] = unprompt_emb["context"][:, 0].to(self.device,dtype=self.pipe.torch_dtype)
self.solver.to(self.device)
extra_input = {}
distill_cfg = 3.0
model_input = latents
noise = torch.randn_like(latents)
if self.global_step % 10 == 0:
shift = random.randint(5, 50)
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
sigmas = self.noise_scheduler.sigmas
self.solver = EulerSolver(
sigmas.numpy()[::-1],
self.noise_scheduler.num_train_timesteps,
euler_timesteps=args.num_euler_timesteps,
)
self.solver.to(self.device)
bsz = model_input.shape[0]
index = torch.randint(0, self.args.num_euler_timesteps, (bsz, ), device=model_input.device).long()
sigmas = extract_into_tensor(self.solver.sigmas, index, model_input.shape)
sigmas_prev = extract_into_tensor(self.solver.sigmas_prev, index, model_input.shape)
timesteps = (sigmas * self.noise_scheduler.num_train_timesteps).view(-1)
timesteps_prev = (sigmas_prev * self.noise_scheduler.num_train_timesteps).view(-1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
loss = 0.0
if random.random() < 0.1:
prompt_emb = unprompt_emb
# try:
if True:
w_max = 15.0
w_min = 3.0
w = (w_max - w_min) * random.random() + w_min
latent_model_input = noisy_model_input
index_ = index.clone().detach()
with torch.no_grad():
for _ in range(10):
index_ = torch.where(
index_ < 0,
torch.zeros_like(index_),
index_,
)
sigmas = extract_into_tensor(self.solver.sigmas, index_, model_input.shape)
timesteps_ = (sigmas * self.noise_scheduler.num_train_timesteps).view(-1)
timesteps_ = timesteps_.to(self.device, dtype=self.pipe.torch_dtype)
latent_model_input = latent_model_input.to(self.device, dtype=self.pipe.torch_dtype)
# predict noise model_output
with torch.cuda.amp.autocast(dtype=self.pipe.torch_dtype):
cond_teacher_output = self.teacher_transformer(
latent_model_input,
timestep=timesteps_,
**prompt_emb,
**extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
uncond_teacher_output = self.teacher_transformer(
latent_model_input,
timestep=timesteps_,
**unprompt_emb,
**extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
teacher_output = uncond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
latent_model_input = self.solver.euler_step(latent_model_input, teacher_output, index)
index_ = index_ - 1
index = index + torch.randint(0, 10, (bsz, ), device=model_input.device).long()
index = torch.clamp(index, min=0, max=49).long()
target = latent_model_input
sigmas = extract_into_tensor(self.solver.sigmas, index, model_input.shape)
timesteps = (sigmas * self.noise_scheduler.num_train_timesteps).view(-1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
noisy_model_input = noisy_model_input.to(self.device, dtype=self.pipe.torch_dtype)
timesteps = timesteps.to(self.device, dtype=self.pipe.torch_dtype)
with torch.cuda.amp.autocast(dtype=self.pipe.torch_dtype):
model_pred = self.pipe.denoising_model()(
noisy_model_input,
timestep=timesteps,
**prompt_emb,
**extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing,
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
)
# model_pred, end_index = self.solver.euler_style_multiphase_pred(noisy_model_input, model_pred, index, multiphase)
model_pred = self.solver.euler_step(noisy_model_input, model_pred, index)
huber_c = 0.001
# loss = loss.mean()
huber_loss = torch.mean(torch.sqrt((model_pred.float() - target.float())**2 + huber_c**2) - huber_c)
# Record log
loss_mo = torch.nn.functional.mse_loss(model_pred[:,:,1:].float(), model_pred[:,:,:-1].float()) * 0.01
loss_tc = torch.nn.functional.mse_loss((model_pred[:,:,5:] - model_pred[:,:,:-5]).float(), (latents[:,:,5:] - latents[:,:,:-5]).float()) * 0.005
loss_tc2 = torch.nn.functional.mse_loss((model_pred - latents).float(), (target - latents).float()) * 0.1
g_loss = huber_loss + loss_tc + loss_mo + loss_tc2
# Record log
self.log("loss_tc", loss_tc, prog_bar=True)
self.log("loss_mo", loss_mo, prog_bar=True)
self.log("loss_tc2", loss_tc2, prog_bar=True)
self.log("huber_loss", huber_loss, prog_bar=True)
self.log("train_loss", g_loss, prog_bar=True)
# self.log("multiphase", multiphase, prog_bar=True)
# self.log("train_loss", loss)
# self.log("multiphase", multiphase)
return g_loss
- Initialization (
__init__
): This method initializes the model, loading the pre-trained DiT (Diffusion Transformer) model, setting up the WanVideoPipeline, and configuring the teacher transformer for knowledge distillation. It also handles LoRA (Low-Rank Adaptation) if specified, which is a technique to efficiently adapt pre-trained models to new tasks. Key steps include:- Loading the DiT model using a
ModelManager
. This involves specifying the path (dit_path
) and ensuring the model is loaded with the correct data type (torch.bfloat16
) for mixed-precision training. - Creating a
WanVideoPipeline
from the loaded model. This pipeline encapsulates the entire video generation process, including the scheduler for managing diffusion timesteps. - Setting up the teacher transformer, which is a copy of the denoising model. The teacher model is used to guide the student model (the model being trained) during knowledge distillation. It’s crucial that the teacher model is in eval mode and has gradients disabled (
requires_grad_(False)
) to prevent it from being inadvertently trained. - Freezing the parameters of the pipeline to prevent them from being directly updated during training. This is a common practice when using LoRA or other parameter-efficient fine-tuning methods.
- Adding LoRA adapters to the denoising model if
train_architecture
is set to `
- Loading the DiT model using a