Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions invokeai/app/invocations/z_image_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Z-Image Control invocation for spatial conditioning."""

from pydantic import BaseModel, Field

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
InputField,
OutputField,
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType


class ZImageControlField(BaseModel):
"""A Z-Image control conditioning field for spatial control (Canny, HED, Depth, Pose, MLSD)."""

image_name: str = Field(description="The name of the preprocessed control image")
control_model: ModelIdentifierField = Field(description="The Z-Image ControlNet adapter model")
control_context_scale: float = Field(
default=0.75,
ge=0.0,
le=2.0,
description="The strength of the control signal. Recommended range: 0.65-0.80.",
)
begin_step_percent: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="When the control is first applied (% of total steps)",
)
end_step_percent: float = Field(
default=1.0,
ge=0.0,
le=1.0,
description="When the control is last applied (% of total steps)",
)


@invocation_output("z_image_control_output")
class ZImageControlOutput(BaseInvocationOutput):
"""Z-Image Control output containing control configuration."""

control: ZImageControlField = OutputField(description="Z-Image control conditioning")


@invocation(
"z_image_control",
title="Z-Image ControlNet",
tags=["image", "z-image", "control", "controlnet"],
category="control",
version="1.1.0",
classification=Classification.Prototype,
)
class ZImageControlInvocation(BaseInvocation):
"""Configure Z-Image ControlNet for spatial conditioning.

Takes a preprocessed control image (e.g., Canny edges, depth map, pose)
and a Z-Image ControlNet adapter model to enable spatial control.

