Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profile and speedup pipelines #9475

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
profile cogvideox
  • Loading branch information
a-r-r-o-w committed Sep 19, 2024
commit 5687dc6d39fb27b0bad1654f012919f657362e8f
99 changes: 52 additions & 47 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
from torch import nn
from torch.profiler import record_function

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
Expand Down Expand Up @@ -433,49 +434,52 @@ def forward(
batch_size, num_frames, channels, height, width = hidden_states.shape

# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
with record_function("time embedding"):
timesteps = timestep
t_emb = self.time_proj(timesteps)

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)

# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
with record_function("patch embedding"):
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)

text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]

# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
with record_function("blocks"):
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)

if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
Expand All @@ -487,16 +491,17 @@ def custom_forward(*inputs):
hidden_states = hidden_states[:, text_seq_length:]

# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
with record_function("final output"):
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
Expand Down
100 changes: 68 additions & 32 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.profiler import record_function
from transformers import T5EncoderModel, T5Tokenizer

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
Expand Down Expand Up @@ -679,39 +680,73 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])

# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
with record_function(f"transformer_iteration_{i}"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# noise_pred = noise_pred.float()

# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
with record_function(f"guidance_{i}"):
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0))
/ 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

with record_function("1.1 scheduler"):
prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps

with record_function("1.2 scheduler"):
alpha_prod_t = self.scheduler.alphas_cumprod[t]

with record_function("1.3 scheduler"):
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else self.scheduler.final_alpha_cumprod
)
latents = latents.to(prompt_embeds.dtype)

with record_function("1.4 scheduler"):
beta_prod_t = 1 - alpha_prod_t

with record_function("1.5 scheduler"):
pred_original_sample = (alpha_prod_t**0.5) * latents - (beta_prod_t**0.5) * noise_pred

with record_function("1.6 scheduler"):
a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5

with record_function("1.7 scheduler"):
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t

with record_function("1.8 scheduler"):
prev_sample = a_t * latents + b_t * pred_original_sample

latents = prev_sample

# # compute the previous noisy sample x_t -> x_t-1
# with record_function(f"scheduler_step_{i}"):
# if not isinstance(self.scheduler, CogVideoXDPMScheduler):
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# else:
# latents, old_pred_original_sample = self.scheduler.step(
# noise_pred,
# old_pred_original_sample,
# t,
# timesteps[i - 1] if i > 0 else None,
# latents,
# **extra_step_kwargs,
# return_dict=False,
# )
# # latents = latents.to(prompt_embeds.dtype)

# call the callback, if provided
if callback_on_step_end is not None:
Expand All @@ -728,8 +763,9 @@ def __call__(
progress_bar.update()

if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
with record_function("decode_latents"):
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents

Expand Down
99 changes: 67 additions & 32 deletions src/diffusers/schedulers/scheduling_ddim_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import torch
from torch.profiler import record_function

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
Expand Down Expand Up @@ -362,41 +363,75 @@ def step(
# - pred_prev_sample -> "x_t-1"

# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

beta_prod_t = 1 - alpha_prod_t

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
with record_function("get original prediction"):
with record_function("step 1"):
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

# 2. compute alphas, betas
with record_function("step 2"):
print(self.alphas_cumprod.device, self.alphas_cumprod.dtype)
print(timestep.device, timestep.type)
print(prev_timestep.device, prev_timestep.dtype)
with record_function("step 2.1"):
alpha_prod_t = self.alphas_cumprod[timestep]

with record_function("step 2.2"):
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
)

with record_function("step 2.3"):
beta_prod_t = 1 - alpha_prod_t
print(beta_prod_t.device, beta_prod_t.dtype)
print("======")

with record_function("step 3"):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
print(
"vpred:",
sample.dtype,
model_output.dtype,
alpha_prod_t.dtype,
beta_prod_t.dtype,
pred_original_sample.dtype,
)
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

with record_function("compute prev sample"):
a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t

prev_sample = a_t * sample + b_t * pred_original_sample
print(
"prevsample devices:",
a_t.device,
b_t.device,
sample.device,
pred_original_sample.device,
prev_sample.device,
)
print("prevsample:", a_t.dtype, b_t.dtype, sample.dtype, pred_original_sample.dtype, prev_sample.dtype)
print("=== done ===")

a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t

prev_sample = a_t * sample + b_t * pred_original_sample

if not return_dict:
return (prev_sample,)
if not return_dict:
return (prev_sample,)

return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
Expand Down