diff --git a/src/cpp/src/rag/text_embedding_pipeline.cpp b/src/cpp/src/rag/text_embedding_pipeline.cpp index b2d44ded2f..0c3d0ebeab 100644 --- a/src/cpp/src/rag/text_embedding_pipeline.cpp +++ b/src/cpp/src/rag/text_embedding_pipeline.cpp @@ -45,6 +45,11 @@ bool has_token_type_ids_input(const T& inputs) { return false; } +void set_node_name(std::shared_ptr node, const std::string& name) { + node->set_friendly_name(name); + node->get_output_tensor(0).set_names({name}); +} + /** * CLS pooling slices first element from seq_length dimension * [batch_size, seq_length, hidden_size] -> [batch_size, seq_length[0], hidden_size] @@ -62,12 +67,10 @@ std::shared_ptr get_cls_pooling_op(const ov::Output& last_hidd return std::make_shared(slice, squeeze_axis); } -std::shared_ptr get_mean_pooling_op(std::shared_ptr model, - const ov::Output& last_hidden_state_node) { +std::shared_ptr get_mean_pooling_op(const ov::Output& last_hidden_state_node, + const ov::Output& attention_mask) { auto shape_of = std::make_shared(last_hidden_state_node); - auto attention_mask = model->input("attention_mask").get_node()->outputs()[0]; - auto unsqueze_axis = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{-1}); auto unsqueze = std::make_shared(attention_mask, unsqueze_axis); @@ -95,8 +98,8 @@ std::shared_ptr get_mean_pooling_op(std::shared_ptr model, return std::make_shared(sum_hidden_state, max_expanded_mask); } -std::shared_ptr get_last_token_pooling_op(std::shared_ptr model, - const ov::Output& last_hidden_state_node, +std::shared_ptr get_last_token_pooling_op(const ov::Output& last_hidden_state_node, + const ov::Output& attention_mask, const TextEmbeddingPipeline::Config& config) { const auto left_padding = config.padding_side.has_value() && config.padding_side.value() == "left"; @@ -115,8 +118,6 @@ std::shared_ptr get_last_token_pooling_op(std::shared_ptr model, return std::make_shared(slice, squeeze_axis); } - auto attention_mask = model->input("attention_mask").get_node()->outputs()[0]; - auto axis_1 = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); auto reduce_sum = std::make_shared(attention_mask, axis_1); auto subtract_1 = std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{1}); @@ -125,31 +126,71 @@ std::shared_ptr get_last_token_pooling_op(std::shared_ptr model, return std::make_shared(last_hidden_state_node, subtract, axis_1, 1); } +std::shared_ptr create_post_ops(const ov::Output& input, + const ov::Output& attention_mask, + const TextEmbeddingPipeline::Config& config) { + if (config.pooling_type == TextEmbeddingPipeline::PoolingType::CLS) { + return get_cls_pooling_op(input); + } else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::MEAN) { + return get_mean_pooling_op(input, attention_mask); + } else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::LAST_TOKEN) { + return get_last_token_pooling_op(input, attention_mask, config); + } + + OPENVINO_THROW("Pooling type is not supported"); +} + +std::shared_ptr create_normalize_ops(const ov::Output& input, + const TextEmbeddingPipeline::Config& config) { + if (config.normalize) { + auto axis_const = std::make_shared(ov::element::i32, ov::Shape{1}, std::vector{1}); + return std::make_shared(input, axis_const, 1e-12, op::EpsMode::MAX); + } + return std::dynamic_pointer_cast(input.get_node_shared_ptr()); +} + std::shared_ptr apply_postprocessing(std::shared_ptr model, const TextEmbeddingPipeline::Config& config) { ov::preprocess::PrePostProcessor processor(model); processor.output().postprocess().custom([model, &config](const ov::Output& node) { - if (config.pooling_type == TextEmbeddingPipeline::PoolingType::CLS) { - return get_cls_pooling_op(node); - } else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::MEAN) { - return get_mean_pooling_op(model, node); - } else if (config.pooling_type == TextEmbeddingPipeline::PoolingType::LAST_TOKEN) { - return get_last_token_pooling_op(model, node, config); - } - - OPENVINO_THROW("Pooling type is not supported"); + auto attention_mask = model->input("attention_mask").get_node()->outputs()[0]; + return create_post_ops(node, attention_mask, config); }); if (config.normalize) { - processor.output().postprocess().custom([](const ov::Output& node) { - auto axis_const = std::make_shared(ov::element::i32, ov::Shape{1}, std::vector{1}); - return std::make_shared(node, axis_const, 1e-12, op::EpsMode::MAX); + processor.output().postprocess().custom([&config](const ov::Output& node) { + return create_normalize_ops(node, config); }); } return processor.build(); } +std::shared_ptr create_post_model(std::shared_ptr model, + const TextEmbeddingPipeline::Config& config, + ov::Dimension::value_type max_prompt_size) { + auto output_node = model->outputs()[0]; + auto output_shape = output_node.get_partial_shape(); + auto input_param = + std::make_shared(output_node.get_element_type(), ov::PartialShape{1, max_prompt_size, output_shape[2]}); + set_node_name(input_param, "input_ids"); + + auto attention_mask = std::make_shared(ov::element::i64, ov::PartialShape{1, max_prompt_size}); + set_node_name(attention_mask, "attention_mask"); + + auto post_output = create_post_ops(input_param, attention_mask, config); + auto post_normalize_output = create_normalize_ops(post_output, config); + OPENVINO_ASSERT(post_normalize_output != nullptr); + + auto result_node = std::make_shared(post_normalize_output); + set_node_name(result_node, "last_hidden_state"); + auto post_model = + std::make_shared(ov::OutputVector{result_node}, ov::ParameterVector{input_param, attention_mask}); + post_model->set_friendly_name(model->get_friendly_name() + "_post_process"); + post_model->validate_nodes_and_infer_types(); + return post_model; +} + std::optional read_max_position_embeddings(const std::filesystem::path& models_path) { // config.json not found. Skip parameters initialization from file, use defaults. const std::filesystem::path& json_path = models_path / "config.json"; @@ -211,32 +252,53 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl { auto model = core.read_model(models_path / "openvino_model.xml", {}, properties); - const bool should_reshape = m_config.batch_size.has_value() || m_config.max_length.has_value(); - if (should_reshape) { - reshape_model(model); - } - - if (device == "NPU") { - OPENVINO_ASSERT(!model->is_dynamic(), - "NPU device does not support dynamic shapes. In order to fix model shape, set batch_size, " - "max_length and pad_to_max_length in the configuration."); - } - - model = apply_postprocessing(model, m_config); - + bool is_seq_len_fixed = true; if (m_config.max_length) { m_tokenization_params.insert({max_length.name(), *m_config.max_length}); + } else { + is_seq_len_fixed = false; } if (m_config.pad_to_max_length) { m_tokenization_params.insert({pad_to_max_length.name(), *m_config.pad_to_max_length}); + is_seq_len_fixed &= m_config.pad_to_max_length.value(); + } else { + is_seq_len_fixed = false; } if (m_config.padding_side) { m_tokenization_params.insert({padding_side.name(), *m_config.padding_side}); } - ov::CompiledModel compiled_model = core.compile_model(model, device, properties); + bool should_reshape_non_npu = + (device != "NPU" && (m_config.batch_size.has_value() || m_config.max_length.has_value())); + bool should_reshape_npu = (device == "NPU" && m_config.batch_size.has_value() && is_seq_len_fixed); + if (should_reshape_non_npu || should_reshape_npu) { + reshape_model(model); + } + + ov::CompiledModel compiled_model; + if (device == "NPU" && model->is_dynamic()) { + OPENVINO_ASSERT(m_config.max_length.has_value(), "The parameter max_length is not set"); + + bool is_padding_on_left = m_config.padding_side.has_value() && m_config.padding_side.value() == "left"; + if (is_padding_on_left && is_seq_len_fixed && + config.pooling_type != TextEmbeddingPipeline::PoolingType::MEAN) { + OPENVINO_THROW("Padding on left is only supported for the MEAN pooling type"); + } + + auto kv_pos = ov::genai::utils::get_kv_axes_pos(model); + utils::KVDesc kv_desc; + std::tie(compiled_model, kv_desc) = + utils::compile_decoder_for_npu_text_embedding(model, properties, kv_pos, m_config); + + auto post_model = create_post_model(model, m_config, m_config.max_length.value()); + auto post_compiled_model = core.compile_model(post_model, "CPU"); + m_post_request = post_compiled_model.create_infer_request(); + } else { + model = apply_postprocessing(model, m_config); + compiled_model = core.compile_model(model, device, properties); + } utils::print_compiled_model_properties(compiled_model, "text embedding model"); m_request = compiled_model.create_infer_request(); @@ -281,9 +343,11 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl { private: Tokenizer m_tokenizer; InferRequest m_request; + InferRequest m_post_request; Config m_config; AnyMap m_tokenization_params; std::optional m_max_position_embeddings; + ov::Tensor m_attention_mask; void reshape_model(std::shared_ptr& model) { ov::PartialShape target_shape{ov::Dimension::dynamic(), ov::Dimension::dynamic()}; @@ -321,6 +385,28 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl { model->reshape(input_name_to_shape); } + ov::Tensor post_model_infer(ov::Tensor input) { + if (m_post_request) { + m_post_request.set_tensor("input_ids", input); + + auto attention_mask_tensor = m_post_request.get_tensor("attention_mask"); + + std::copy_n(m_attention_mask.data(), + m_attention_mask.get_size(), + attention_mask_tensor.data()); + if (m_attention_mask.get_size() < attention_mask_tensor.get_size()) { + std::fill_n(attention_mask_tensor.data() + m_attention_mask.get_size(), + attention_mask_tensor.get_size() - m_attention_mask.get_size(), + 0); + } + + m_post_request.infer(); + return m_post_request.get_tensor("last_hidden_state"); + } + + return input; + } + void start_embed_async(std::vector& texts) { if (m_config.batch_size.has_value()) { // if batch_size is set, model shape is fixed @@ -332,10 +418,11 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl { } const auto encoded = m_tokenizer.encode(texts, m_tokenization_params); - m_request.set_tensor("input_ids", encoded.input_ids); m_request.set_tensor("attention_mask", encoded.attention_mask); + m_attention_mask = encoded.attention_mask; + // fill token_type_ids // todo: pass token_type_ids from tokenizer if (has_token_type_ids_input(m_request.get_compiled_model().inputs())) { @@ -351,9 +438,8 @@ class TextEmbeddingPipeline::TextEmbeddingPipelineImpl { m_request.wait(); // [batch_size, hidden_size] - const Tensor last_hidden_state = m_request.get_tensor("last_hidden_state"); - - return to_embedding_result(last_hidden_state); + const auto last_hidden_state = m_request.get_tensor("last_hidden_state"); + return to_embedding_result(post_model_infer(last_hidden_state)); }; std::vector format_texts(const std::vector& texts) { diff --git a/src/cpp/src/utils.cpp b/src/cpp/src/utils.cpp index 8948832f26..b420628865 100644 --- a/src/cpp/src/utils.cpp +++ b/src/cpp/src/utils.cpp @@ -116,6 +116,21 @@ void update_npu_config_whisper(ov::AnyMap& config, update_config(config, {"NPUW_LLM_PREFILL_HINT", "STATIC"}); } +void update_npu_config_text_embedding(ov::AnyMap& config, + const ov::genai::utils::KVAxesPosition& kv_pos, + const ov::genai::utils::KVDesc& kv_desc) { + update_config(config, {"NPU_USE_NPUW", "YES"}); + update_config(config, {"NPUW_LLM", "YES"}); + update_config(config, {"NPUW_LLM_BATCH_DIM", kv_pos.batch}); + update_config(config, {"NPUW_LLM_SEQ_LEN_DIM", kv_pos.seq_len}); + + update_config(config, {"NPUW_LLM_MAX_PROMPT_LEN", kv_desc.max_prompt_len}); + update_config(config, {"NPUW_LLM_MIN_RESPONSE_LEN", kv_desc.min_response_len}); + update_config(config, {"NPUW_LLM_SHARED_HEAD", "NO"}); + + update_config(config, {"NPUW_TEXT_EMBED", "YES"}); +} + inline bool is_paged_attention_available() { #if defined(OPENVINO_ARCH_X86_64) || defined(OPENVINO_ARCH_ARM64) return true; @@ -130,6 +145,8 @@ namespace ov { namespace genai { namespace utils { +enum class ModelType { Default, Whisper, TextEmbedding }; + Tensor init_attention_mask(const Tensor& input_ids) { auto shape = input_ids.get_shape(); auto attention_mask = ov::Tensor{input_ids.get_element_type(), shape}; @@ -570,11 +587,72 @@ void print_scheduler_config_info(const SchedulerConfig &scheduler_config) { std::cout << scheduler_config.to_string() << std::endl; } -std::pair -compile_decoder_for_npu(const std::shared_ptr& model, - const ov::AnyMap& config, - const KVAxesPosition& kv_pos, - const bool is_whisper) { +void import_npu_model(ov::CompiledModel& compiled, + KVDesc& kv_desc, + const ov::AnyMap& config, + const std::string& blob_path) { + if (!std::filesystem::exists(blob_path)) { + OPENVINO_THROW("Blob file is not found at: " + blob_path); + } + std::ifstream fin(blob_path, std::ios::in | std::ios::binary); + if (!fin.is_open()) { + OPENVINO_THROW("Blob file can't be opened: " + blob_path); + } + compiled = ov::genai::utils::singleton_core().import_model(fin, "NPU", config); + kv_desc.max_prompt_len = compiled.get_property("NPUW_LLM_MAX_PROMPT_LEN").as(); + kv_desc.min_response_len = compiled.get_property("NPUW_LLM_MIN_RESPONSE_LEN").as(); +} + +void export_npu_model(ov::CompiledModel& compiled, const std::string& blob_path) { + // Check the path is full + const int EXT_SIZE = 5; // ".blob" + if (blob_path.size() < EXT_SIZE) { + OPENVINO_THROW("Please provide a full path to blob file in BLOB_PATH: " + blob_path); + } + if (strncmp(&blob_path[blob_path.size() - EXT_SIZE], ".blob", EXT_SIZE) != 0) { + OPENVINO_THROW("Please provide a full path to blob file in BLOB_PATH: " + blob_path); + } + std::ofstream fout(blob_path, std::ios::out | std::ios::binary); + if (!fout.is_open()) { + OPENVINO_THROW("Blob file can't be exported to: " + blob_path); + } + compiled.export_model(fout); +} + +void get_npu_model_config(ov::AnyMap& properties, + const KVAxesPosition& kv_pos, + KVDesc& kv_desc, + const bool is_whisper) { + if (is_whisper) { + kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(4u); + // kvcache size for Whisper = 448u (MAX_PROMPT_LEN + MIN_RESPONSE_LEN) + kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(444u); + update_npu_config_whisper(properties, kv_pos, kv_desc); + } else { + kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u); + kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(128u); + update_npu_config(properties, kv_pos, kv_desc); + } +} + +void get_npu_text_embedding_config(ov::AnyMap& properties, + const KVAxesPosition& kv_pos, + KVDesc& kv_desc, + const TextEmbeddingPipeline::Config& text_embed_config) { + if (text_embed_config.max_length.has_value()) { + kv_desc.max_prompt_len = text_embed_config.max_length.value(); + } else { + kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u); + } + kv_desc.min_response_len = kv_desc.max_prompt_len; + update_npu_config_text_embedding(properties, kv_pos, kv_desc); +} + +std::pair compile_decoder_for_npu_impl(const std::shared_ptr& model, + const ov::AnyMap& config, + const KVAxesPosition& kv_pos, + ModelType model_type, + const TextEmbeddingPipeline::Config& text_embed_config = {}) { ov::CompiledModel compiled; ov::AnyMap properties = config; KVDesc kv_desc; @@ -584,49 +662,46 @@ compile_decoder_for_npu(const std::shared_ptr& model, const bool do_import = (!blob_path.empty() && !export_blob); if (do_import) { - if (!std::filesystem::exists(blob_path)) { - OPENVINO_THROW("Blob file is not found at: " + blob_path); - } - std::ifstream fin(blob_path, std::ios::in | std::ios::binary); - if (!fin.is_open()) { - OPENVINO_THROW("Blob file can't be opened: " + blob_path); - } - compiled = ov::genai::utils::singleton_core().import_model(fin, "NPU", config); - kv_desc.max_prompt_len = compiled.get_property("NPUW_LLM_MAX_PROMPT_LEN").as(); - kv_desc.min_response_len = compiled.get_property("NPUW_LLM_MIN_RESPONSE_LEN").as(); + import_npu_model(compiled, kv_desc, properties, blob_path); } else { - if (is_whisper) { - kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(4u); - // kvcache size for Whisper = 448u (MAX_PROMPT_LEN + MIN_RESPONSE_LEN) - kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(444u); - update_npu_config_whisper(properties, kv_pos, kv_desc); - } else { - kv_desc.max_prompt_len = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u); - kv_desc.min_response_len = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(128u); - update_npu_config(properties, kv_pos, kv_desc); + switch (model_type) { + case ModelType::TextEmbedding: + get_npu_text_embedding_config(properties, kv_pos, kv_desc, text_embed_config); + break; + case ModelType::Whisper: + get_npu_model_config(properties, kv_pos, kv_desc, true); + break; + case ModelType::Default: + default: + get_npu_model_config(properties, kv_pos, kv_desc, false); + break; } + compiled = ov::genai::utils::singleton_core().compile_model(model, "NPU", properties); // Also export compiled model if required if (export_blob) { if (blob_path.empty()) { blob_path = "openvino_model.blob"; } - // Check the path is full - const int EXT_SIZE = 5; // ".blob" - if (blob_path.size() < EXT_SIZE) { - OPENVINO_THROW("Please provide a full path to blob file in BLOB_PATH: " + blob_path); - } - if (strncmp(".blob", &blob_path[blob_path.size() - EXT_SIZE], EXT_SIZE) != 0) { - OPENVINO_THROW("Please provide a full path to blob file in BLOB_PATH: " + blob_path); - } - std::ofstream fout(blob_path, std::ios::out | std::ios::binary); - if (!fout.is_open()) { - OPENVINO_THROW("Blob file can't be exported to: " + blob_path); - } - compiled.export_model(fout); + export_npu_model(compiled, blob_path); } } - return { compiled, kv_desc }; + + return {compiled, kv_desc}; +} + +std::pair compile_decoder_for_npu(const std::shared_ptr& model, + const ov::AnyMap& config, + const KVAxesPosition& kv_pos, + const bool is_whisper) { + return compile_decoder_for_npu_impl(model, config, kv_pos, is_whisper ? ModelType::Whisper : ModelType::Default); +} + +std::pair compile_decoder_for_npu_text_embedding(const std::shared_ptr& model, + const ov::AnyMap& config, + const KVAxesPosition& kv_pos, + const TextEmbeddingPipeline::Config& text_embed_config) { + return compile_decoder_for_npu_impl(model, config, kv_pos, ModelType::TextEmbedding, text_embed_config); } std::optional pop_option(ov::AnyMap& config, const std::string& option_name) { diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index 14106ef8f7..e2443690e1 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -9,6 +9,7 @@ #include "openvino/genai/llm_pipeline.hpp" #include "openvino/genai/visual_language/pipeline.hpp" +#include "openvino/genai/rag/text_embedding_pipeline.hpp" #include "openvino/runtime/core.hpp" #include "openvino/genai/generation_handle.hpp" @@ -196,6 +197,11 @@ std::pair compile_decoder_for_npu(const std::shared_p const KVAxesPosition& kv_pos, const bool is_whisper = false); +std::pair compile_decoder_for_npu_text_embedding(const std::shared_ptr& model, + const ov::AnyMap& config, + const KVAxesPosition& kv_pos, + const ov::genai::TextEmbeddingPipeline::Config& text_embed_config); + /// @brief SharedOptional is a wrapper around a reference to an existing object and an optional shared alternative value. /// The difference from std::optional is that the default state is not empty and contains a reference to an existing object outside the class. /// Another difference is that the alternative value is shared between all instances of SharedOptional like std::shared_ptr. diff --git a/tests/python_tests/test_rag.py b/tests/python_tests/test_rag.py index 013a925908..bccc638027 100644 --- a/tests/python_tests/test_rag.py +++ b/tests/python_tests/test_rag.py @@ -107,11 +107,16 @@ def run_text_embedding_genai( documents: list[str], config: TextEmbeddingPipeline.Config | None = None, task: Literal["embed_documents", "embed_query"] = "embed_documents", + device: str = "CPU", + properties: dict | None = None, ): if not config: config = TextEmbeddingPipeline.Config() - pipeline = TextEmbeddingPipeline(models_path, "CPU", config) + if properties: + pipeline = TextEmbeddingPipeline(models_path, device, config, **properties) + else: + pipeline = TextEmbeddingPipeline(models_path, device, config) if config.batch_size: documents = documents[: config.batch_size] @@ -192,13 +197,13 @@ def run_qwen3_embedding_optimum( MAX_EMBEDDING_ERROR = 2e-6 if sys.platform != "darwin" else 0.02 # ARM64 macs have different results -def validate_embedding_results(result_1: EmbeddingResult, result_2: EmbeddingResult): +def validate_embedding_results(result_1: EmbeddingResult, result_2: EmbeddingResult, threshold: float = MAX_EMBEDDING_ERROR): __tracebackhide__ = True np_result_1 = np.array(result_1) np_result_2 = np.array(result_2) max_error = np.abs(np_result_1 - np_result_2).max() - assert max_error < MAX_EMBEDDING_ERROR, f"Max error: {max_error} is greater than allowed {MAX_EMBEDDING_ERROR}" + assert max_error < threshold, f"Max error: {max_error} is greater than allowed {threshold}" @@ -351,6 +356,142 @@ def test_qwen3_embedding(emb_model, dataset_documents, config): validate_embedding_results(embeddings_genai, embeddings_opt.tolist()) +@pytest.mark.parametrize( + "emb_model", + ["Qwen/Qwen3-Embedding-0.6B"], + indirect=True, +) +@pytest.mark.parametrize( + ("config", "chunk_size", "threshold"), + [ + # Chunk disabled + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.CLS, + padding_side="right" + ), 0, 2e-4), + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.LAST_TOKEN, + padding_side="right" + ), 0, 2e-4), + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.MEAN, + padding_side="right" + ), 0, 2e-4), + + # Chunk enabled + # 33 tokens handled by a chunk of 128 + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 256, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.CLS, + padding_side="right" + ), 128, 2e-4), + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 256, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.LAST_TOKEN, + padding_side="right" + ), 128, 2e-4), + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 256, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.MEAN, + padding_side="right" + ), 128, 2e-4), + + # 33 tokens handled by 3 chunks of 16 + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.CLS, + padding_side="right" + ), 16, 6e-3), + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.LAST_TOKEN, + padding_side="right" + ), 16, 6e-3), + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = False, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.MEAN, + padding_side="right" + ), 16, 6e-3), + + # normalize = True, 33 tokens handled by 3 chunks of 16 + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = True, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.CLS, + padding_side="right" + ), 16, 7e-5), + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = True, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.LAST_TOKEN, + padding_side="right" + ), 16, 7e-5), + (TextEmbeddingPipeline.Config( + batch_size = 1, + max_length = 192, + normalize = True, + pad_to_max_length = False, + pooling_type=TextEmbeddingPipeline.PoolingType.MEAN, + padding_side="right" + ), 16, 7e-5), + ], +) +@pytest.mark.xfail(condition=(sys.platform == "darwin"), reason="Ticket - 174635") +def test_qwen3_embedding_npu(emb_model, dataset_documents, config, chunk_size, threshold): + NPU_FALLBACK_PROPERTIES = {"NPUW_DEVICES": "CPU", "NPUW_F16IC": "False", "NPUW_LLM_PREFILL_CHUNK_SIZE" : chunk_size} + + embeddings_genai_cpu = run_text_embedding_genai( + emb_model.models_path, + dataset_documents, + config, + "embed_documents", + device="CPU" + ) + embeddings_genai_npu = run_text_embedding_genai( + emb_model.models_path, + dataset_documents, + config, + "embed_documents", + device="NPU", + properties=NPU_FALLBACK_PROPERTIES + ) + validate_embedding_results(embeddings_genai_npu, embeddings_genai_cpu, threshold) + + @pytest.mark.parametrize("emb_model", ["BAAI/bge-small-en-v1.5"], indirect=True) def test_embedding_constructors(emb_model): models_path = emb_model.models_path diff --git a/tools/llm_bench/benchmark.py b/tools/llm_bench/benchmark.py index f4f63a137c..0483dc402c 100644 --- a/tools/llm_bench/benchmark.py +++ b/tools/llm_bench/benchmark.py @@ -210,6 +210,7 @@ def get_argprser(): help="Pooling type CLS or MEAN for encoders, LAST_TOKEN for decoders. " "Different post-processing is applied depending on the padding side. Applicable only for text embeddings") parser.add_argument("--embedding_normalize", action="store_true", help="Normalize embeddings. Applicable only for text embeddings") + parser.add_argument("--embedding_pad_to_max_length", action="store_true", help="Pad embeddings. Applicable only for text embeddings") parser.add_argument("--embedding_max_length", type=int, default=None, help="Max length for text embeddings. Input text will be padded or truncated to specified value") parser.add_argument("--embedding_padding_side", choices=["left", "right"], default=None, diff --git a/tools/llm_bench/llm_bench_utils/model_utils.py b/tools/llm_bench/llm_bench_utils/model_utils.py index 1a9ccaf237..233d7575b4 100644 --- a/tools/llm_bench/llm_bench_utils/model_utils.py +++ b/tools/llm_bench/llm_bench_utils/model_utils.py @@ -136,6 +136,7 @@ def analyze_args(args): model_args['emb_normalize'] = args.embedding_normalize model_args["emb_max_length"] = args.embedding_max_length model_args["emb_padding_side"] = args.embedding_padding_side + model_args["emb_pad_to_max_length"] = args.embedding_pad_to_max_length model_args['rerank_max_length'] = args.reranking_max_length model_args["rerank_top_n"] = args.reranking_top_n model_args["rerank_texts"] = args.texts diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index ec57706dd4..ba36175560 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -666,7 +666,7 @@ def create_genai_text_embed_model(model_path, device, memory_data_collector, **k pooling_type = kwargs.get("emb_pooling_type") max_length = kwargs.get("emb_max_length") - padding_side = kwargs.get("embedding_padding_side") + padding_side = kwargs.get("emb_padding_side") config = openvino_genai.TextEmbeddingPipeline.Config() if pooling_type is not None: @@ -678,7 +678,8 @@ def create_genai_text_embed_model(model_path, device, memory_data_collector, **k config.pooling_type = openvino_genai.TextEmbeddingPipeline.PoolingType.CLS if max_length is not None: config.max_length = max_length - config.pad_to_max_length = True + + config.pad_to_max_length = kwargs.get("emb_pad_to_max_length", False) config.normalize = kwargs.get("emb_normalize", False) if padding_side: config.padding_side = padding_side