Skip to content

Conversation

@andrew-k-park
Copy link

@andrew-k-park andrew-k-park commented Dec 4, 2025

Description

Optimize video frame preprocessing for LLaVA-NeXT-Video-7B model on GPU by creating an OpenVINO preprocessing model to move preprocessing operations from CPU to GPU

Ticket: CVS-177558

Average 1st token latency (1280x720 5s video (32 frames) + 100 input tokens -> generate 128 tokens)

CPP preprocessing (GPU)  2906.118 ms
OV preprocessing (GPU)   845.6711 ms
CPP preprocessing (CPU)  15321.59 ms
OV preprocessing (CPU)   14327.6  ms

WWB results with video input (--model-type visual-video-text):

CPP preprocessing (GPU)  0.880806
OV preprocessing (GPU)   0.860603
CPP preprocessing (CPU)  0.918247
OV preprocessing (CPU)   0.906167

WWB results with image input (--model-type visual-text):

CPP preprocessing (GPU)  0.899575
OV preprocessing (GPU)   0.881532
CPP preprocessing (CPU)  0.903088
OV preprocessing (CPU)   0.902678

Checklist:

  • Tests have been updated or added to cover the new code.
  • This patch fully addresses the ticket.
  • I have made corresponding changes to the documentation.

Copilot AI review requested due to automatic review settings December 4, 2025 13:23
@github-actions github-actions bot added the category: visual language Visual language pipeline label Dec 4, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes video frame preprocessing for the LLaVA-NeXT-Video-7B model by implementing GPU-accelerated preprocessing using OpenVINO operations instead of CPU-based preprocessing. The change provides significant performance improvements, reducing first token latency from ~15s (CPU) to ~845ms (GPU with OV preprocessing).

Key changes:

  • Added OpenVINO-based preprocessing model that performs resize, crop, and normalization on GPU
  • Implemented environment variable control to switch between CPU and GPU preprocessing
  • Refactored preprocessing logic to support both CPU (preprocess_frames_cpp) and GPU (preprocess_frames_ov) paths

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
src/cpp/src/visual_language/llava_next_video/classes.hpp Added new methods for CPU and GPU preprocessing, added preprocessing model infrastructure and use flag
src/cpp/src/visual_language/llava_next_video/classes.cpp Implemented OpenVINO preprocessing model creation and GPU-accelerated frame preprocessing logic

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch 2 times, most recently from 8037bec to 89aa9b7 Compare December 8, 2025 07:58
Copilot AI review requested due to automatic review settings December 8, 2025 07:58
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch from 89aa9b7 to c6480b2 Compare December 9, 2025 12:20
Copilot AI review requested due to automatic review settings December 10, 2025 05:04
@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch from c6480b2 to 212b086 Compare December 10, 2025 05:04
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch from b6c840a to ae8bcbd Compare December 10, 2025 05:35
Copilot AI review requested due to automatic review settings December 10, 2025 05:35
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@andrew-k-park
Copy link
Author

@yatarkan @Wovchena Could you review this PR?

return sliced;
}

