Skip to content

Commit

Permalink
feat: remove scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
junhsss committed Mar 27, 2023
1 parent e1634e4 commit ee6a349
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 239 deletions.
19 changes: 7 additions & 12 deletions consistency/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchvision.transforms.functional import to_pil_image
from torchvision.utils import make_grid, save_image

from consistency.diffusers import ConsistencyPipeline, ConsistencyScheduler
from consistency.pipeline import ConsistencyPipeline

with suppress(ImportError):
import wandb
Expand Down Expand Up @@ -58,7 +58,7 @@ def __init__(
save_samples_every_n_epoch: int = 10,
num_samples: int = 16,
sample_steps: int = 1,
sample_ema: bool = False,
use_ema: bool = False,
sample_seed: int = 0,
push_to_hub: bool = False,
model_id: Optional[str] = None,
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
self.save_samples_every_n_epoch = save_samples_every_n_epoch
self.num_samples = num_samples
self.sample_steps = sample_steps
self.sample_ema = sample_ema
self.use_ema = use_ema
self.sample_seed = sample_seed

if push_to_hub:
Expand Down Expand Up @@ -246,18 +246,13 @@ def on_train_batch_end(self, *args, **kwargs):
and self.trainer.global_step > 0
):
pipeline = ConsistencyPipeline(
unet=self.model_ema.unet if self.sample_ema else self.model.unet,
scheduler=ConsistencyScheduler(
time_min=self.time_min,
time_max=self.time_max,
data_std=self.data_std,
),
unet=self.model_ema.unet if self.use_ema else self.model.unet,
)

pipeline.save_pretrained(self.model_id)

self.repo.push_to_hub(
commit_message=f"Epoch {self.current_epoch}",
commit_message=f"Step {self.global_step}",
blocking=False,
)

Expand Down Expand Up @@ -321,7 +316,7 @@ def on_train_start(self) -> None:
num_samples=self.num_samples,
steps=self.sample_steps,
generator=torch.Generator(device=self.device).manual_seed(self.sample_seed),
use_ema=self.sample_ema,
use_ema=self.use_ema,
)

@rank_zero_only
Expand All @@ -338,7 +333,7 @@ def on_train_epoch_end(self) -> None:
generator=torch.Generator(device=self.device).manual_seed(
self.sample_seed
),
use_ema=self.sample_ema,
use_ema=self.use_ema,
)

@torch.no_grad()
Expand Down
4 changes: 0 additions & 4 deletions consistency/diffusers/__init__.py

This file was deleted.

82 changes: 0 additions & 82 deletions consistency/diffusers/pipeline_consistency.py

This file was deleted.

105 changes: 0 additions & 105 deletions consistency/diffusers/scheduling_consistency.py

This file was deleted.

70 changes: 70 additions & 0 deletions consistency/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import math
from typing import List, Optional, Tuple, Union

import torch
from diffusers import DiffusionPipeline, ImagePipelineOutput, UNet2DModel
from diffusers.utils import randn_tensor


class ConsistencyPipeline(DiffusionPipeline):
unet: UNet2DModel

def __init__(
self,
unet: UNet2DModel,
) -> None:
super().__init__()
self.register_modules(unet=unet)

@torch.no_grad()
def __call__(
self,
steps: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
time_min: float = 0.002,
time_max: float = 80.0,
data_std: float = 0.5,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
) -> Union[Tuple, ImagePipelineOutput]:
img_size = self.unet.config.sample_size
shape = (1, 3, img_size, img_size)

model = self.unet

time: float = time_max

sample = randn_tensor(shape, generator=generator) * time

for step in self.progress_bar(range(steps)):
if step > 0:
time = self.search_previous_time(time)
sigma = math.sqrt(time**2 - time_min**2 + 1e-6)
sample = sample + sigma * randn_tensor(
sample.shape, device=sample.device, generator=generator
)

out = model(sample, torch.tensor([time], device=sample.device)).sample

skip_coef = data_std**2 / ((time - time_min) ** 2 + data_std**2)
out_coef = data_std * time / (time**2 + data_std**2) ** (0.5)

sample = (sample * skip_coef + out * out_coef).clamp(-1.0, 1.0)

sample = (sample / 2 + 0.5).clamp(0, 1)
image = sample.cpu().permute(0, 2, 3, 1).numpy()

if output_type == "pil":
image = self.numpy_to_pil(image)

if not return_dict:
return (image,)

return ImagePipelineOutput(images=image)

# TODO: Implement greedy search on FID
def search_previous_time(
self, time, time_min: float = 0.002, time_max: float = 80.0
):
return (2 * time + time_min) / 3
2 changes: 1 addition & 1 deletion examples/consistency_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,4 @@
"outputs": []
}
]
}
}
35 changes: 0 additions & 35 deletions tests/test_compat.py

This file was deleted.

0 comments on commit ee6a349

Please sign in to comment.