Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profile and speedup pipelines #9475

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Profile and speedup pipelines #9475

wants to merge 6 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Sep 19, 2024

What does this PR do?

Currently tested only for CogVideoX. Noticing a ~42% speedup comparing compiled versions of this branch VS compiled version of main. Will share more concrete numbers for our major pipelines/schedulers soon.

testing code
import argparse
import logging
import os
from platform import python_version

os.environ['TORCH_LOGS'] = '+graph_breaks,dynamo,recompiles'
os.environ['TORCHDYNAMO_VERBOSE'] = '1'

logging.basicConfig(level=logging.INFO)

import diffusers
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler


print(diffusers.__version__)
print(torch.__version__)
print(python_version())


torch.set_float32_matmul_precision("high")
torch._inductor.conv_1x1_as_mm = True
torch._inductor.coordinate_descent_tuning = True
torch._inductor.epilogue_fusion = False
torch._inductor.coordinate_descent_check_all_directions = True

PROMPT_EMBEDS_FILE = "prompt_embeds.pt"
NEGATIVE_PROMPT_EMBEDS_FILE = "negative_prompt_embeds.pt"


def profiler_runner(fn, *args, **kwargs):
    with torch.profiler.profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=True,
            on_trace_ready=tensorboard_trace_handler("./cogvideox_trace")
    ) as prof:
        result = fn(*args, **kwargs)
    
    # prof.export_chrome_trace("cogvideox_trace.json")
    return result


@torch.no_grad()
def generate_embeds(pipe):
    prompt = (
        "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
        "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
        "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
        "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
        "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
        "atmosphere of this unique musical performance."
    )
    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=prompt,
        do_classifier_free_guidance=True,
        device="cuda",
    )
    torch.save(prompt_embeds, PROMPT_EMBEDS_FILE)
    torch.save(negative_prompt_embeds, NEGATIVE_PROMPT_EMBEDS_FILE)


@torch.no_grad()
def main(args):
    # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
    pipe = CogVideoXPipeline.from_pretrained("/raid/aryan/CogVideoX-trial", torch_dtype=torch.float16)
    pipe.transformer.to(memory_format=torch.channels_last)
    pipe.to("cuda")

    print(type(pipe.scheduler.alphas_cumprod))
    # print(pipe.scheduler.alphas_cumprod.device)
    # assert pipe.scheduler.alphas_cumprod.device == torch.device("cpu")
    # pipe.scheduler.alphas_cumprod = pipe.scheduler.alphas_cumprod.to("cuda")
    # pipe.scheduler.final_alpha_cumprod = pipe.scheduler.final_alpha_cumprod.to("cuda")

    pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
    
    if args.generate_embeds or not os.path.isfile(PROMPT_EMBEDS_FILE):
        generate_embeds(pipe)
    
    prompt_embeds = torch.load(PROMPT_EMBEDS_FILE).to("cuda")
    negative_prompt_embeds = torch.load(NEGATIVE_PROMPT_EMBEDS_FILE).to("cuda")

    def inference():
        for i in range(3):
            with record_function(f"pipeline_run_{i}"):
                video = pipe(
                    prompt_embeds=prompt_embeds,
                    negative_prompt_embeds=negative_prompt_embeds,
                    guidance_scale=6,
                    num_inference_steps=50,
                ).frames[0]
                export_to_video(video, "output.mp4", fps=8)
    
    # profiler_runner(inference)
    inference()


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--generate_embeds", action="store_true")
    parser.add_argument("--run_profile", action="store_true")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Sep 23, 2024

Problem

We have CUDA Stream synchronizations after every inference step. It's as bad as having a multiple graph-breaks when running compiled inference. Even if not running compiled inference, limiting such syncs can help speed up inference quite a bit. In the case of CogVideoX, I notice a ~42% speedup when comparing "fixed" implementation with diffusers:main after compiling.

Investigation

So, I've been benchmarking and profiling some of our pipelines. My observations hint towards a massive possibility for speedups across pipelines. Let's get into the details.

Most of our pipelines have similar structure:

1. Check and prepare inputs
2. Start inference loop
3. Forward pass through denoiser
4. CFG
5. Scheduler step (aka, our culprit)
6. Repeat until N steps
7. Decode latents (also the culprit in many cases with heavy CPU usage in VAE [especially CogVideoX] but scheduler is more important and the point of discussion for now)

Let's take a look at what our scheduler .step() looks like at a high-level.

1. Index and access current alphas and betas based on timestep passed. For all pipelines tested, alphas are a torch.Tensor residing on cuda
2. x/eps/v-prediction
3. Computing prev sample and returning

Our culprit, causing the slowdown, is indexing alphas_cumprod with current timestep. alphas_cumprod is a float32 tensor residing on cpu whereas timestep is an int64 tensor residing on cuda. Doing this indexing causes timestep to first item-ified (to get a cpu int64 tensor) and moved to cpu. This is the code in question:

alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

In the trace, it looks something like:

image

Note: "2 scheduler" corresponds to the indexing operation.

So, the easy solution to fix this would be to move alphas_cumprod to cuda as well, yes? Unfortunately, I still see cuda stream synchronizations. I'm not sure why, but the trace stack looks extremely similar to above, if not the same.

Since the above doesn't work, I instead created a copy of timesteps that resides on the cpu always. So, self.timesteps would reside on the device set by the pipeline (always cuda when running our default pipeline examples on a GPU machine) and the copied attribute self.timesteps_cpu would reside on the cpu. We need to take care of reassigning this value correctly in methods like set_timesteps() too.

After doing this, we can see that cuda stream syncs are no longer taking place after each inference step. It only happens once after completing inference before decoding latents. This speeds up things drastically, in the denoiser, with/without compile.

Since it takes a long while to run profiles for different pipelines, I've only done so for SD, SDXL, CogVideoX with DDIMScheduler and LCMScheduler. In all cases, maintaining the copy of timesteps on cpu is what helped. #9485 mentions that keeping all variables on cuda would help, but, upon trying this, I still see cuda stream syncs which was unexpected. Perhaps someone else can help take a look in case I made a mistake.

To make sure that I wasn't imagining things, I tried replicating this behaviour at a smaller scale. Here's a minimal reproducer that is somewhat representative of what our pipelines look like. It can fully reproduce the above-mentioned behaviours:

  • alphas_cumprod on cpu, timesteps on cuda: This is what we have currently
  • alphas_cumprod on cuda, timesteps on cuda: Still has cuda stream syncs after each inference step
  • alphas_cumprod on cpu, copy of timesteps on cpu: Proposed solution
  • alphas_cumprod on cuda, copy of timesteps on cpu: Highly unlikely case so ignore
import numpy as np
import torch
from torch.profiler import profile, record_function, ProfilerActivity, tensorboard_trace_handler


def profiler_runner(fn, *args, trace_folder: str = "./traces", **kwargs):
    with torch.profiler.profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=True,
        on_trace_ready=tensorboard_trace_handler(trace_folder),
    ) as _:
        result = fn(*args, **kwargs)
    return result


def model_forward(x):
    for _ in range(5):
        x = torch.matmul(x, x)
        x = (x - x.min()) / (x.max() - x.min())
    return x


@torch.no_grad()
def test(alphas_cumprod, timesteps, model_input, timesteps_cpu=None):
    for i, t in enumerate(timesteps):
        with record_function(f"pipeline_step_{i}"):
            with record_function("step 1"):
                model_output = model_forward(model_input)

            with record_function("step 2"):
                if timesteps_cpu is not None:
                    alpha = alphas_cumprod[timesteps_cpu[i]]  # using cpu tensor as index
                else:
                    alpha = alphas_cumprod[timesteps[i]]  # using cuda tensor as index
                    # alpha = alphas_cumprod[t] # alternatively

            with record_function("step 3"):
                beta = 1 - alpha

            with record_function("step 4"):
                model_output = (alpha**0.5) * model_output - (beta**0.5) * model_input

            with record_function("step 5"):
                model_input = model_output


@torch.no_grad()
def main(alphas_device, should_timestep_be_on_cpu):
    for i in range(3):
        alphas = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5], device=alphas_device, dtype=torch.float32)
        alphas_cumprod = torch.cumprod(alphas, dim=0)

        timesteps_numpy = np.arange(0, 5)[::-1].astype(np.int64)

        timesteps = torch.from_numpy(timesteps_numpy)
        timesteps = timesteps.to("cuda")

        timesteps_cpu = None
        if should_timestep_be_on_cpu:
            timesteps_cpu = torch.from_numpy(timesteps_numpy)

        model_input = torch.randn((1, 4, 64, 64), device="cuda")

        with record_function(f"inference {i}"):
            test(alphas_cumprod, timesteps, model_input, timesteps_cpu)


# This is what we have currently
# cuda syncs happen after every inference step
profiler_runner(main, "cpu", False, trace_folder="./traces")

# I would expect this to not cause cuda sync but it still does even though both timesteps
# cuda syncs happen after every inference step
# and alphas_cumprod is on cuda
profiler_runner(main, "cuda", False, trace_folder="./traces")

# Here, both alphas and timesteps are on the cpu. However, it is very rarely the case that timesteps
# would be on the cpu since they are used in the unet/transformer to get embeddings.
# cuda syncs only happen after all inference steps
profiler_runner(main, "cpu", True, trace_folder="./traces")

# This is highly unlikely case and not realistic either where alphas are on cuda but timesteps are on cpu.
# cuda syncs only happen after all inference steps
profiler_runner(main, "cuda", True, trace_folder="./traces")

Solution

I'm not really sure what the best thing to do here would be. I find it unintuitive why having both alphas and timesteps on cuda would cause synchronizations. In the case that keeping both of them on cpu is the best solution, I propose to have another attribute in the scheduler: self.timesteps_cpu = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64)). When doing .step(), we would then need to pass the cpu timestep instead of cuda timestep.

There are probably better ways of doing this. I haven't thought much about it yet.

@yiyixuxu @sayakpaul

@sayakpaul
Copy link
Member

Thanks for the excellent analysis.

My primary question here do the syncs (after your fixes) slow things down, still? There's likely a limit when it comes to reducing these syncs?

@a-r-r-o-w
Copy link
Member Author

a-r-r-o-w commented Sep 23, 2024

I'm not sure I fully understand. What I've meant to say above is that, if there were 25 inference steps, there would be 25 syncs. That was before. Now there's only 1 sync after all inference steps. More syncs cause more slowdown, so now we're always faster.

@sayakpaul
Copy link
Member

Okay. Then IIUC, in your comment, you are asking for the best way to do it? In that case, I don't see any disadvantages of maintaining a copy on the CPU. It's a pretty lightweight thing and I cannot imagine if it'd ever lead to any side-effects.

@a-r-r-o-w
Copy link
Member Author

could you perform the indexing with torch.gather()?

So, I made the following replacements in the scheduler.

Previosly,

alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

Now,

alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, prev_timestep) if prev_timestep >= 0 else self.final_alpha_cumprod

I made sure to move alphas_cumprod and final_alpha_cumprod to cuda, however there are still calls made to cudaStreamSync :(

@a-r-r-o-w
Copy link
Member Author

TLDR; We were measuring the profiling time incorrectly. There is no big speedup as I claimed in excitement earlier.

After some more hours of trying to profile our pipelines/schedulers, I found that there was a mistake in what we were trying to measure. Previously, we had cuda stream syncs after every step. After latest changes, there is no longer syncs after every step, except for a single time after all inference steps are complete. This should be good, right? Unfortunately, the results say something else. It is good advice by many folks to reduce the number of sync operations as much as possible to speedup training/inference. This is, in fact, what we achieved here but it does not seem to help much in this case.

When trying to speedup pytorch code, the recommendation, as mentioned here, is to make the CPU run as far ahead of the GPU as possible and put many tasks into the accelerator queue. This is essentially all we've done so far. We made it so that the cpu can queue up a lot of operations without waiting for a synchronization to happen after every inference step. Unfortunately, this does not lead to any speedups here (maybe a few milliseconds but that's about it). The reason why I claimed a ~42% speedup in CogVideoX and simlar observations in other pipelines was because I had a tiny bug in my profiling code. We were measuring the time it took from start of pipeline execution to the moment when the CPU finished queuing up all tasks to the GPU. This part, of queueing instructions to GPU is, in fact, much faster when comparing to diffusers:main... about 42% faster. But finishing queueing instructions to GPU is one thing and this does not make it so that the GPU is able to complete its operations any faster. In fact, it takes almost the same amount of time as before i.e. no speedup. This can be more clearly observed by running the fixed profiling implementation:

Code
import argparse
import os

# os.environ["TORCH_LOGS"] = "+dynamo,graph_breaks,recompiles"
# os.environ["TORCHDYNAMO_VERBOSE"] = "1"

import git
import torch
import torch.utils.benchmark as benchmark
from diffusers import CogVideoXPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMScheduler
from diffusers.training_utils import set_seed
from tabulate import tabulate


torch.set_float32_matmul_precision("high")
set_seed(42)

PROMPT = "An astronaut floating in space"
NEGATIVE_PROMPT = ""


def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=torch.get_num_threads(),
    )
    return f"{(t0.blocked_autorange().mean):.3f}"


def pretty_print_results(results, precision: int = 6):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def prepare_sd15(compile: bool = False):
    pipe = StableDiffusionPipeline.from_pretrained("emilianJR/epiCRealism", torch_dtype=torch.float16)
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.to("cuda")

    if compile:
        pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
    
    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=PROMPT,
        device="cuda",
        num_images_per_prompt=8,
        do_classifier_free_guidance=True,
        negative_prompt=NEGATIVE_PROMPT,
    )

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 512,
        "width": 512,
        "num_inference_steps": 50,
        "guidance_scale": 7.5,
    }

    return pipe, generation_kwargs


def prepare_sdxl(compile: bool = False):
    pipe = StableDiffusionXLPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        cache_dir="/raid/.cache/huggingface",
        torch_dtype=torch.float16,
    )
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    pipe.to("cuda")

    if compile:
        pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
    
    prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
        prompt=PROMPT,
        device="cuda",
        num_images_per_prompt=8,
        do_classifier_free_guidance=True,
        negative_prompt=NEGATIVE_PROMPT,
    )

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 7.5,
    }

    return pipe, generation_kwargs


def prepare_cogvideox(compile: bool = False):
    pipe = CogVideoXPipeline.from_pretrained("/raid/aryan/CogVideoX-5b-trial", torch_dtype=torch.bfloat16)
    pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config)
    pipe.to("cuda")

    if compile:
        pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
    
    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=PROMPT,
        negative_prompt=NEGATIVE_PROMPT,
        do_classifier_free_guidance=True,
        device="cuda",
        dtype=torch.bfloat16,
    )

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "num_inference_steps": 50,
        "guidance_scale": 6,
    }

    return pipe, generation_kwargs


@torch.no_grad()
def run_inference(pipe, generation_kwargs):
    generator = torch.Generator().manual_seed(3047)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id, compile):
    # 1. Load pipeline
    if model_id == "sd15":
        pipe, generation_kwargs = prepare_sd15(compile)
    elif model_id == "sdxl":
        pipe, generation_kwargs = prepare_sdxl(compile)
    elif model_id == "cogvideox":
        pipe, generation_kwargs = prepare_cogvideox(compile)
    else:
        raise ValueError("Invalid model_id for benchmarking.")

    # 2. Warmup
    num_warmups = 2
    for _ in range(num_warmups):
        output = run_inference(pipe, generation_kwargs)

    # 3. Benchmark
    time = benchmark_fn(run_inference, pipe, generation_kwargs)

    # 4. Save artifacts
    repo = git.Repo(search_parent_directories=True)
    branch = repo.active_branch
    
    info = {
        "model_id": model_id,
        "compile": compile,
        "time": time,
        "branch": branch,
    }
    pretty_print_results(info, precision=3)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="sd15",
        choices=["sd15", "sdxl", "cogvideox"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether or not to torch.compile the denoiser.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()

    main(args.model_id, args.compile)

We get the following results:

model_id compile time branch
sd15 False 7.396 cogvideox-profiling
sd15 True 6.149 cogvideox-profiling
cogvideox False 253.024 cogvideox-profiling
cogvideox True 207.292 cogvideox-profiling
sd15 False 7.398 main
sd15 True 6.171 main
cogvideox False 253.973 main
cogvideox True 207.393 main

We can see that there is no speedup. After making sure that all GPU calls finish executing, the inference takes about the same time. It's just the time-to-queue-instructions-from-CPU-to-GPU that is reduced, by "fixing" our schedulers to not cause calls cudaStreamSync. I tried two different approaches and observed the same/similar results as above:

  • Maintaining a separate copy of timesteps on CPU as a new self.timesteps_cpu attribute and used it for indexing (cpu tensor indexed by cpu tensor)
  • Using the current approach in the PR, by @sayakpaul (indexing cuda tensor with cuda tensor, without cpu overhead by making use of torch.gather and torch.where)

cc @xiang9156 @yiyixuxu @sayakpaul

@sayakpaul
Copy link
Member

We observed something similar when profiling Flux a few days ago. Within the torch.compile() regions, there were a few syncs but they didn't contribute a whole lot to the latency actually.

There's a torch.compile() mode that specifically optimizes for overhead-bound scenarios -- "reduce-overhead" (see more here) but I don't think that is the case here. Because the model being compiled (i.e., the transformer, for example), it's rather compute-bound than being overhead-bound.

@yiyixuxu
Copy link
Collaborator

@a-r-r-o-w
i haven't had time to looked into this in details, but this PR is for your reference #4986

Basically, we have been trying to move away using timestep to index anything, we implemented step_index and use that in all our popular schedulers.

Other than the performance issue that it helps, it also allows the user to use the schedulers with k-sigmas etc. For Cogvideox, haven't looked into if we can refactor all the calculation to only use sigma, but it does not have to be the case, in fact it might be beneficial to not only use sigmas, but calculate all the needed variable once like this comment points out https://github.com/huggingface/diffusers/pull/4986/files#r1349529721

as long as we can use step_index it should be good

@dianyo
Copy link
Contributor

dianyo commented Sep 27, 2024

Just redirect here from #9425, I think the high level idea is the same - To reduce unneeded CudaMemcpy/CudaMemsync in diffusion loop. Following by the previous comment from @yiyixuxu, i think it'd better to create a new PR that trying to

  • Check if step_index supports all scheduler, if not, implement it
  • Calculate all the needed variable once for all schedulers.
  • Profiling the pipeline to do the comparisons.

@a-r-r-o-w @sayakpaul @yiyixuxu
I'm happy to take in charge of this if the plan above make sense to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants