Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profile and speedup pipelines #9475

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
a-r-r-o-w committed Sep 24, 2024
commit 8e78c9d1d1b677a058e5d6f9146b7df7204dc489
102 changes: 31 additions & 71 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.profiler import record_function
from transformers import T5EncoderModel, T5Tokenizer

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
Expand Down Expand Up @@ -682,77 +681,39 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])

# predict noise model_output
with record_function(f"transformer_iteration_{i}"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# noise_pred = noise_pred.float()
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()

# perform guidance
with record_function(f"guidance_{i}"):
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0))
/ 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

with record_function("1.1 scheduler"):
prev_timestep = (
self.scheduler.timesteps_numpy[i]
- self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)

with record_function("1.2 scheduler"):
# alpha_prod_t = self.scheduler.alphas_cumprod[t]
alpha_prod_t = self.scheduler.alphas_cumprod[self.scheduler.timesteps_numpy[i]]

with record_function("1.3 scheduler"):
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else self.scheduler.final_alpha_cumprod
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)

with record_function("1.4 scheduler"):
beta_prod_t = 1 - alpha_prod_t

with record_function("1.5 scheduler"):
pred_original_sample = (alpha_prod_t**0.5) * latents - (beta_prod_t**0.5) * noise_pred

with record_function("1.6 scheduler"):
a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5

with record_function("1.7 scheduler"):
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t

with record_function("1.8 scheduler"):
prev_sample = a_t * latents + b_t * pred_original_sample

latents = prev_sample

# # compute the previous noisy sample x_t -> x_t-1
# with record_function(f"scheduler_step_{i}"):
# if not isinstance(self.scheduler, CogVideoXDPMScheduler):
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# else:
# latents, old_pred_original_sample = self.scheduler.step(
# noise_pred,
# old_pred_original_sample,
# t,
# timesteps[i - 1] if i > 0 else None,
# latents,
# **extra_step_kwargs,
# return_dict=False,
# )
# # latents = latents.to(prompt_embeds.dtype)
latents = latents.to(prompt_embeds.dtype)

# call the callback, if provided
if callback_on_step_end is not None:
Expand All @@ -769,9 +730,8 @@ def __call__(
progress_bar.update()

if not output_type == "latent":
with record_function("decode_latents"):
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents

Expand Down
50 changes: 22 additions & 28 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@

import numpy as np
import torch
from torch.profiler import record_function

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from ..utils.torch_utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin

from torch.profiler import record_function


@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
Expand Down Expand Up @@ -232,10 +233,9 @@ def __init__(

# setable values
self.num_inference_steps = None

# TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))
self.timesteps_cpu = np.arange(0, num_train_timesteps)[::-1].astype(np.int64)

def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
Expand All @@ -256,10 +256,11 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None

def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)

safe_prev_timestep = torch.clamp(prev_timestep, min=0)
gathered = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, gathered, self.final_alpha_cumprod)
# alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)

beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

Expand Down Expand Up @@ -345,12 +346,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
)

self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps_cpu = timesteps
self.alphas_cumprod = self.alphas_cumprod.to(device)
self.final_alpha_cumprod = self.final_alpha_cumprod.to(device)

def step(
self,
model_output: torch.Tensor,
timestep: Union[torch.Tensor, float],
timestep: int,
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
Expand All @@ -365,7 +367,7 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`torch.Tensor` or `float`):
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Expand Down Expand Up @@ -413,13 +415,11 @@ def step(
# 2. compute alphas, betas
with record_function("2 scheduler"):
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
# alpha_prod_t = self.alphas_cumprod[timestep]


with record_function("3 scheduler"):
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
gathered = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, gathered, self.final_alpha_cumprod)
# alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)

with record_function("4 scheduler"):
beta_prod_t = 1 - alpha_prod_t
Expand Down Expand Up @@ -454,26 +454,20 @@ def step(
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
with record_function("7 scheduler"):
with record_function("7.1 scheduler"):
variance = self._get_variance(timestep, prev_timestep)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)

with record_function("7.2 scheduler"):
std_dev_t = eta * variance ** (0.5)

with record_function("7.3 scheduler"):
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)

# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
with record_function("8 scheduler"):
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon

# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
with record_function("9 scheduler"):
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

with record_function("10 scheduler"):
with record_function("8 scheduler"):
if eta > 0:
if variance_noise is not None and generator is not None:
raise ValueError(
Expand Down
Loading