std::shared_ptr<ov::Model> create_video_preprocess_model(const ProcessorConfig& config) {
Copy link
Contributor

@yatarkan yatarkan Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we apply the same OV preprocessing to images as well (in parent class)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

preprocessing for images and videos is similar, but the normalization formula and functions used differ and need verification. Once confirmed, they will likely be applied consistently. currently, the focus is on video frame preprocessing, so after validation, this can be reflected in a follow-up PR later

@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch from ae8bcbd to c99b28d Compare December 14, 2025 12:10
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated no new comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch from 082ab54 to 8862644 Compare December 15, 2025 11:48
@andrew-k-park
Copy link
Author

@yatarkan Could you review this PR again?

Comment on lines 533 to 541
ov::Shape concat_shape = preprocessed_frames[0].get_shape();
concat_shape[0] = preprocessed_frames.size();
ov::Tensor concatenated_frames = ov::Tensor(preprocessed_frames[0].get_element_type(), concat_shape);

float* frames_data = concatenated_frames.data<float>();
for (size_t i = 0; i < preprocessed_frames.size(); i++) {
memcpy(frames_data, preprocessed_frames[i].data(), preprocessed_frames[i].get_byte_size());
frames_data+=ov::shape_size(preprocessed_frames[i].get_shape());
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to move tensor concatenation to preprocess_frames so it can be also optimized with OV processing?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved tensor concatenation into each preprocess_frames function and for OV processing, I modified it to handle the concatenated frames as batch processing

Comment on lines 221 to 227
auto preprocess_model = create_video_preprocess_model(m_processor_config);
auto compiled_preprocess = utils::singleton_core().compile_model(preprocess_model, device, properties);
m_ireq_queue_preprocess = std::make_unique<CircularBufferQueue<ov::InferRequest>>(
compiled_preprocess.get_property(ov::optimal_number_of_infer_requests),
[&compiled_preprocess]() -> ov::InferRequest {
return compiled_preprocess.create_infer_request();
});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In similar PRs we had a pattern that preprocessing was patched into vision encoder model (see example).
I would follow the same approach if it does not contradict/interfere with image input preprocessing.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the example's approach, I removed the dedicated preprocessing model and updated preprocessing to patch preprocessing pipeline directly into vision encoder

Copilot AI review requested due to automatic review settings December 17, 2025 01:28
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings December 17, 2025 01:54
@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch from aac54d3 to e493f3a Compare December 17, 2025 01:54
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@github-actions github-actions bot added the category: GGUF GGUF file reader label Dec 17, 2025
@andrew-k-park
Copy link
Author

@yatarkan I've applied the final comments and CI has passed. Could you review the PR once more?

Comment on lines 67 to 72
# Disable OV preprocessing for video in tests to avoid input parameter conflicts
# The integrated preprocessing model changes the vision encoder inputs from a single
# 'pixel_values' parameter to multiple parameters (video_frames, resize_target_size,
# crop_height, crop_width), which conflicts with image encoding that still expects
# the original 'pixel_values' input
os.environ["VIDEO_PREPROCESS"] = "CPP"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As default is OV preprocessing (not cpp) user will face issues on running the model with image input. This breaks expected functionality and behavior.
As noted in #3097 (comment), patching vision encoder model makes sense if there are no conflicts with image processing, but they are. So let's don't follow the "vision encoder patching approach" until preprocessing is used and aligned for both images and videos.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation has been updated so that preprocessing is aligned for both images and videos. The WWB results for image input will be updated through the description.

std::vector<ov::genai::EncodedVideo> encoded_videos;
for (const auto video: videos) {
std::vector<ov::Tensor> frames = to_single_image_tensors({video});
auto vision_encoder = std::static_pointer_cast<VisionEncoderLLaVANextVideo>(m_vision_encoder);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be moved outside the for-loop?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved

memcpy(frames_data, prepprocessed_frames[i].data(), prepprocessed_frames[i].get_byte_size());
frames_data+=ov::shape_size(prepprocessed_frames[i].get_shape());
}
auto config = vision_encoder->get_processor_config();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be moved outside the for-loop?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved


std::vector<ov::genai::EncodedVideo> InputsEmbedderLLaVANextVideo::encode_videos(const std::vector<ov::Tensor>& videos) {
std::vector<ov::genai::EncodedVideo> encoded_videos;
for (const auto video: videos) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I agree with some of the copilot comments/suggestions despite they were mark as resolved.

Taking into account that patching vision encoder conflicts with image preprocessing, I suggest the following flow:

for (const auto video: videos) {
    ImageSize original_size = ...;
    ImageSize target_size = get_resize_target_size(original_size, config); // utility function, to be reused in preprocess_clip_image_llava_next_video(...)
    size_t num_frames = video.get_shape().at(0);

    ov::Tensor pixel_values;
    size_t num_video_tokens;
    if (vision_encoder->get_use_ov_preprocess()) {
        // We don't need here to split video tensor into vector of single frames, ov_video_preprocess_model will handle batched frames
        ov::Tensor target_size_tensor(ov::element::i64, {2});
        target_size_tensor.data<int64_t>()[0] = target_size.height;
        target_size_tensor.data<int64_t>()[1] = target_size.width;
        
        ov::Tensor crop_size_tensor(ov::element::i64, {2});
        crop_size_tensor.data<int64_t>()[0] = config.crop_size_height;
        crop_size_tensor.data<int64_t>()[1] = config.crop_size_width;

        // Pass video, target size, crop size to ov_video_preprocess_model -> pixel_values for vision_encoder model
        ov_video_preprocess_model.set_input_tensor(0, video);
        ov_video_preprocess_model.set_input_tensor(1, target_size_tensor);
        ov_video_preprocess_model.set_input_tensor(2, crop_size_tensor); // both crop_height and crop_width as for target_size_tensor input
        ov_video_preprocess_model.infer();
        pixel_values = ov_video_preprocess_model.get_output_tensor();
        num_video_tokens = vision_encoder->get_num_video_tokens(target_size, num_frames); // utility method, to be reused in preprocess_frames(...)
    } else {
        // Follow original CPP flow
        std::vector<ov::Tensor> frames = to_single_image_tensors({video});
        // Preprocess and concatenate preprocessed frames to single tensor -> pixel_values for vision_encoder model
        pixel_values = vision_encoder->preprocess_frames(frames); // concatenate preprocessed frames to single tensor inside
        num_video_tokens = vision_encoder->get_num_video_tokens(target_size, num_frames);
    }
    vision_encoder_infer_request.set_tensor("pixel_values", pixel_values);
    // ...
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As applied changes to support preprocessing for video and image, refactored the code by creating helper function for duplicate logic and cleaned up the overall code

Copilot AI review requested due to automatic review settings December 18, 2025 05:26
@github-actions github-actions bot removed the category: GGUF GGUF file reader label Dec 18, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

Comments suppressed due to low confidence (1)

src/cpp/src/visual_language/llava_next_video/classes.cpp:415

  • The variable searched_pos is declared but never used in the function. This appears to be dead code that should be removed.
    size_t searched_pos = 0;

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch from 2229b5f to 46e1de9 Compare December 18, 2025 05:27
Comment on lines 123 to 124
size_t orig_height,
size_t orig_width,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not to use ov::genai::ImageSize struct, e.g.

Suggested change
size_t orig_height,
size_t orig_width,
ImageSize original_size,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And for return value

std::pair<int64_t, int64_t> calculate_resize_dimensions(
size_t orig_height,
size_t orig_width,
int target_shortest_edge) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config.size_shortest_edge has size_t type

Comment on lines 136 to 137
size_t orig_height,
size_t orig_width,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use ov::genai::ImageSize struct, here and in other places where possible

Comment on lines 561 to 562
size_t num_video_tokens = ((config.crop_size_height / m_patch_size) *
(config.crop_size_width / m_patch_size) / 4) * num_frames;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that num_video_tokens can be calculated outside of preprocess_frames_ov/preprocess_frames_cpp methods so this will remove duplication

Comment on lines 576 to 578
auto [concatenated_frames, num_video_tokens] = vision_encoder->get_use_ov_preprocess()
? vision_encoder->preprocess_frames_ov(frames)
: vision_encoder->preprocess_frames_cpp(frames);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that preprocess_frames_ov is not really needed. If we calculate num_video_tokens outside the preprocess_frames_* methods, frames concatenation is applicable only for CPP processing.
For OV processing we can pass video tensor to set_preprocess_parameters below instead of concatenated_frames

}

bool can_use_ov_preprocess() {
const char* env = std::getenv("VISION_PREPROCESS");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please align env var name with other models (qwen2vl, phi3_vision).
Actually I find you name more suitable as it relates to both image and video inputs, so I would prefer replacing IMAGE_PREPROCESS with VISION_PREPROCESS.

Copilot AI review requested due to automatic review settings December 22, 2025 04:25
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

namespace ov::genai {

namespace {

Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The create_bicubic_resize function lacks documentation explaining its purpose, parameters, and the significance of the bicubic resize configuration. Add a docstring explaining that this creates a bicubic resize operation for NHWC format inputs, the meaning of the cube_coeff value (-0.5 for Catmull-Rom), and why ASYMMETRIC coordinate transformation mode is used.

Suggested change
// Creates a bicubic resize operation for NHWC-formatted inputs.
//
// Parameters:
// - input: Input tensor in NHWC layout (N: batch, H: height, W: width, C: channels).
// - target_size: 1D tensor with two elements: [new_height, new_width] used with
// ShapeCalcMode::SIZES to define the output spatial size.
//
// The interpolation is configured to:
// - Operate on the spatial axes [1, 2] corresponding to H and W in NHWC.
// - Use cubic interpolation with cube_coeff = -0.5f, which corresponds to the
// Catmull-Rom bicubic kernel (a = -0.5) and is chosen to match CPU preprocessing.
// - Use CoordinateTransformMode::ASYMMETRIC so that source and target coordinates
// are mapped without half-pixel offsets, aligning with common preprocessing
// behavior in vision models and OpenVINO-based CPU pipelines.

Copilot uses AI. Check for mistakes.

return std::make_shared<v11::Interpolate>(input_f32, target_size, axes, attrs);
}

Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The create_mean_scale function lacks documentation explaining the normalization formula being implemented and why the conversion logic differs from the context example. Add a docstring explaining the per-channel normalization formula: (x/255.0 - mean) / std.

Suggested change
/**
* Builds an OpenVINO subgraph that applies per-channel image normalization.
*
* The implemented formula matches the original mean_scale() preprocessing logic:
*
* y[c] = ( x[c] / 255.0f - image_mean[c] ) / image_std[c]
*
* where:
* - x is the input pixel value in the range [0, 255] when provided as uint8,
* - image_mean[c] and image_std[c] are channel-wise mean and std values taken
* from ProcessorConfig::image_mean and ProcessorConfig::image_std,
* - the operation is performed per channel c with broadcasting for NHWC
* tensors using constants of shape [1, 1, 1, 3].
*
* Unlike some context examples that always start from uint8 tensors, this helper
* accepts either u8 or f32 input:
* - if the input is u8, it is first converted to f32 to faithfully reproduce
* the original mean_scale() behavior: float(x) / 255.0f;
* - if the input is already f32 (e.g., pre-scaled elsewhere), it is used
* directly to avoid redundant conversions while still applying the same
* (x/255.0 - mean) / std normalization formula via OV ops.
*/

Copilot uses AI. Check for mistakes.

return result;
}

Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The create_channels_first function lacks documentation explaining the transpose operation. Add a docstring indicating this converts from NHWC to NCHW layout.

Suggested change
/// Creates a transpose node that converts an input tensor from NHWC to NCHW layout.

Copilot uses AI. Check for mistakes.
auto transpose_order = v0::Constant::create(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
return std::make_shared<v1::Transpose>(input_nhwc, transpose_order);
}

Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The create_center_crop function lacks documentation explaining its purpose and the crop calculation logic. Add a docstring explaining that this performs center cropping by calculating start positions as (dimension - crop_size) / 2.

Suggested change
/**
* Perform a center crop on the spatial dimensions of an NHWC input tensor.
*
* The requested crop size (height, width) is taken from {@code crop_size}, and
* the crop region is positioned at the center of the input by computing the
* starting coordinates as:
* start_y = (H - crop_height) / 2
* start_x = (W - crop_width) / 2
* where H and W are the input tensor height and width. The function then
* slices the input tensor using these start positions and the given crop size.
*/

Copilot uses AI. Check for mistakes.
const char* env = std::getenv("VISION_PREPROCESS");
return !(env && std::string(env) == "CPP");
}

Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patch_preprocess_into_vision_encoder_model function lacks documentation explaining its purpose and the preprocessing pipeline it creates. Add a docstring describing that this integrates bicubic resize, center crop, normalization, and channel transpose operations into the vision encoder model.

Suggested change
/**
* @brief Integrates a preprocessing pipeline into a vision encoder model.
*
* This function wraps the provided @p vision_encoder_model with an OpenVINO subgraph
* that performs the image preprocessing typically done on the CPU. The new model
* exposes three inputs:
* - input_frames: concatenated image/video frames in NHWC uint8 format
* - resize_target_size: target spatial size [height, width] for bicubic resize
* - crop_size: center crop size [height, width]
*
* The injected preprocessing pipeline consists of:
* 1. Bicubic resize (Interpolate with CUBIC mode) to @p resize_target_size.
* 2. Center crop to @p crop_size.
* 3. Per-channel normalization using mean/scale parameters from @p config.
* 4. Channel transpose from NHWC to NCHW (channels-first layout).
*
* The output of this pipeline is connected to the original encoder's first input
* (typically "pixel_values"), so that the returned model directly accepts raw
* uint8 frames and produces the same outputs as the original encoder.
*
* @param vision_encoder_model Original vision encoder model to be patched.
* @param config Processor configuration providing normalization parameters and sizes.
* @return A new model with preprocessing integrated into the encoder graph.
*/

Copilot uses AI. Check for mistakes.
Comment on lines 457 to 458
ov::Tensor image_newline;
size_t searched_pos = 0;
std::vector<ov::Tensor> image_embeds;
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable image_newline is declared but never used in the get_inputs_embeds function. Remove this unused variable declaration.

Copilot uses AI. Check for mistakes.
@andrew-k-park andrew-k-park force-pushed the preproc_opt_for_llava_next_video branch from 922fe4f to c6dcf71 Compare December 22, 2025 04:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

category: visual language Visual language pipeline

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants