Skip to content

Commit

Permalink
Fix sampling for batch sizes that are different from the total number…
Browse files Browse the repository at this point in the history
… of images to sample
  • Loading branch information
crowsonkb committed Dec 23, 2021
1 parent 9083c4d commit ac808e7
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions clip_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def main():

torch.manual_seed(args.seed)

def cond_fn(x, t, pred, **kwargs):
def cond_fn(x, t, pred, clip_embed):
clip_in = normalize(make_cutouts((pred + 1) / 2))
image_embeds = clip_model.encode_image(clip_in).view([args.cutn, x.shape[0], -1])
losses = spherical_dist_loss(image_embeds, clip_embed[None])
Expand All @@ -152,10 +152,15 @@ def cond_fn(x, t, pred, **kwargs):
def run(x, clip_embed):
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
steps = utils.get_spliced_ddpm_cosine_schedule(t)
extra_args = {'clip_embed': clip_embed} if hasattr(model, 'clip_model') else {}
if hasattr(model, 'clip_model'):
extra_args = {'clip_embed': clip_embed}
cond_fn_ = cond_fn
else:
extra_args = {}
cond_fn_ = partial(cond_fn, clip_embed=clip_embed)
if not args.clip_guidance_scale:
return sampling.sample(model, x, steps, args.eta, extra_args)
return sampling.cond_sample(model, x, steps, args.eta, extra_args, cond_fn)
return sampling.cond_sample(model, x, steps, args.eta, extra_args, cond_fn_)

def run_all(n, batch_size):
x = torch.randn([args.n, 3, side_y, side_x], device=device)
Expand Down

0 comments on commit ac808e7

Please sign in to comment.