-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Profile and speedup pipelines #9475
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ProblemWe have CUDA Stream synchronizations after every inference step. It's as bad as having a multiple graph-breaks when running compiled inference. Even if not running compiled inference, limiting such syncs can help speed up inference quite a bit. In the case of CogVideoX, I notice a ~42% speedup when comparing "fixed" implementation with InvestigationSo, I've been benchmarking and profiling some of our pipelines. My observations hint towards a massive possibility for speedups across pipelines. Let's get into the details. Most of our pipelines have similar structure:
Let's take a look at what our scheduler
Our culprit, causing the slowdown, is indexing diffusers/src/diffusers/schedulers/scheduling_ddim.py Lines 405 to 406 in ba5af5a
In the trace, it looks something like: Note: "2 scheduler" corresponds to the indexing operation. So, the easy solution to fix this would be to move Since the above doesn't work, I instead created a copy of timesteps that resides on the cpu always. So, After doing this, we can see that cuda stream syncs are no longer taking place after each inference step. It only happens once after completing inference before decoding latents. This speeds up things drastically, in the denoiser, with/without compile. Since it takes a long while to run profiles for different pipelines, I've only done so for SD, SDXL, CogVideoX with DDIMScheduler and LCMScheduler. In all cases, maintaining the copy of timesteps on cpu is what helped. #9485 mentions that keeping all variables on cuda would help, but, upon trying this, I still see cuda stream syncs which was unexpected. Perhaps someone else can help take a look in case I made a mistake. To make sure that I wasn't imagining things, I tried replicating this behaviour at a smaller scale. Here's a minimal reproducer that is somewhat representative of what our pipelines look like. It can fully reproduce the above-mentioned behaviours:
import numpy as np
import torch
from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler
def profiler_runner(fn, *args, trace_folder: str = "./traces", **kwargs):
with torch.profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
on_trace_ready=tensorboard_trace_handler(trace_folder),
) as _:
result = fn(*args, **kwargs)
return result
def model_forward(x):
for _ in range(5):
x = torch.matmul(x, x)
x = (x - x.min()) / (x.max() - x.min())
return x
@torch.no_grad()
def test(alphas_cumprod, timesteps, model_input, timesteps_cpu=None):
for i, t in enumerate(timesteps):
with record_function(f"pipeline_step_{i}"):
with record_function("step 1"):
model_output = model_forward(model_input)
with record_function("step 2"):
if timesteps_cpu is not None:
alpha = alphas_cumprod[timesteps_cpu[i]] # using cpu tensor as index
else:
alpha = alphas_cumprod[timesteps[i]] # using cuda tensor as index
# alpha = alphas_cumprod[t] # alternatively
with record_function("step 3"):
beta = 1 - alpha
with record_function("step 4"):
model_output = (alpha**0.5) * model_output - (beta**0.5) * model_input
with record_function("step 5"):
model_input = model_output
@torch.no_grad()
def main(alphas_device, should_timestep_be_on_cpu):
for i in range(3):
alphas = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], device=alphas_device, dtype=torch.float32)
alphas_cumprod = torch.cumprod(alphas, dim=0)
timesteps_numpy = np.arange(0, 5)[::-1].astype(np.int64)
timesteps = torch.from_numpy(timesteps_numpy)
timesteps = timesteps.to("cuda")
timesteps_cpu = None
if should_timestep_be_on_cpu:
timesteps_cpu = torch.from_numpy(timesteps_numpy)
model_input = torch.randn((1, 4, 64, 64), device="cuda")
with record_function(f"inference {i}"):
test(alphas_cumprod, timesteps, model_input, timesteps_cpu)
# This is what we have currently
# cuda syncs happen after every inference step
profiler_runner(main, "cpu", False, trace_folder="./traces")
# I would expect this to not cause cuda sync but it still does even though both timesteps
# cuda syncs happen after every inference step
# and alphas_cumprod is on cuda
profiler_runner(main, "cuda", False, trace_folder="./traces")
# Here, both alphas and timesteps are on the cpu. However, it is very rarely the case that timesteps
# would be on the cpu since they are used in the unet/transformer to get embeddings.
# cuda syncs only happen after all inference steps
profiler_runner(main, "cpu", True, trace_folder="./traces")
# This is highly unlikely case and not realistic either where alphas are on cuda but timesteps are on cpu.
# cuda syncs only happen after all inference steps
profiler_runner(main, "cuda", True, trace_folder="./traces") SolutionI'm not really sure what the best thing to do here would be. I find it unintuitive why having both alphas and timesteps on cuda would cause synchronizations. In the case that keeping both of them on cpu is the best solution, I propose to have another attribute in the scheduler: There are probably better ways of doing this. I haven't thought much about it yet. |
Thanks for the excellent analysis. My primary question here do the syncs (after your fixes) slow things down, still? There's likely a limit when it comes to reducing these syncs? |
I'm not sure I fully understand. What I've meant to say above is that, if there were 25 inference steps, there would be 25 syncs. That was before. Now there's only 1 sync after all inference steps. More syncs cause more slowdown, so now we're always faster. |
Okay. Then IIUC, in your comment, you are asking for the best way to do it? In that case, I don't see any disadvantages of maintaining a copy on the CPU. It's a pretty lightweight thing and I cannot imagine if it'd ever lead to any side-effects. |
@a-r-r-o-w could you perform the indexing with Relevant threads: |
So, I made the following replacements in the scheduler. Previosly,
Now,
I made sure to move alphas_cumprod and final_alpha_cumprod to cuda, however there are still calls made to cudaStreamSync :( |
TLDR; We were measuring the profiling time incorrectly. There is no big speedup as I claimed in excitement earlier. After some more hours of trying to profile our pipelines/schedulers, I found that there was a mistake in what we were trying to measure. Previously, we had cuda stream syncs after every step. After latest changes, there is no longer syncs after every step, except for a single time after all inference steps are complete. This should be good, right? Unfortunately, the results say something else. It is good advice by many folks to reduce the number of sync operations as much as possible to speedup training/inference. This is, in fact, what we achieved here but it does not seem to help much in this case. When trying to speedup pytorch code, the recommendation, as mentioned here, is to make the CPU run as far ahead of the GPU as possible and put many tasks into the accelerator queue. This is essentially all we've done so far. We made it so that the cpu can queue up a lot of operations without waiting for a synchronization to happen after every inference step. Unfortunately, this does not lead to any speedups here (maybe a few milliseconds but that's about it). The reason why I claimed a ~42% speedup in CogVideoX and simlar observations in other pipelines was because I had a tiny bug in my profiling code. We were measuring the time it took from start of pipeline execution to the moment when the CPU finished queuing up all tasks to the GPU. This part, of queueing instructions to GPU is, in fact, much faster when comparing to Codeimport argparse
import os
# os.environ["TORCH_LOGS"] = "+dynamo,graph_breaks,recompiles"
# os.environ["TORCHDYNAMO_VERBOSE"] = "1"
import git
import torch
import torch.utils.benchmark as benchmark
from diffusers import CogVideoXPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMScheduler
from diffusers.training_utils import set_seed
from tabulate import tabulate
torch.set_float32_matmul_precision("high")
set_seed(42)
PROMPT = "An astronaut floating in space"
NEGATIVE_PROMPT = ""
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def pretty_print_results(results, precision: int = 6):
def format_value(value):
if isinstance(value, float):
return f"{value:.{precision}f}"
return value
filtered_table = {k: format_value(v) for k, v in results.items()}
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))
def prepare_sd15(compile: bool = False):
pipe = StableDiffusionPipeline.from_pretrained("emilianJR/epiCRealism", torch_dtype=torch.float16)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
if compile:
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=PROMPT,
device="cuda",
num_images_per_prompt=8,
do_classifier_free_guidance=True,
negative_prompt=NEGATIVE_PROMPT,
)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"height": 512,
"width": 512,
"num_inference_steps": 50,
"guidance_scale": 7.5,
}
return pipe, generation_kwargs
def prepare_sdxl(compile: bool = False):
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
cache_dir="/raid/.cache/huggingface",
torch_dtype=torch.float16,
)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
if compile:
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
prompt=PROMPT,
device="cuda",
num_images_per_prompt=8,
do_classifier_free_guidance=True,
negative_prompt=NEGATIVE_PROMPT,
)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
"height": 1024,
"width": 1024,
"num_inference_steps": 50,
"guidance_scale": 7.5,
}
return pipe, generation_kwargs
def prepare_cogvideox(compile: bool = False):
pipe = CogVideoXPipeline.from_pretrained("/raid/aryan/CogVideoX-5b-trial", torch_dtype=torch.bfloat16)
pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
if compile:
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=PROMPT,
negative_prompt=NEGATIVE_PROMPT,
do_classifier_free_guidance=True,
device="cuda",
dtype=torch.bfloat16,
)
generation_kwargs = {
"prompt_embeds": prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"num_inference_steps": 50,
"guidance_scale": 6,
}
return pipe, generation_kwargs
@torch.no_grad()
def run_inference(pipe, generation_kwargs):
generator = torch.Generator().manual_seed(3047)
output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
torch.cuda.synchronize()
return output
@torch.no_grad()
def main(model_id, compile):
# 1. Load pipeline
if model_id == "sd15":
pipe, generation_kwargs = prepare_sd15(compile)
elif model_id == "sdxl":
pipe, generation_kwargs = prepare_sdxl(compile)
elif model_id == "cogvideox":
pipe, generation_kwargs = prepare_cogvideox(compile)
else:
raise ValueError("Invalid model_id for benchmarking.")
# 2. Warmup
num_warmups = 2
for _ in range(num_warmups):
output = run_inference(pipe, generation_kwargs)
# 3. Benchmark
time = benchmark_fn(run_inference, pipe, generation_kwargs)
# 4. Save artifacts
repo = git.Repo(search_parent_directories=True)
branch = repo.active_branch
info = {
"model_id": model_id,
"compile": compile,
"time": time,
"branch": branch,
}
pretty_print_results(info, precision=3)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_id",
type=str,
default="sd15",
choices=["sd15", "sdxl", "cogvideox"],
help="Model to run benchmark for.",
)
parser.add_argument(
"--compile",
action="store_true",
default=False,
help="Whether or not to torch.compile the denoiser.",
)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args.model_id, args.compile) We get the following results:
We can see that there is no speedup. After making sure that all GPU calls finish executing, the inference takes about the same time. It's just the time-to-queue-instructions-from-CPU-to-GPU that is reduced, by "fixing" our schedulers to not cause calls
|
We observed something similar when profiling Flux a few days ago. Within the There's a torch.compile() mode that specifically optimizes for overhead-bound scenarios -- "reduce-overhead" (see more here) but I don't think that is the case here. Because the model being compiled (i.e., the transformer, for example), it's rather compute-bound than being overhead-bound. |
@a-r-r-o-w Basically, we have been trying to move away using timestep to index anything, we implemented Other than the performance issue that it helps, it also allows the user to use the schedulers with k-sigmas etc. For Cogvideox, haven't looked into if we can refactor all the calculation to only use as long as we can use |
Just redirect here from #9425, I think the high level idea is the same - To reduce unneeded CudaMemcpy/CudaMemsync in diffusion loop. Following by the previous comment from @yiyixuxu, i think it'd better to create a new PR that trying to
@a-r-r-o-w @sayakpaul @yiyixuxu |
What does this PR do?
Currently tested only for CogVideoX. Noticing a ~42% speedup comparing compiled versions of this branch VS compiled version of
main
. Will share more concrete numbers for our major pipelines/schedulers soon.testing code
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.