Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions examples/community/cogvideox_ddim_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,14 +522,15 @@ def sample(
timestep = t.expand(latent_model_input.shape[0])

# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()

if reference_latents is not None: # Recover the original batch size
Expand Down
23 changes: 12 additions & 11 deletions examples/community/pipeline_flux_differential_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,17 +954,18 @@ def __call__(

# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
Expand Down
42 changes: 22 additions & 20 deletions examples/community/pipeline_flux_kontext_multiple_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,33 +1150,35 @@ def __call__(
latent_model_input = torch.cat([latents, image_latents], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]

if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)

Expand Down
46 changes: 24 additions & 22 deletions examples/community/pipeline_flux_rf_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,17 +888,18 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

latents_dtype = latents.dtype
if do_rf_inversion:
Expand Down Expand Up @@ -1058,17 +1059,18 @@ def invert(
timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size)

# get the unconditional vector field
u_t_i = self.transformer(
hidden_states=Y_t,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("uncond"):
u_t_i = self.transformer(
hidden_states=Y_t,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

# get the conditional vector field
u_t_i_cond = (y_1 - Y_t) / (1 - t_i)
Expand Down
47 changes: 25 additions & 22 deletions examples/community/pipeline_flux_semantic_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,23 +1135,25 @@ def __call__(
else:
guidance = None

noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps:
noise_pred_edit_concepts = []
for e_embed, pooled_e_embed, e_text_id in zip(
editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids
):
# TODO-context
noise_pred_edit = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
Expand All @@ -1168,17 +1170,18 @@ def __call__(
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
noise_pred_uncond = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_guidance = true_cfg_scale * (noise_pred - noise_pred_uncond)
else:
noise_pred_uncond = noise_pred
Expand Down
23 changes: 12 additions & 11 deletions examples/community/pipeline_flux_with_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,17 +815,18 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)

noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

if do_true_cfg:
neg_noise_pred, noise_pred = noise_pred.chunk(2)
Expand Down
25 changes: 13 additions & 12 deletions examples/community/pipeline_hunyuandit_differential_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,18 +1074,19 @@ def __call__(
)

# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]

noise_pred, _ = noise_pred.chunk(2, dim=1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -918,13 +918,14 @@ def __call__(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])

noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
return_dict=False,
)[0]

# perform guidance
if self.do_classifier_free_guidance:
Expand Down
18 changes: 10 additions & 8 deletions examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,14 +1178,15 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])
scaled_latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)

noise_pred = self.transformer(
hidden_states=scaled_latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=scaled_latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

# perform guidance
if self.do_classifier_free_guidance:
Expand All @@ -1204,6 +1205,7 @@ def __call__(
if skip_guidance_layers is not None and should_skip_layers:
timestep = t.expand(latents.shape[0])
latent_model_input = latents
# TODO-context
noise_pred_skip_layers = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
Expand Down
17 changes: 9 additions & 8 deletions examples/community/pipeline_stg_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,14 +793,15 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])

# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()

# perform guidance
Expand Down
Loading