diff --git a/conditioner.hpp b/conditioner.hpp index a4e84aa3b..3c31bec27 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -10,9 +10,14 @@ struct SDCondition { struct ggml_tensor* c_vector = nullptr; // aka y struct ggml_tensor* c_concat = nullptr; + std::vector extra_c_crossattns; + SDCondition() = default; - SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat) - : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} + SDCondition(struct ggml_tensor* c_crossattn, + struct ggml_tensor* c_vector, + struct ggml_tensor* c_concat, + const std::vector& extra_c_crossattns = {}) + : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat), extra_c_crossattns(extra_c_crossattns) {} }; struct ConditionerParams { @@ -1657,18 +1662,23 @@ struct LLMEmbedder : public Conditioner { } std::tuple, std::vector> tokenize(std::string text, - std::pair attn_range, + const std::pair& attn_range, size_t max_length = 0, bool padding = false) { std::vector> parsed_attention; - parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); - if (attn_range.second - attn_range.first > 0) { - auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); - parsed_attention.insert(parsed_attention.end(), - new_parsed_attention.begin(), - new_parsed_attention.end()); - } - parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + if (attn_range.first >= 0 && attn_range.second > 0) { + parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.second - attn_range.first > 0) { + auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); + parsed_attention.insert(parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); + } + parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + } else { + parsed_attention.emplace_back(text, 1.f); + } + { std::stringstream ss; ss << "["; @@ -1699,156 +1709,27 @@ struct LLMEmbedder : public Conditioner { return {tokens, weights}; } - SDCondition get_learned_condition(ggml_context* work_ctx, - int n_threads, - const ConditionerParams& conditioner_params) override { - std::string prompt; - std::vector> image_embeds; - std::pair prompt_attn_range; - int prompt_template_encode_start_idx = 34; - int max_length = 0; - std::set out_layers; - std::vector tokens; - std::vector weights; + ggml_tensor* encode_prompt(ggml_context* work_ctx, + int n_threads, + const std::string prompt, + const std::pair& prompt_attn_range, + int max_length, + int min_length, + std::vector> image_embeds, + const std::set& out_layers, + int prompt_template_encode_start_idx) { + auto tokens_and_weights = tokenize(prompt, prompt_attn_range); + auto& tokens = std::get<0>(tokens_and_weights); + auto& weights = std::get<1>(tokens_and_weights); std::vector mask; - if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { - LOG_INFO("QwenImageEditPlusPipeline"); - prompt_template_encode_start_idx = 64; - int image_embed_idx = 64 + 6; - - int min_pixels = 384 * 384; - int max_pixels = 560 * 560; - std::string placeholder = "<|image_pad|>"; - std::string img_prompt; - - for (int i = 0; i < conditioner_params.ref_images.size(); i++) { - sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; - int height = image.height; - int width = image.width; - int h_bar = static_cast(std::round(height / factor) * factor); - int w_bar = static_cast(std::round(width / factor) * factor); - - if (static_cast(h_bar) * w_bar > max_pixels) { - double beta = std::sqrt((height * width) / static_cast(max_pixels)); - h_bar = std::max(static_cast(factor), - static_cast(std::floor(height / beta / factor)) * static_cast(factor)); - w_bar = std::max(static_cast(factor), - static_cast(std::floor(width / beta / factor)) * static_cast(factor)); - } else if (static_cast(h_bar) * w_bar < min_pixels) { - double beta = std::sqrt(static_cast(min_pixels) / (height * width)); - h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); - w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); - } - - LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); - - sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); - free(image.data); - image.data = nullptr; - - ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); - free(resized_image.data); - resized_image.data = nullptr; - - ggml_tensor* image_embed = nullptr; - llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); - image_embeds.emplace_back(image_embed_idx, image_embed); - image_embed_idx += 1 + static_cast(image_embed->ne[1]) + 6; - - img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] - int64_t num_image_tokens = image_embed->ne[1]; - img_prompt.reserve(num_image_tokens * placeholder.size()); - for (int j = 0; j < num_image_tokens; j++) { - img_prompt += placeholder; - } - img_prompt += "<|vision_end|>"; - } - - prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; - prompt += img_prompt; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } else if (version == VERSION_FLUX2) { - prompt_template_encode_start_idx = 0; - out_layers = {10, 20, 30}; - - prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "[/INST]"; - } else if (sd_version_is_z_image(version)) { - prompt_template_encode_start_idx = 0; - out_layers = {35}; // -2 - - prompt = "<|im_start|>user\n"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } else if (version == VERSION_FLUX2_KLEIN) { - prompt_template_encode_start_idx = 0; - max_length = 512; - out_layers = {9, 18, 27}; - - prompt = "<|im_start|>user\n"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; - - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false); - tokens = std::get<0>(tokens_and_weights); - weights = std::get<1>(tokens_and_weights); + if (max_length > 0 && tokens.size() < max_length) { mask.insert(mask.end(), tokens.size(), 1.f); - if (tokens.size() < max_length) { - mask.insert(mask.end(), max_length - tokens.size(), 0.f); - tokenizer->pad_tokens(tokens, weights, max_length, true); - } - } else if (version == VERSION_OVIS_IMAGE) { - prompt_template_encode_start_idx = 28; - max_length = prompt_template_encode_start_idx + 256; - - prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += " " + conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; - } else { - prompt_template_encode_start_idx = 34; - - prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + mask.insert(mask.end(), max_length - tokens.size(), 0.f); + tokenizer->pad_tokens(tokens, weights, max_length, true); } - if (tokens.empty()) { - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); - tokens = std::get<0>(tokens_and_weights); - weights = std::get<1>(tokens_and_weights); - } - - int64_t t0 = ggml_time_ms(); - struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584] + struct ggml_tensor* hidden_states = nullptr; // [N, n_token, hidden_size] auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); @@ -1891,11 +1772,6 @@ struct LLMEmbedder : public Conditioner { GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); - int64_t min_length = 0; - if (version == VERSION_FLUX2) { - min_length = 512; - } - int64_t zero_pad_len = 0; if (min_length > 0) { if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) { @@ -1917,11 +1793,186 @@ struct LLMEmbedder : public Conditioner { ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); }); - // print_ggml_tensor(new_hidden_states); + return new_hidden_states; + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const ConditionerParams& conditioner_params) override { + std::string prompt; + std::pair prompt_attn_range; + std::vector extra_prompts; + std::vector> extra_prompts_attn_range; + std::vector> image_embeds; + int prompt_template_encode_start_idx = 34; + int max_length = 0; // pad tokens + int min_length = 0; // zero pad hidden_states + std::set out_layers; + + int64_t t0 = ggml_time_ms(); + + if (sd_version_is_qwen_image(version)) { + if (llm->enable_vision && !conditioner_params.ref_images.empty()) { + LOG_INFO("QwenImageEditPlusPipeline"); + prompt_template_encode_start_idx = 64; + int image_embed_idx = 64 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor) * factor); + int w_bar = static_cast(std::round(width / factor) * factor); + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; + + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; + + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + static_cast(image_embed->ne[1]) + 6; + + img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; + } + + prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } else { + prompt_template_encode_start_idx = 34; + + prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } + } else if (version == VERSION_FLUX2) { + prompt_template_encode_start_idx = 0; + min_length = 512; + out_layers = {10, 20, 30}; + + prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "[/INST]"; + } else if (sd_version_is_z_image(version)) { + prompt_template_encode_start_idx = 0; + out_layers = {35}; // -2 + + if (!conditioner_params.ref_images.empty()) { + LOG_INFO("ZImageOmniPipeline"); + prompt = "<|im_start|>user\n<|vision_start|>"; + for (int i = 0; i < conditioner_params.ref_images.size() - 1; i++) { + extra_prompts.push_back("<|vision_end|><|vision_start|>"); + } + extra_prompts.push_back("<|vision_end|>" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"); + extra_prompts.push_back("<|vision_end|><|im_end|>"); + } else { + prompt = "<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } + } else if (version == VERSION_FLUX2_KLEIN) { + prompt_template_encode_start_idx = 0; + max_length = 512; + out_layers = {9, 18, 27}; + + prompt = "<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + } else if (version == VERSION_OVIS_IMAGE) { + prompt_template_encode_start_idx = 28; + max_length = prompt_template_encode_start_idx + 256; + + prompt = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background:"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += " " + conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; + } else { + GGML_ABORT("unknown version %d", version); + } + + auto hidden_states = encode_prompt(work_ctx, + n_threads, + prompt, + prompt_attn_range, + max_length, + min_length, + image_embeds, + out_layers, + prompt_template_encode_start_idx); + + std::vector extra_hidden_states_vec; + for (int i = 0; i < extra_prompts.size(); i++) { + auto extra_hidden_states = encode_prompt(work_ctx, + n_threads, + extra_prompts[i], + extra_prompts_attn_range[i], + max_length, + min_length, + image_embeds, + out_layers, + prompt_template_encode_start_idx); + extra_hidden_states_vec.push_back(extra_hidden_states); + } int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); - return {new_hidden_states, nullptr, nullptr}; + return {hidden_states, nullptr, nullptr, extra_hidden_states_vec}; } }; diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 06cbecc28..89a26a4bc 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -23,6 +23,8 @@ struct DiffusionParams { struct ggml_tensor* vace_context = nullptr; float vace_strength = 1.f; std::vector skip_layers = {}; + std::vector extra_contexts; // for z-image-omni + std::vector ref_clip_feats; // for z-image-omni }; struct DiffusionModel { @@ -436,12 +438,14 @@ struct ZImageModel : public DiffusionModel { DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { + std::vector contexts = {diffusion_params.context}; + contexts.insert(contexts.end(), diffusion_params.extra_contexts.begin(), diffusion_params.extra_contexts.end()); return z_image.compute(n_threads, diffusion_params.x, diffusion_params.timesteps, - diffusion_params.context, + contexts, diffusion_params.ref_latents, - true, // increase_ref_index + diffusion_params.ref_clip_feats, output, output_ctx); } diff --git a/model.cpp b/model.cpp index 253dd25cd..4a6fad896 100644 --- a/model.cpp +++ b/model.cpp @@ -1039,6 +1039,8 @@ SDVersion ModelLoader::get_sd_version() { bool is_xl = false; bool is_flux = false; bool is_flux2 = false; + bool is_z_image = false; + bool is_z_image_omni = false; bool has_single_block_47 = false; bool is_wan = false; int64_t patch_embedding_channels = 0; @@ -1071,7 +1073,10 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_OVIS_IMAGE; } if (tensor_storage.name.find("model.diffusion_model.cap_embedder.0.weight") != std::string::npos) { - return VERSION_Z_IMAGE; + is_z_image = true; + } + if (tensor_storage.name.find("model.diffusion_model.siglip_embedder.0.weight") != std::string::npos) { + is_z_image_omni = true; } if (tensor_storage.name.find("model.diffusion_model.blocks.0.cross_attn.norm_k.weight") != std::string::npos) { is_wan = true; @@ -1174,6 +1179,13 @@ SDVersion ModelLoader::get_sd_version() { return VERSION_FLUX2_KLEIN; } + if (is_z_image) { + if (is_z_image_omni) { + return VERSION_Z_IMAGE_OMNI; + } + return VERSION_Z_IMAGE; + } + if (token_embedding_weight.ne[0] == 768) { if (is_inpaint) { return VERSION_SD1_INPAINT; diff --git a/model.h b/model.h index e16ac3a07..553e15961 100644 --- a/model.h +++ b/model.h @@ -48,6 +48,7 @@ enum SDVersion { VERSION_FLUX2, VERSION_FLUX2_KLEIN, VERSION_Z_IMAGE, + VERSION_Z_IMAGE_OMNI, VERSION_OVIS_IMAGE, VERSION_COUNT, }; @@ -123,7 +124,7 @@ static inline bool sd_version_is_qwen_image(SDVersion version) { } static inline bool sd_version_is_z_image(SDVersion version) { - if (version == VERSION_Z_IMAGE) { + if (version == VERSION_Z_IMAGE || version == VERSION_Z_IMAGE_OMNI) { return true; } return false; diff --git a/rope.hpp b/rope.hpp index 45e88c831..deeb8d61f 100644 --- a/rope.hpp +++ b/rope.hpp @@ -518,60 +518,117 @@ namespace Rope { return (m - (a % m)) % m; } - __STATIC_INLINE__ std::vector> gen_z_image_ids(int h, - int w, + __STATIC_INLINE__ std::vector> gen_z_image_ids(ggml_tensor* x, + const std::vector& contexts, + const std::vector& ref_latents, + const std::vector& siglip_feats, int patch_size, - int bs, - int context_len, int seq_multi_of, - const std::vector& ref_latents, - bool increase_ref_index) { - int padded_context_len = context_len + bound_mod(context_len, seq_multi_of); - auto txt_ids = std::vector>(bs * padded_context_len, std::vector(3, 0.0f)); - for (int i = 0; i < bs * padded_context_len; i++) { - txt_ids[i][0] = (i % padded_context_len) + 1.f; + int bs) { + GGML_ASSERT(contexts.size() > ref_latents.size()); + GGML_ASSERT(contexts.size() >= siglip_feats.size()); + int context_cu_len = 1; + std::vector context_end_pos; + std::vector> txt_ids; + for (auto context : contexts) { + int padded_context_len = static_cast(context->ne[1]) + bound_mod(static_cast(context->ne[1]), seq_multi_of); + auto curr_txt_ids = std::vector>(bs * padded_context_len, std::vector(3, 0.0f)); + for (int i = 0; i < bs * padded_context_len; i++) { + curr_txt_ids[i][0] = static_cast((i % padded_context_len) + context_cu_len); + } + context_cu_len += padded_context_len; + context_end_pos.push_back(context_cu_len); + context_cu_len += 2; // for image and siglip tokens + txt_ids = concat_ids(txt_ids, curr_txt_ids, bs); } - int axes_dim_num = 3; - int index = padded_context_len + 1; - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index); + std::vector> img_ids; + std::vector all_img = ref_latents; + all_img.push_back(x); + for (int i = 0; i < all_img.size(); i++) { + int axes_dim_num = 3; + int index = context_end_pos[i]; + auto curr_img_ids = gen_flux_img_ids(static_cast(all_img[i]->ne[1]), static_cast(all_img[i]->ne[0]), patch_size, bs, axes_dim_num, index); + + int img_pad_len = bound_mod(static_cast(curr_img_ids.size() / bs), seq_multi_of); + if (img_pad_len > 0) { + std::vector> img_pad_ids(bs * img_pad_len, std::vector(3, 0.f)); + curr_img_ids = concat_ids(curr_img_ids, img_pad_ids, bs); + } + img_ids = concat_ids(img_ids, curr_img_ids, bs); + } + + std::vector> sig_ids; + for (int i = 0; i < siglip_feats.size(); i++) { + int axes_dim_num = 3; + int index = context_end_pos[i] + 1; + int h_len = static_cast(siglip_feats[i]->ne[1]); + int w_len = static_cast(siglip_feats[i]->ne[0]); + + std::vector> curr_sig_ids(bs * h_len * w_len, std::vector(axes_dim_num, 0.0)); + + // scale position IDs to match img resolution + std::vector row_ids = linspace(0, static_cast(all_img[i]->ne[1]) - 1.f, h_len); + std::vector col_ids = linspace(0, static_cast(all_img[i]->ne[0]) - 1.f, w_len); + + for (int ib = 0; ib < bs; ++ib) { + for (int ih = 0; ih < h_len; ++ih) { + for (int iw = 0; iw < w_len; ++iw) { + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][0] = static_cast(index); + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][1] = row_ids[ih]; + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][2] = col_ids[iw]; + } + } + } - int img_pad_len = bound_mod(static_cast(img_ids.size() / bs), seq_multi_of); - if (img_pad_len > 0) { - std::vector> img_pad_ids(bs * img_pad_len, std::vector(3, 0.f)); - img_ids = concat_ids(img_ids, img_pad_ids, bs); + int sig_pad_len = bound_mod(static_cast(curr_sig_ids.size() / bs), seq_multi_of); + if (sig_pad_len > 0) { + std::vector> sig_pad_ids(bs * sig_pad_len, std::vector(3, 0.f)); + curr_sig_ids = concat_ids(curr_sig_ids, sig_pad_ids, bs); + } + sig_ids = concat_ids(sig_ids, curr_sig_ids, bs); } auto ids = concat_ids(txt_ids, img_ids, bs); - // ignore ref_latents for now + if (!sig_ids.empty()) { + ids = concat_ids(ids, sig_ids, bs); + } + return ids; } // Generate z_image positional embeddings - __STATIC_INLINE__ std::vector gen_z_image_pe(int h, - int w, + __STATIC_INLINE__ std::vector gen_z_image_pe(ggml_tensor* x, + const std::vector& contexts, + const std::vector& ref_latents, + const std::vector& siglip_feats, int patch_size, - int bs, - int context_len, int seq_multi_of, - const std::vector& ref_latents, - bool increase_ref_index, int theta, + const std::vector& axes_dim, bool circular_h, bool circular_w, - const std::vector& axes_dim) { - std::vector> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); + int bs) { + std::vector> ids = gen_z_image_ids(x, contexts, ref_latents, siglip_feats, patch_size, seq_multi_of, bs); std::vector> wrap_dims; if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int context_len = 0; + for (auto context : contexts) { + int padded_context_len = static_cast(context->ne[1]) + bound_mod(static_cast(context->ne[1]), seq_multi_of); + context_len += padded_context_len; + } + int h = static_cast(x->ne[1]); + int w = static_cast(x->ne[0]); int pad_h = (patch_size - (h % patch_size)) % patch_size; int pad_w = (patch_size - (w % patch_size)) % patch_size; int h_len = (h + pad_h) / patch_size; int w_len = (w + pad_w) / patch_size; + if (h_len > 0 && w_len > 0) { size_t pos_len = ids.size() / bs; wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); - size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding) + size_t cursor = context_len; // skip text (and its padding) size_t img_tokens = static_cast(h_len) * static_cast(w_len); for (size_t token_i = 0; token_i < img_tokens; ++token_i) { if (circular_h) { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index b181f994b..1eb1c80bb 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -51,6 +51,7 @@ const char* model_version_to_str[] = { "Flux.2", "Flux.2 klein", "Z-Image", + "Z-Image-Omni", "Ovis Image", }; @@ -1615,12 +1616,13 @@ class StableDiffusionGGML { const std::vector& sigmas, int start_merge_step, SDCondition id_cond, - std::vector ref_latents = {}, - bool increase_ref_index = false, - ggml_tensor* denoise_mask = nullptr, - ggml_tensor* vace_context = nullptr, - float vace_strength = 1.f, - const sd_cache_params_t* cache_params = nullptr) { + std::vector ref_latents = {}, + std::vector ref_clip_feats = {}, + bool increase_ref_index = false, + ggml_tensor* denoise_mask = nullptr, + ggml_tensor* vace_context = nullptr, + float vace_strength = 1.f, + const sd_cache_params_t* cache_params = nullptr) { if (shifted_timestep > 0 && !sd_version_is_sdxl(version)) { LOG_WARN("timestep shifting is only supported for SDXL models!"); shifted_timestep = 0; @@ -2009,6 +2011,7 @@ class StableDiffusionGGML { diffusion_params.timesteps = timesteps; diffusion_params.guidance = guidance_tensor; diffusion_params.ref_latents = ref_latents; + diffusion_params.ref_clip_feats = ref_clip_feats; diffusion_params.increase_ref_index = increase_ref_index; diffusion_params.controls = controls; diffusion_params.control_strength = control_strength; @@ -2019,10 +2022,11 @@ class StableDiffusionGGML { struct ggml_tensor** active_output = &out_cond; if (start_merge_step == -1 || step <= start_merge_step) { // cond - diffusion_params.context = cond.c_crossattn; - diffusion_params.c_concat = cond.c_concat; - diffusion_params.y = cond.c_vector; - active_condition = &cond; + diffusion_params.context = cond.c_crossattn; + diffusion_params.extra_contexts = cond.extra_c_crossattns; + diffusion_params.c_concat = cond.c_concat; + diffusion_params.y = cond.c_vector; + active_condition = &cond; } else { diffusion_params.context = id_cond.c_crossattn; diffusion_params.c_concat = cond.c_concat; @@ -2053,12 +2057,13 @@ class StableDiffusionGGML { LOG_ERROR("controlnet compute failed"); } } - current_step_skipped = cache_step_is_skipped(); - diffusion_params.controls = controls; - diffusion_params.context = uncond.c_crossattn; - diffusion_params.c_concat = uncond.c_concat; - diffusion_params.y = uncond.c_vector; - bool skip_uncond = cache_before_condition(&uncond, out_uncond); + current_step_skipped = cache_step_is_skipped(); + diffusion_params.controls = controls; + diffusion_params.context = uncond.c_crossattn; + diffusion_params.extra_contexts = uncond.extra_c_crossattns; + diffusion_params.c_concat = uncond.c_concat; + diffusion_params.y = uncond.c_vector; + bool skip_uncond = cache_before_condition(&uncond, out_uncond); if (!skip_uncond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -2073,10 +2078,11 @@ class StableDiffusionGGML { float* img_cond_data = nullptr; if (has_img_cond) { - diffusion_params.context = img_cond.c_crossattn; - diffusion_params.c_concat = img_cond.c_concat; - diffusion_params.y = img_cond.c_vector; - bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); + diffusion_params.context = img_cond.c_crossattn; + diffusion_params.extra_contexts = img_cond.extra_c_crossattns; + diffusion_params.c_concat = img_cond.c_concat; + diffusion_params.y = img_cond.c_vector; + bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); if (!skip_img_cond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -3196,6 +3202,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, sd_pm_params_t pm_params, std::vector ref_images, std::vector ref_latents, + std::vector ref_clip_feats, bool increase_ref_index, ggml_tensor* concat_latent = nullptr, ggml_tensor* denoise_mask = nullptr, @@ -3392,6 +3399,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, start_merge_step, id_cond, ref_latents, + ref_clip_feats, increase_ref_index, denoise_mask, nullptr, @@ -3658,6 +3666,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } std::vector ref_latents; + std::vector ref_clip_feats; for (int i = 0; i < ref_images.size(); i++) { ggml_tensor* img; if (sd_img_gen_params->auto_resize_ref_image) { @@ -3704,6 +3713,11 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img); ref_latents.push_back(latent); + + if (sd_ctx->sd->version == VERSION_Z_IMAGE_OMNI) { + auto clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, *ref_images[i], false, -2); + ref_clip_feats.push_back(clip_vision_output); + } } if (sd_img_gen_params->init_image.data != nullptr || sd_img_gen_params->ref_images_count > 0) { @@ -3731,6 +3745,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->pm_params, ref_images, ref_latents, + ref_clip_feats, sd_img_gen_params->increase_ref_index, concat_latent, denoise_mask, @@ -4100,8 +4115,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s high_noise_sample_method, high_noise_sigmas, -1, - {}, - {}, + {}, // id_cond + {}, // ref_latents + {}, // ref_clip_feats false, denoise_mask, vace_context, @@ -4137,8 +4153,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sample_method, sigmas, -1, - {}, - {}, + {}, // id_cond + {}, // ref_latents + {}, // ref_clip_feats false, denoise_mask, vace_context, diff --git a/z_image.hpp b/z_image.hpp index cee23833a..d138b1e18 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -140,14 +140,37 @@ namespace ZImage { __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, struct ggml_tensor* x, - struct ggml_tensor* scale) { + struct ggml_tensor* scale, + bool skip_reshape = false) { // x: [N, L, C] - // scale: [N, C] - scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] - x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + // scale: [N, C] or [N, L, C] + if (!skip_reshape) { + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + } + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); return x; } + __STATIC_INLINE__ struct ggml_tensor* select_per_token(struct ggml_context* ctx, + struct ggml_tensor* index, + struct ggml_tensor* mod_0, + struct ggml_tensor* mod_1) { + // index: [N, L] + // mod_0/mod_1: [N, C] + // return: [N, L, C] + // mod_result = torch.where(index == 0, mod_0, mod_1) + // mod_result = (1 - index)*mod_0 + index*mod_1 + index = ggml_reshape_3d(ctx, index, 1, index->ne[0], index->ne[1]); + index = ggml_repeat_4d(ctx, index, mod_0->ne[0], index->ne[1], index->ne[2], 1); // [N, L, C] + mod_0 = ggml_reshape_3d(ctx, mod_0, mod_0->ne[0], 1, mod_0->ne[1]); // [N, 1, C] + mod_1 = ggml_reshape_3d(ctx, mod_1, mod_1->ne[0], 1, mod_1->ne[1]); // [N, 1, C] + + mod_0 = ggml_sub(ctx, ggml_repeat(ctx, mod_0, index), ggml_mul(ctx, index, mod_0)); // [N, L, C] + mod_1 = ggml_mul(ctx, index, mod_1); // [N, L, C] + auto mod_result = ggml_add(ctx, mod_0, mod_1); + return mod_result; + } + struct JointTransformerBlock : public GGMLBlock { protected: bool modulation; @@ -179,7 +202,10 @@ namespace ZImage { struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask = nullptr, - struct ggml_tensor* adaln_input = nullptr) { + struct ggml_tensor* adaln_input = nullptr, + struct ggml_tensor* noise_mask = nullptr, + struct ggml_tensor* adaln_noisy = nullptr, + struct ggml_tensor* adaln_clean = nullptr) { auto attention = std::dynamic_pointer_cast(blocks["attention"]); auto feed_forward = std::dynamic_pointer_cast(blocks["feed_forward"]); auto attention_norm1 = std::dynamic_pointer_cast(blocks["attention_norm1"]); @@ -188,32 +214,55 @@ namespace ZImage { auto ffn_norm2 = std::dynamic_pointer_cast(blocks["ffn_norm2"]); if (modulation) { - GGML_ASSERT(adaln_input != nullptr); auto adaLN_modulation_0 = std::dynamic_pointer_cast(blocks["adaLN_modulation.0"]); - auto m = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size] - auto mods = ggml_ext_chunk(ctx->ggml_ctx, m, 4, 0); - auto scale_msa = mods[0]; - auto gate_msa = mods[1]; - auto scale_mlp = mods[2]; - auto gate_mlp = mods[3]; + struct ggml_tensor* scale_msa = nullptr; + struct ggml_tensor* gate_msa = nullptr; + struct ggml_tensor* scale_mlp = nullptr; + struct ggml_tensor* gate_mlp = nullptr; + bool skip_reshape = false; + + if (noise_mask != nullptr) { + GGML_ASSERT(adaln_noisy != nullptr); + GGML_ASSERT(adaln_clean != nullptr); + + auto mod_noisy = adaLN_modulation_0->forward(ctx, adaln_noisy); // [N, 4 * hidden_size] + auto mod_clean = adaLN_modulation_0->forward(ctx, adaln_clean); // [N, 4 * hidden_size] + + auto mod_noisy_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_noisy, 4, 0); + auto mod_clean_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_clean, 4, 0); + + scale_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[0], mod_noisy_vec[0]); + gate_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[1], mod_noisy_vec[1]); + scale_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[2], mod_noisy_vec[2]); + gate_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[3], mod_noisy_vec[3]); + + skip_reshape = true; + } else { + GGML_ASSERT(adaln_input != nullptr); + + auto mod = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size] + auto mod_vec = ggml_ext_chunk(ctx->ggml_ctx, mod, 4, 0); + scale_msa = mod_vec[0]; + gate_msa = mod_vec[1]; + scale_mlp = mod_vec[2]; + gate_mlp = mod_vec[3]; + } auto residual = x; - x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa); + x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa, skip_reshape); x = attention->forward(ctx, x, pe, mask); x = attention_norm2->forward(ctx, x); x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_msa)); x = ggml_add(ctx->ggml_ctx, x, residual); residual = x; - x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp); + x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp, skip_reshape); x = feed_forward->forward(ctx, x); x = ffn_norm2->forward(ctx, x); x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_mlp)); x = ggml_add(ctx->ggml_ctx, x, residual); } else { - GGML_ASSERT(adaln_input == nullptr); - auto residual = x; x = attention_norm1->forward(ctx, x); x = attention->forward(ctx, x, pe, mask); @@ -243,7 +292,10 @@ namespace ZImage { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* c) { + struct ggml_tensor* c, + struct ggml_tensor* noise_mask = nullptr, + struct ggml_tensor* c_noisy = nullptr, + struct ggml_tensor* c_clean = nullptr) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] @@ -251,10 +303,28 @@ namespace ZImage { auto linear = std::dynamic_pointer_cast(blocks["linear"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - auto scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size] - x = norm_final->forward(ctx, x); - x = modulate(ctx->ggml_ctx, x, scale); - x = linear->forward(ctx, x); + struct ggml_tensor* scale = nullptr; + bool skip_reshape = false; + + if (noise_mask != nullptr) { + GGML_ASSERT(c_noisy != nullptr); + GGML_ASSERT(c_clean != nullptr); + + auto scale_noisy = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c_noisy)); // [N, hidden_size] + auto scale_clean = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c_clean)); // [N, hidden_size] + + scale = select_per_token(ctx->ggml_ctx, noise_mask, scale_clean, scale_noisy); + + skip_reshape = true; + } else { + GGML_ASSERT(c != nullptr); + + scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size] + } + + x = norm_final->forward(ctx, x); + x = modulate(ctx->ggml_ctx, x, scale, skip_reshape); + x = linear->forward(ctx, x); return x; } @@ -275,6 +345,7 @@ namespace ZImage { float norm_eps = 1e-5f; bool qk_norm = true; int64_t cap_feat_dim = 2560; + int64_t siglip_feat_dim = 0; int theta = 256; std::vector axes_dim = {32, 48, 48}; int64_t axes_dim_sum = 128; @@ -287,6 +358,10 @@ namespace ZImage { void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); + + if (z_image_params.siglip_feat_dim > 0) { + params["siglip_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); + } } public: @@ -328,6 +403,26 @@ namespace ZImage { blocks["context_refiner." + std::to_string(i)] = block; } + if (z_image_params.siglip_feat_dim > 0) { + blocks["siglip_embedder.0"] = std::make_shared(z_image_params.siglip_feat_dim, z_image_params.norm_eps); + blocks["siglip_embedder.1"] = std::make_shared(z_image_params.siglip_feat_dim, z_image_params.hidden_size); + + for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + auto block = std::make_shared(2000 + i, + z_image_params.hidden_size, + z_image_params.head_dim, + z_image_params.num_heads, + z_image_params.num_kv_heads, + z_image_params.multiple_of, + z_image_params.ffn_dim_multiplier, + z_image_params.norm_eps, + z_image_params.qk_norm, + false); + + blocks["siglip_refiner." + std::to_string(i)] = block; + } + } + for (int i = 0; i < z_image_params.num_layers; i++) { auto block = std::make_shared(i, z_image_params.hidden_size, @@ -409,11 +504,32 @@ namespace ZImage { return x; } - struct ggml_tensor* forward_core(GGMLRunnerContext* ctx, - struct ggml_tensor* x, - struct ggml_tensor* timestep, - struct ggml_tensor* context, - struct ggml_tensor* pe) { + std::pair _pad_and_gen_noise_mask(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* pad_token, + int N, + float noise_mask_value = 1.f) { + int64_t n_pad_token = Rope::bound_mod(static_cast(x->ne[1]), SEQ_MULTI_OF); + if (n_pad_token > 0) { + auto pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, pad_token, pad_token->ne[0], n_pad_token, N, 1); + x = ggml_concat(ctx->ggml_ctx, x, pad_tokens, 1); // [N, n_token + n_pad_token, hidden_size] + } + ggml_tensor* noise_mask = nullptr; + if (noise_mask_value == 0.f) { + noise_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[1], 1, 1, 1); + } else if (noise_mask_value == 1.f) { + noise_mask = ggml_ext_ones(ctx->ggml_ctx, x->ne[1], 1, 1, 1); + } + return {x, noise_mask}; + } + + struct ggml_tensor* forward_omni(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* timestep, + std::vector contexts, + ggml_tensor* pe, + std::vector ref_latents, + std::vector siglip_feats) { auto x_embedder = std::dynamic_pointer_cast(blocks["x_embedder"]); auto t_embedder = std::dynamic_pointer_cast(blocks["t_embedder"]); auto cap_embedder_0 = std::dynamic_pointer_cast(blocks["cap_embedder.0"]); @@ -424,31 +540,145 @@ namespace ZImage { auto txt_pad_token = params["cap_pad_token"]; auto img_pad_token = params["x_pad_token"]; - int64_t N = x->ne[2]; - int64_t n_img_token = x->ne[1]; - int64_t n_txt_token = context->ne[1]; + bool omni_mode = ref_latents.size() > 0; + + int64_t N = x->ne[2]; + + // noise mask of img: 0 for condition images (clean), 1 for target image (noisy) + // noise mask of txg/sig: same as the corresponding img. If there is no corresponding img, set to 1 + + ggml_tensor* txt = nullptr; + ggml_tensor* txt_noise_mask = nullptr; + for (int i = 0; i < contexts.size(); i++) { + auto curr_txt_raw = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, contexts[i])); // [N, n_txt_token, hidden_size] + + float noise_mask_value = -1.f; // empty noise mask + if (omni_mode) { + noise_mask_value = (i < ref_latents.size() ? 0.f : 1.f); + } + + auto [curr_txt, curr_txt_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_txt_raw, txt_pad_token, static_cast(N), noise_mask_value); + if (txt == nullptr) { + txt = curr_txt; + } else { + txt = ggml_concat(ctx->ggml_ctx, txt, curr_txt, 1); + } + + if (omni_mode) { + if (txt_noise_mask == nullptr) { + txt_noise_mask = curr_txt_noise_mask; + } else { + txt_noise_mask = ggml_concat(ctx->ggml_ctx, txt_noise_mask, curr_txt_noise_mask, 0); + } + } + } + + ggml_tensor* img = nullptr; + ggml_tensor* img_noise_mask = nullptr; + for (ggml_tensor* ref : ref_latents) { + auto curr_img_raw = x_embedder->forward(ctx, ref); // [N, n_img_token, hidden_size] + + float noise_mask_value = -1.f; // empty noise mask + if (omni_mode) { + noise_mask_value = 0.f; + } + + auto [curr_img, curr_img_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_img_raw, img_pad_token, static_cast(N), noise_mask_value); + if (img == nullptr) { + img = curr_img; + } else { + img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1); + } + + if (omni_mode) { + if (img_noise_mask == nullptr) { + img_noise_mask = curr_img_noise_mask; + } else { + img_noise_mask = ggml_concat(ctx->ggml_ctx, img_noise_mask, curr_img_noise_mask, 0); + } + } + } + + int64_t final_img_offset = (img ? img->ne[1] : 0); + int64_t final_img_pad_len = 0; + + { + auto curr_img_raw = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size] + + float noise_mask_value = -1.f; // empty noise mask + if (omni_mode) { + noise_mask_value = 0.f; + } + + auto [curr_img, curr_img_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_img_raw, img_pad_token, static_cast(N), noise_mask_value); + if (img == nullptr) { + img = curr_img; + } else { + img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1); + } + + if (omni_mode) { + if (img_noise_mask == nullptr) { + img_noise_mask = curr_img_noise_mask; + } else { + img_noise_mask = ggml_concat(ctx->ggml_ctx, img_noise_mask, curr_img_noise_mask, 0); + } + } + + final_img_pad_len = Rope::bound_mod(static_cast(curr_img_raw->ne[1]), SEQ_MULTI_OF); + } + + ggml_tensor* sig = nullptr; + ggml_tensor* sig_noise_mask = nullptr; + for (int i = 0; i < siglip_feats.size(); i++) { + auto sig_pad_token = params["siglip_pad_token"]; + auto siglip_embedder_0 = std::dynamic_pointer_cast(blocks["siglip_embedder.0"]); + auto siglip_embedder_1 = std::dynamic_pointer_cast(blocks["siglip_embedder.1"]); + + auto curr_sig_raw = siglip_embedder_1->forward(ctx, siglip_embedder_0->forward(ctx, siglip_feats[i])); // [N, n_sig_token, hidden_size] - auto t_emb = t_embedder->forward(ctx, timestep); + float noise_mask_value = -1.f; // empty noise mask + if (omni_mode) { + noise_mask_value = (i < ref_latents.size() ? 0.f : 1.f); + } - auto txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size] - auto img = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size] + auto [curr_sig, curr_sig_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_sig_raw, sig_pad_token, static_cast(N), noise_mask_value); + if (sig == nullptr) { + sig = curr_sig; + } else { + sig = ggml_concat(ctx->ggml_ctx, sig, curr_sig, 1); + } - int64_t n_txt_pad_token = Rope::bound_mod(static_cast(n_txt_token), SEQ_MULTI_OF); - if (n_txt_pad_token > 0) { - auto txt_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, txt_pad_token, txt_pad_token->ne[0], n_txt_pad_token, N, 1); - txt = ggml_concat(ctx->ggml_ctx, txt, txt_pad_tokens, 1); // [N, n_txt_token + n_txt_pad_token, hidden_size] + if (omni_mode) { + if (sig_noise_mask == nullptr) { + sig_noise_mask = curr_sig_noise_mask; + } else { + sig_noise_mask = ggml_concat(ctx->ggml_ctx, sig_noise_mask, curr_sig_noise_mask, 0); + } + } } - int64_t n_img_pad_token = Rope::bound_mod(static_cast(n_img_token), SEQ_MULTI_OF); - if (n_img_pad_token > 0) { - auto img_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, img_pad_token, img_pad_token->ne[0], n_img_pad_token, N, 1); - img = ggml_concat(ctx->ggml_ctx, img, img_pad_tokens, 1); // [N, n_img_token + n_img_pad_token, hidden_size] + ggml_tensor* t_emb = nullptr; + ggml_tensor* t_noisy = nullptr; + ggml_tensor* t_clean = nullptr; + if (omni_mode) { + t_noisy = t_embedder->forward(ctx, timestep); + t_clean = t_embedder->forward(ctx, + ggml_scale(ctx->ggml_ctx, + ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]), + 1000.f)); + } else { + t_emb = t_embedder->forward(ctx, timestep); } - GGML_ASSERT(txt->ne[1] + img->ne[1] == pe->ne[3]); + if (sig) { + GGML_ASSERT(txt->ne[1] + img->ne[1] + sig->ne[1] == pe->ne[3]); + } else { + GGML_ASSERT(txt->ne[1] + img->ne[1] == pe->ne[3]); + } auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]); - auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], pe->ne[3]); + auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], txt->ne[1] + img->ne[1]); for (int i = 0; i < z_image_params.num_refiner_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["context_refiner." + std::to_string(i)]); @@ -459,30 +689,50 @@ namespace ZImage { for (int i = 0; i < z_image_params.num_refiner_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["noise_refiner." + std::to_string(i)]); - img = block->forward(ctx, img, img_pe, nullptr, t_emb); + img = block->forward(ctx, img, img_pe, nullptr, t_emb, img_noise_mask, t_noisy, t_clean); } - auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, hidden_size] + auto unified = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] + + ggml_tensor* noise_mask = nullptr; + if (omni_mode) { + noise_mask = ggml_concat(ctx->ggml_ctx, txt_noise_mask, img_noise_mask, 0); // [N, n_txt_token + n_img_token] + } + + ggml_tensor* sig_pe = nullptr; + if (sig) { + sig_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1] + img->ne[1], pe->ne[3]); + + for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["siglip_refiner." + std::to_string(i)]); + + sig = block->forward(ctx, sig, sig_pe, nullptr, nullptr); + } + + unified = ggml_concat(ctx->ggml_ctx, unified, sig, 1); // [N, n_txt_token + n_img_token + n_sig_token, hidden_size] + noise_mask = ggml_concat(ctx->ggml_ctx, noise_mask, sig_noise_mask, 0); // [N, n_txt_token + n_img_token + n_sig_token] + } for (int i = 0; i < z_image_params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - txt_img = block->forward(ctx, txt_img, pe, nullptr, t_emb); + unified = block->forward(ctx, unified, pe, nullptr, t_emb, noise_mask, t_noisy, t_clean); } - txt_img = final_layer->forward(ctx, txt_img, t_emb); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, ph*pw*C] + unified = final_layer->forward(ctx, unified, t_emb, noise_mask, t_noisy, t_clean); // [N, n_txt_token + n_img_token + n_sig_token, ph*pw*C] - img = ggml_ext_slice(ctx->ggml_ctx, txt_img, 1, n_txt_token + n_txt_pad_token, n_txt_token + n_txt_pad_token + n_img_token); // [N, n_img_token, ph*pw*C] + img = ggml_ext_slice(ctx->ggml_ctx, unified, 1, txt->ne[1] + final_img_offset, txt->ne[1] + img->ne[1] - final_img_pad_len); // [N, n_final_img_token, ph*pw*C] return img; } struct ggml_tensor* forward(GGMLRunnerContext* ctx, - struct ggml_tensor* x, - struct ggml_tensor* timestep, - struct ggml_tensor* context, - struct ggml_tensor* pe, - std::vector ref_latents = {}) { + ggml_tensor* x, + ggml_tensor* timestep, + std::vector contexts, + ggml_tensor* pe, + std::vector ref_latents = {}, + std::vector siglip_feats = {}) { // Forward pass of DiT. // x: [N, C, H, W] // timestep: [N,] @@ -495,23 +745,20 @@ namespace ZImage { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx, x); - uint64_t n_img_token = img->ne[1]; - - if (ref_latents.size() > 0) { - for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx, ref); - img = ggml_concat(ctx->ggml_ctx, img, ref, 1); - } - } + auto img = process_img(ctx, x); int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size); int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size); - auto out = forward_core(ctx, img, timestep, context, pe); + for (int i = 0; i < ref_latents.size(); i++) { + ref_latents[i] = process_img(ctx, ref_latents[i]); + } + + auto out = forward_omni(ctx, img, timestep, contexts, pe, ref_latents, siglip_feats); // [N, n_img_token, ph*pw*C] + + // auto out = forward_basic(ctx, img, timestep, contexts[0], pe); // [N, n_img_token, ph*pw*C] - out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C] - out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] + out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] // slice out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] @@ -549,34 +796,37 @@ namespace ZImage { z_image.get_param_tensors(tensors, prefix); } - struct ggml_cgraph* build_graph(struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - std::vector ref_latents = {}, - bool increase_ref_index = false) { + struct ggml_cgraph* build_graph(ggml_tensor* x, + ggml_tensor* timesteps, + std::vector contexts, + std::vector ref_latents = {}, + std::vector siglip_feats = {}) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE); - x = to_backend(x); - context = to_backend(context); + x = to_backend(x); + + for (int i = 0; i < contexts.size(); i++) { + contexts[i] = to_backend(contexts[i]); + } + timesteps = to_backend(timesteps); for (int i = 0; i < ref_latents.size(); i++) { ref_latents[i] = to_backend(ref_latents[i]); } - pe_vec = Rope::gen_z_image_pe(static_cast(x->ne[1]), - static_cast(x->ne[0]), + pe_vec = Rope::gen_z_image_pe(x, + contexts, + ref_latents, + siglip_feats, z_image_params.patch_size, - static_cast(x->ne[3]), - static_cast(context->ne[1]), SEQ_MULTI_OF, - ref_latents, - increase_ref_index, z_image_params.theta, + z_image_params.axes_dim, circular_y_enabled, circular_x_enabled, - z_image_params.axes_dim); + static_cast(x->ne[3])); int pos_len = static_cast(pe_vec.size() / z_image_params.axes_dim_sum / 2); // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len); @@ -589,7 +839,7 @@ namespace ZImage { struct ggml_tensor* out = z_image.forward(&runner_ctx, x, timesteps, - context, + contexts, pe, ref_latents); @@ -601,16 +851,16 @@ namespace ZImage { bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, - struct ggml_tensor* context, - std::vector ref_latents = {}, - bool increase_ref_index = false, - struct ggml_tensor** output = nullptr, - struct ggml_context* output_ctx = nullptr) { + std::vector contexts, + std::vector ref_latents = {}, + std::vector siglip_feats = {}, + struct ggml_tensor** output = nullptr, + struct ggml_context* output_ctx = nullptr) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, ref_latents, increase_ref_index); + return build_graph(x, timesteps, contexts, ref_latents, siglip_feats); }; return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -642,7 +892,7 @@ namespace ZImage { struct ggml_tensor* out = nullptr; int64_t t0 = ggml_time_ms(); - compute(8, x, timesteps, context, {}, false, &out, work_ctx); + compute(8, x, timesteps, {context}, {}, {}, &out, work_ctx); int64_t t1 = ggml_time_ms(); print_ggml_tensor(out);