Supports 5 control modes: Canny, HED, Depth, Pose, MLSD.
Recommended control_context_scale: 0.65-0.80.
"""

image: ImageField = InputField(
description="The preprocessed control image (Canny, HED, Depth, Pose, or MLSD)",
)
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model,
title="Control Model",
ui_model_base=BaseModelType.ZImage,
ui_model_type=ModelType.ControlNet,
)
control_context_scale: float = InputField(
default=0.75,
ge=0.0,
le=2.0,
description="Strength of the control signal. Recommended range: 0.65-0.80.",
title="Control Scale",
)
begin_step_percent: float = InputField(
default=0.0,
ge=0.0,
le=1.0,
description="When the control is first applied (% of total steps)",
)
end_step_percent: float = InputField(
default=1.0,
ge=0.0,
le=1.0,
description="When the control is last applied (% of total steps)",
)

def invoke(self, context: InvocationContext) -> ZImageControlOutput:
return ZImageControlOutput(
control=ZImageControlField(
image_name=self.image.image_name,
control_model=self.control_model,
control_context_scale=self.control_context_scale,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
)
)
190 changes: 167 additions & 23 deletions invokeai/app/invocations/z_image_denoise.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import math
from contextlib import ExitStack
from typing import Callable, Iterator, Optional, Tuple

import einops
import torch
import torchvision.transforms as tv_transforms
from PIL import Image
from torchvision.transforms.functional import resize as tv_resize
from tqdm import tqdm

Expand All @@ -16,8 +19,10 @@
LatentsField,
ZImageConditioningField,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.invocations.z_image_control import ZImageControlField
from invokeai.app.invocations.z_image_image_to_latents import ZImageImageToLatentsInvocation
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat
from invokeai.backend.patches.layer_patcher import LayerPatcher
Expand All @@ -27,6 +32,11 @@
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ZImageConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.z_image.z_image_control_adapter import ZImageControlAdapter
from invokeai.backend.z_image.z_image_controlnet_extension import (
ZImageControlNetExtension,
z_image_forward_with_control,
)


@invocation(
Expand Down Expand Up @@ -59,18 +69,31 @@ class ZImageDenoiseInvocation(BaseInvocation):
negative_conditioning: Optional[ZImageConditioningField] = InputField(
default=None, description=FieldDescriptions.negative_cond, input=Input.Connection
)
# Z-Image-Turbo uses guidance_scale=0.0 by default (no CFG)
# Z-Image-Turbo works best without CFG (guidance_scale=1.0)
guidance_scale: float = InputField(
default=0.0,
ge=0.0,
description="Guidance scale for classifier-free guidance. Use 0.0 for Z-Image-Turbo.",
default=1.0,
ge=1.0,
description="Guidance scale for classifier-free guidance. 1.0 = no CFG (recommended for Z-Image-Turbo). "
"Values > 1.0 amplify guidance.",
title="Guidance Scale",
)
width: int = InputField(default=1024, multiple_of=16, description="Width of the generated image.")
height: int = InputField(default=1024, multiple_of=16, description="Height of the generated image.")
# Z-Image-Turbo uses 8 steps by default
steps: int = InputField(default=8, gt=0, description="Number of denoising steps. 8 recommended for Z-Image-Turbo.")
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
# Z-Image Control support
control: Optional[ZImageControlField] = InputField(
default=None,
description="Z-Image control conditioning for spatial control (Canny, HED, Depth, Pose, MLSD).",
input=Input.Connection,
)
# VAE for encoding control images (required when using control)
vae: Optional[VAEField] = InputField(
default=None,
description=FieldDescriptions.vae + " Required for control conditioning.",
input=Input.Connection,
)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
Expand Down Expand Up @@ -206,12 +229,15 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
device=device,
)

# Load negative conditioning if provided and guidance_scale > 0
# Load negative conditioning if provided and guidance_scale != 1.0
# CFG formula: pred = pred_uncond + cfg_scale * (pred_cond - pred_uncond)
# At cfg_scale=1.0: pred = pred_cond (no effect, skip uncond computation)
# This matches FLUX's convention where 1.0 means "no CFG"
neg_prompt_embeds: torch.Tensor | None = None
do_classifier_free_guidance = self.guidance_scale > 0.0 and self.negative_conditioning is not None
do_classifier_free_guidance = not math.isclose(self.guidance_scale, 1.0) and self.negative_conditioning is not None
if do_classifier_free_guidance:
if self.negative_conditioning is None:
raise ValueError("Negative conditioning is required when guidance_scale > 0")
raise ValueError("Negative conditioning is required when guidance_scale != 1.0")
neg_prompt_embeds = self._load_text_conditioning(
context=context,
conditioning_name=self.negative_conditioning.conditioning_name,
Expand Down Expand Up @@ -293,9 +319,6 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
)

with ExitStack() as exit_stack:
# Load transformer and apply LoRA patches
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())

# Get transformer config to determine if it's quantized
transformer_config = context.models.get_config(self.transformer.transformer)

Expand All @@ -309,6 +332,102 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
else:
raise ValueError(f"Unsupported Z-Image model format: {transformer_config.format}")

# Load transformer - always use base transformer, control is handled via extension
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())

# Prepare control extension if control is provided
control_extension: ZImageControlNetExtension | None = None

if self.control is not None:
# Load control adapter using context manager (proper GPU memory management)
control_model_info = context.models.load(self.control.control_model)
(_, control_adapter) = exit_stack.enter_context(control_model_info.model_on_device())
assert isinstance(control_adapter, ZImageControlAdapter)

# Get control_in_dim from adapter config (16 for V1, 33 for V2.0)
adapter_config = control_adapter.config
control_in_dim = adapter_config.get("control_in_dim", 16)
num_control_blocks = adapter_config.get("num_control_blocks", 6)

# Log control configuration for debugging
version = "V2.0" if control_in_dim > 16 else "V1"
context.util.signal_progress(
f"Using Z-Image ControlNet {version} (Extension): control_in_dim={control_in_dim}, "
f"num_blocks={num_control_blocks}, scale={self.control.control_context_scale}"
)

# Load and prepare control image - must be VAE-encoded!
if self.vae is None:
raise ValueError("VAE is required when using Z-Image Control. Connect a VAE to the 'vae' input.")

control_image = context.images.get_pil(self.control.image_name)

# Resize control image to match output dimensions
control_image = control_image.convert("RGB")
control_image = control_image.resize((self.width, self.height), Image.Resampling.LANCZOS)

# Convert to tensor format for VAE encoding
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor

control_image_tensor = image_resized_to_grid_as_tensor(control_image)
if control_image_tensor.dim() == 3:
control_image_tensor = einops.rearrange(control_image_tensor, "c h w -> 1 c h w")

# Encode control image through VAE to get latents
vae_info = context.models.load(self.vae.vae)
control_latents = ZImageImageToLatentsInvocation.vae_encode(
vae_info=vae_info,
image_tensor=control_image_tensor,
)

# Move to inference device/dtype
control_latents = control_latents.to(device=device, dtype=inference_dtype)

# Add frame dimension: [B, C, H, W] -> [C, 1, H, W] (single image)
control_latents = control_latents.squeeze(0).unsqueeze(1)

# Prepare control_cond based on control_in_dim
# V1: 16 channels (just control latents)
# V2.0: 33 channels = 16 control + 16 reference + 1 mask
# - Channels 0-15: control image latents (from VAE encoding)
# - Channels 16-31: reference/inpaint image latents (zeros for pure control)
# - Channel 32: inpaint mask (1.0 = don't inpaint, 0.0 = inpaint region)
# For pure control (no inpainting), we set mask=1 to tell model "use control, don't inpaint"
c, f, h, w = control_latents.shape
if c < control_in_dim:
padding_channels = control_in_dim - c
if padding_channels == 17:
# V2.0: 16 reference channels (zeros) + 1 mask channel (ones)
ref_padding = torch.zeros(
(16, f, h, w),
device=device,
dtype=inference_dtype,
)
# Mask channel = 1.0 means "don't inpaint this region, use control signal"
mask_channel = torch.ones(
(1, f, h, w),
device=device,
dtype=inference_dtype,
)
control_latents = torch.cat([control_latents, ref_padding, mask_channel], dim=0)
else:
# Generic padding with zeros for other cases
zero_padding = torch.zeros(
(padding_channels, f, h, w),
device=device,
dtype=inference_dtype,
)
control_latents = torch.cat([control_latents, zero_padding], dim=0)

# Create control extension (adapter is already on device from model_on_device)
control_extension = ZImageControlNetExtension(
control_adapter=control_adapter,
control_cond=control_latents,
weight=self.control.control_context_scale,
begin_step_percent=self.control.begin_step_percent,
end_step_percent=self.control.end_step_percent,
)

# Apply LoRA models to the transformer.
# Note: We apply the LoRA after the transformer has been moved to its target device for faster patching.
exit_stack.enter_context(
Expand Down Expand Up @@ -340,25 +459,50 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
latent_model_input = latent_model_input.unsqueeze(2) # Add frame dimension
latent_model_input_list = list(latent_model_input.unbind(dim=0))

# Transformer returns (List[torch.Tensor], dict) - we only need the tensor list
model_output = transformer(
x=latent_model_input_list,
t=timestep,
cap_feats=[pos_prompt_embeds],
# Determine if control should be applied at this step
apply_control = (
control_extension is not None and control_extension.should_apply(step_idx, total_steps)
)
model_out_list = model_output[0] # Extract list of tensors from tuple

# Run forward pass - use custom forward with control if extension is active
if apply_control:
model_out_list, _ = z_image_forward_with_control(
transformer=transformer,
x=latent_model_input_list,
t=timestep,
cap_feats=[pos_prompt_embeds],
control_extension=control_extension,
)
else:
model_output = transformer(
x=latent_model_input_list,
t=timestep,
cap_feats=[pos_prompt_embeds],
)
model_out_list = model_output[0] # Extract list of tensors from tuple

noise_pred_cond = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred_cond = noise_pred_cond.squeeze(2) # Remove frame dimension
noise_pred_cond = -noise_pred_cond # Z-Image uses v-prediction with negation

# Apply CFG if enabled
if do_classifier_free_guidance and neg_prompt_embeds is not None:
model_output_uncond = transformer(
x=latent_model_input_list,
t=timestep,
cap_feats=[neg_prompt_embeds],
)
model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple
if apply_control:
model_out_list_uncond, _ = z_image_forward_with_control(
transformer=transformer,
x=latent_model_input_list,
t=timestep,
cap_feats=[neg_prompt_embeds],
control_extension=control_extension,
)
else:
model_output_uncond = transformer(
x=latent_model_input_list,
t=timestep,
cap_feats=[neg_prompt_embeds],
)
model_out_list_uncond = model_output_uncond[0] # Extract list of tensors from tuple

noise_pred_uncond = torch.stack([t.float() for t in model_out_list_uncond], dim=0)
noise_pred_uncond = noise_pred_uncond.squeeze(2)
noise_pred_uncond = -noise_pred_uncond
Expand Down
Loading
Loading