-
Notifications
You must be signed in to change notification settings - Fork 74
feat(sq4): implement for sq4's simd #1391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Reviewer's GuideImplements native SQ4 SIMD kernels for SSE, AVX, AVX2, and AVX512 backends and extends tests/benchmarks to validate and measure the new implementations, while making a few minor style cleanups (float suffixes, spacing, namespace comments). Flow diagram for SQ4ComputeCodesL2Sqr AVX512 kernelflowchart LR
C1[codes1 uint8_t*]
C2[codes2 uint8_t*]
LB[lower_bound float*]
DF[diff float*]
subgraph loop32[Loop over 32-dimensional blocks]
U1[unpack_4bit_to_m512 on codes1]
U2[unpack_4bit_to_m512 on codes2]
D10[decoded0 codes1]
D11[decoded1 codes1]
D20[decoded0 codes2]
D21[decoded1 codes2]
L0[Load lower_bound block 0..15]
L1[Load lower_bound block 16..31]
F0[Load diff block 0..15]
F1[Load diff block 16..31]
A10[Affine codes1 block0]
A11[Affine codes1 block1]
A20[Affine codes2 block0]
A21[Affine codes2 block1]
S0[Subtract A20 - A10]
S1[Subtract A21 - A11]
Q0[Square and accumulate block0]
Q1[Square and accumulate block1]
end
C1 --> U1
C2 --> U2
U1 --> D10
U1 --> D11
U2 --> D20
U2 --> D21
LB --> L0
LB --> L1
DF --> F0
DF --> F1
D10 --> A10
F0 --> A10
L0 --> A10
D11 --> A11
F1 --> A11
L1 --> A11
D20 --> A20
F0 --> A20
L0 --> A20
D21 --> A21
F1 --> A21
L1 --> A21
A20 --> S0
A10 --> S0
A21 --> S1
A11 --> S1
S0 --> Q0
S1 --> Q1
File-Level Changes
Possibly linked issues
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
Summary of ChangesHello @LHT129, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of 4-bit scalar quantization (SQ4) operations by introducing optimized SIMD implementations. It provides specialized code paths for SSE, AVX, AVX2, and AVX512 instruction sets, ensuring that distance calculations (Inner Product and L2 Squared) are executed with maximum efficiency on compatible hardware. The changes include detailed vectorization logic for processing 4-bit quantized data, along with comprehensive testing to validate accuracy and measure performance improvements. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there - I've reviewed your changes and found some issues that need to be addressed.
Prompt for AI Agents
Please address the comments from this code review:
## Individual Comments
### Comment 1
<location> `src/simd/avx2.cpp:789-790` </location>
<code_context>
+ // Process 16 values at a time (8 bytes containing 16 4-bit values)
+ for (; d + 15 < dim; d += 16) {
+ // Load 8 bytes (16 4-bit values)
+ __m128i code_vec = _mm_loadl_epi64((__m128i*)(codes + (d >> 1)));
+ __m128i code_vec2 = _mm_loadl_epi64((__m128i*)(codes + (d >> 1) + 4));
+
+ // Extract low nibbles (values 0,2,4,6,8,10,12,14) - even indices
</code_context>
<issue_to_address>
**issue (bug_risk):** Potential out-of-bounds reads of packed 4‑bit codes for small dimensions.
In the AVX2 SQ4 path, each loop iteration processes 16 values (d += 16) but reads 16 bytes of codes via two `_mm_loadl_epi64` calls, while the packed SQ4 buffer only guarantees `ceil(dim/2)` bytes. For `dim == 16`, that means the loop runs once and reads past the nominal 8-byte buffer. The same pattern exists in `SQ4ComputeIP`, `SQ4ComputeL2Sqr`, `SQ4ComputeCodesIP`, and `SQ4ComputeCodesL2Sqr`, so unless callers overallocate `codes`, this is undefined behavior.
Please either tighten the loop condition based on `bytes = (dim + 1) >> 1`, reduce each iteration to only load valid bytes (e.g., a single `_mm_loadl_epi64` and unpack/reuse), or route small `dim` (e.g., `<= 16`) to the SSE/generic path. The key is to ensure the maximum read index never exceeds `((dim + 1) >> 1) - 1` for any SQ4 AVX2 loop.
</issue_to_address>
### Comment 2
<location> `src/simd/avx.cpp:748-749` </location>
<code_context>
+ // Process 16 values at a time (8 bytes containing 16 4-bit values)
+ for (; d + 15 < dim; d += 16) {
+ // Load 8 bytes (16 4-bit values)
+ __m128i code_vec = _mm_loadl_epi64((__m128i*)(codes + (d >> 1)));
+ __m128i code_vec2 = _mm_loadl_epi64((__m128i*)(codes + (d >> 1) + 4));
+
+ // Extract low nibbles (values 0,2,4,6,8,10,12,14) - even indices
</code_context>
<issue_to_address>
**issue (bug_risk):** AVX SQ4 loops have the same packed-code overread risk as the AVX2 implementation.
This loop still overreads `codes` for small dims. For `dim == 16`, you process one 16‑value iteration but only have 8 bytes of codes (`(dim + 1) >> 1`), while the two `_mm_loadl_epi64` calls read 16 bytes total. That’s the same UB risk as in the AVX2 path.
Please adjust the loop bound (e.g., based on the number of available code bytes, or by reducing bytes read per iteration) and apply the same fix across all new AVX SQ4 functions (`SQ4ComputeIP`, `SQ4ComputeL2Sqr`, `SQ4ComputeCodesIP`, `SQ4ComputeCodesL2Sqr`) to keep AVX and AVX2 consistent and safe.
</issue_to_address>
### Comment 3
<location> `src/simd/sse.cpp:727` </location>
<code_context>
+ // Process 16 values at a time (8 bytes containing 16 4-bit values)
+ for (; d + 15 < dim; d += 16) {
+ // Load 8 bytes (16 4-bit values)
+ __m128i code_vec = _mm_loadl_epi64((__m128i*)(codes + (d >> 1)));
+ __m128i code_vec2 = _mm_loadl_epi64((__m128i*)(codes + (d >> 1) + 4));
+
</code_context>
<issue_to_address>
**issue (bug_risk):** SSE SQ4 implementations also read more bytes than minimally required for dim == 8.
These loops advance by 8 elements (`d += 8`) and load with `_mm_loadl_epi64(codes + (d >> 1))`. For `dim == 8`, only 4 bytes are logically valid (`(dim + 1) >> 1 = 4`), but the condition `d + 7 < dim` still allows one iteration and `_mm_loadl_epi64` reads 8 bytes from the base, which is out of bounds if the caller only provides the minimal buffer. The same pattern exists in:
- `SQ4ComputeIP`
- `SQ4ComputeL2Sqr`
- `SQ4ComputeCodesIP`
- `SQ4ComputeCodesL2Sqr`
Since only the low 4 bytes are actually used, consider either loading only 4 valid bytes (e.g., via a smaller load/memcpy into a local then `_mm_loadl_epi64`) or tightening the loop so it runs only when at least 8 bytes of `codes` are available and handling the 8-dim case in the scalar path. This avoids UB and keeps behavior consistent across SIMD variants.
</issue_to_address>
### Comment 4
<location> `src/simd/sq4_simd_test.cpp:140` </location>
<code_context>
+ }
+
TEST_CASE("SQ4 SIMD Compute Codes", "[ut][simd]") {
const std::vector<uint32_t> dims = {1, 8, 16, 32, 97, 129, 256};
int64_t count = 100;
</code_context>
<issue_to_address>
**suggestion (testing):** Add explicit coverage for dim == 0, which now has a dedicated code path in all SQ4 SIMD implementations.
The new SQ4 SIMD paths (SSE/AVX/AVX2/AVX512) add a `dim == 0` early return, but current tests only cover positive dims. Please add coverage for `dim = 0` so that all variants (`generic`, `sse`, `avx`, `avx2`, `avx512`, `neon`, `sve`) are checked to return 0 without reading input buffers. For example, either extend `dims` to include 0 in both "SQ4 SIMD Compute Codes" and "SQ4 SIMD Compute" test cases, or add a small dedicated `TEST_CASE` that calls each `SQ4*` function with `dim = 0` and minimal buffers.
Suggested implementation:
```cpp
TEST_CASE("SQ4 SIMD Compute Codes", "[ut][simd]") {
const std::vector<uint32_t> dims = {0, 1, 8, 16, 32, 97, 129, 256};
int64_t count = 100;
```
To fully implement your review comment, similar coverage for `dim == 0` should be added to the `"SQ4 SIMD Compute"` test case elsewhere in this file. The simplest aligned change is to:
1. Locate the `TEST_CASE("SQ4 SIMD Compute", ...)` block.
2. Update its `dims` vector (or equivalent loop over dimensions) to include `0` as the first element, e.g. `{0, 1, 8, 16, ...}`.
This will ensure both "Compute Codes" and "Compute" tests exercise the new early-return path for `dim == 0` across all SIMD variants.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces SIMD optimizations for SQ4 quantization across SSE, AVX, AVX2, and AVX512, which is a great performance improvement. The implementations for AVX512 are well-structured with helper functions. However, the SSE, AVX, and AVX2 implementations contain significant code duplication for unpacking 4-bit values and performing horizontal sums. I've suggested refactoring this duplicated logic into helper functions to improve maintainability and code quality. The test updates are also a valuable addition, correcting previous issues and expanding coverage.
| for (; d + 15 < dim; d += 16) { | ||
| // Load 8 bytes (16 4-bit values) | ||
| __m128i code_vec = _mm_loadl_epi64((__m128i*)(codes + (d >> 1))); | ||
| __m128i code_vec2 = _mm_loadl_epi64((__m128i*)(codes + (d >> 1) + 4)); | ||
|
|
||
| // Extract low nibbles (values 0,2,4,6,8,10,12,14) - even indices | ||
| __m128i low_nibbles1 = _mm_and_si128(code_vec, _mm_set1_epi8(0x0F)); | ||
| __m128i low_nibbles2 = _mm_and_si128(code_vec2, _mm_set1_epi8(0x0F)); | ||
|
|
||
| // Extract high nibbles (values 1,3,5,7,9,11,13,15) - odd indices | ||
| __m128i high_nibbles1 = _mm_and_si128(_mm_srli_epi16(code_vec, 4), _mm_set1_epi8(0x0F)); | ||
| __m128i high_nibbles2 = _mm_and_si128(_mm_srli_epi16(code_vec2, 4), _mm_set1_epi8(0x0F)); | ||
|
|
||
| // Interleave low and high nibbles to get correct order | ||
| __m128i interleaved1 = _mm_unpacklo_epi8(low_nibbles1, high_nibbles1); | ||
| __m128i interleaved2 = _mm_unpacklo_epi8(low_nibbles2, high_nibbles2); | ||
|
|
||
| // Convert to float and scale - first 8 values | ||
| __m128i low_part1 = _mm_cvtepu8_epi32(interleaved1); | ||
| __m128i high_part1 = _mm_cvtepu8_epi32(_mm_srli_si128(interleaved1, 4)); | ||
| __m128 values0 = _mm_cvtepi32_ps(low_part1); | ||
| __m128 values1 = _mm_cvtepi32_ps(high_part1); | ||
|
|
||
| // Convert to float and scale - next 8 values | ||
| __m128i low_part2 = _mm_cvtepu8_epi32(interleaved2); | ||
| __m128i high_part2 = _mm_cvtepu8_epi32(_mm_srli_si128(interleaved2, 4)); | ||
| __m128 values2 = _mm_cvtepi32_ps(low_part2); | ||
| __m128 values3 = _mm_cvtepi32_ps(high_part2); | ||
|
|
||
| // Combine into AVX vectors | ||
| __m256 values01 = _mm256_set_m128(values1, values0); | ||
| __m256 values23 = _mm256_set_m128(values3, values2); | ||
|
|
||
| // Scale by 1/15.0 | ||
| __m256 scale = _mm256_set1_ps(1.0f / 15.0f); | ||
| values01 = _mm256_mul_ps(values01, scale); | ||
| values23 = _mm256_mul_ps(values23, scale); | ||
|
|
||
| // Apply diff and lower_bound | ||
| __m256 diff_vec0 = _mm256_loadu_ps(diff + d); | ||
| __m256 diff_vec1 = _mm256_loadu_ps(diff + d + 8); | ||
| __m256 lb_vec0 = _mm256_loadu_ps(lower_bound + d); | ||
| __m256 lb_vec1 = _mm256_loadu_ps(lower_bound + d + 8); | ||
|
|
||
| values01 = _mm256_add_ps(_mm256_mul_ps(values01, diff_vec0), lb_vec0); | ||
| values23 = _mm256_add_ps(_mm256_mul_ps(values23, diff_vec1), lb_vec1); | ||
|
|
||
| // Load query vectors | ||
| __m256 query_vec0 = _mm256_loadu_ps(query + d); | ||
| __m256 query_vec1 = _mm256_loadu_ps(query + d + 8); | ||
|
|
||
| // Compute dot products | ||
| __m256 prod0 = _mm256_mul_ps(query_vec0, values01); | ||
| __m256 prod1 = _mm256_mul_ps(query_vec1, values23); | ||
|
|
||
| // Horizontal sum | ||
| __m256 sum = _mm256_add_ps(prod0, prod1); | ||
| __m128 sum_low = _mm256_castps256_ps128(sum); | ||
| __m128 sum_high = _mm256_extractf128_ps(sum, 1); | ||
| __m128 sum01 = _mm_add_ps(sum_low, sum_high); | ||
| __m128 sum23 = _mm_shuffle_ps(sum01, sum01, _MM_SHUFFLE(2, 3, 0, 1)); | ||
| __m128 sum0123 = _mm_add_ps(sum01, sum23); | ||
| __m128 sum4567 = _mm_movehl_ps(sum0123, sum0123); | ||
| __m128 total_sum = _mm_add_ss(sum0123, sum4567); | ||
|
|
||
| result += _mm_cvtss_f32(total_sum); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is significant code duplication across the newly added SQ4 functions (SQ4ComputeIP, SQ4ComputeL2Sqr, SQ4ComputeCodesIP, SQ4ComputeCodesL2Sqr). The logic for unpacking 4-bit codes and for horizontal summation is repeated in each function.
To improve maintainability and reduce code size, consider extracting this logic into static inline helper functions. This is similar to the approach taken in avx512.cpp with unpack_4bit_to_m512.
You could create two helpers:
- A function to unpack and convert 16 4-bit values into two
__m256float vectors. - A function for the horizontal sum of a
__m256vector.
For example:
static inline void unpack_16_4bit_to_2_m256(const uint8_t* codes, __m256* out0, __m256* out1) {
// ... unpacking logic from lines 747-777 ...
}
static inline float hsum_float_m256(__m256 v) {
// ... horizontal sum logic from lines 802-810 ...
}Using these helpers would make the main functions much cleaner and easier to maintain.
| for (; d + 15 < dim; d += 16) { | ||
| // Load 8 bytes (16 4-bit values) | ||
| __m128i code_vec = _mm_loadl_epi64((__m128i*)(codes + (d >> 1))); | ||
| __m128i code_vec2 = _mm_loadl_epi64((__m128i*)(codes + (d >> 1) + 4)); | ||
|
|
||
| // Extract low nibbles (values 0,2,4,6,8,10,12,14) - even indices | ||
| __m128i low_nibbles1 = _mm_and_si128(code_vec, _mm_set1_epi8(0x0F)); | ||
| __m128i low_nibbles2 = _mm_and_si128(code_vec2, _mm_set1_epi8(0x0F)); | ||
|
|
||
| // Extract high nibbles (values 1,3,5,7,9,11,13,15) - odd indices | ||
| __m128i high_nibbles1 = _mm_and_si128(_mm_srli_epi16(code_vec, 4), _mm_set1_epi8(0x0F)); | ||
| __m128i high_nibbles2 = _mm_and_si128(_mm_srli_epi16(code_vec2, 4), _mm_set1_epi8(0x0F)); | ||
|
|
||
| // Interleave low and high nibbles to get correct order | ||
| __m128i interleaved1 = _mm_unpacklo_epi8(low_nibbles1, high_nibbles1); | ||
| __m128i interleaved2 = _mm_unpacklo_epi8(low_nibbles2, high_nibbles2); | ||
|
|
||
| // Convert to float and scale - first 8 values | ||
| __m128i low_part1 = _mm_cvtepu8_epi32(interleaved1); | ||
| __m128i high_part1 = _mm_cvtepu8_epi32(_mm_srli_si128(interleaved1, 4)); | ||
| __m128 values0 = _mm_cvtepi32_ps(low_part1); | ||
| __m128 values1 = _mm_cvtepi32_ps(high_part1); | ||
|
|
||
| // Convert to float and scale - next 8 values | ||
| __m128i low_part2 = _mm_cvtepu8_epi32(interleaved2); | ||
| __m128i high_part2 = _mm_cvtepu8_epi32(_mm_srli_si128(interleaved2, 4)); | ||
| __m128 values2 = _mm_cvtepi32_ps(low_part2); | ||
| __m128 values3 = _mm_cvtepi32_ps(high_part2); | ||
|
|
||
| // Combine into AVX vectors | ||
| __m256 values01 = _mm256_set_m128(values1, values0); | ||
| __m256 values23 = _mm256_set_m128(values3, values2); | ||
|
|
||
| // Scale by 1/15.0 | ||
| __m256 scale = _mm256_set1_ps(1.0f / 15.0f); | ||
| values01 = _mm256_mul_ps(values01, scale); | ||
| values23 = _mm256_mul_ps(values23, scale); | ||
|
|
||
| // Apply diff and lower_bound | ||
| __m256 diff_vec0 = _mm256_loadu_ps(diff + d); | ||
| __m256 diff_vec1 = _mm256_loadu_ps(diff + d + 8); | ||
| __m256 lb_vec0 = _mm256_loadu_ps(lower_bound + d); | ||
| __m256 lb_vec1 = _mm256_loadu_ps(lower_bound + d + 8); | ||
|
|
||
| values01 = _mm256_fmadd_ps(values01, diff_vec0, lb_vec0); | ||
| values23 = _mm256_fmadd_ps(values23, diff_vec1, lb_vec1); | ||
|
|
||
| // Load query vectors | ||
| __m256 query_vec0 = _mm256_loadu_ps(query + d); | ||
| __m256 query_vec1 = _mm256_loadu_ps(query + d + 8); | ||
|
|
||
| // Compute dot products | ||
| __m256 prod0 = _mm256_mul_ps(query_vec0, values01); | ||
| __m256 prod1 = _mm256_mul_ps(query_vec1, values23); | ||
|
|
||
| // Horizontal sum | ||
| __m256 sum = _mm256_add_ps(prod0, prod1); | ||
| __m128 sum_low = _mm256_castps256_ps128(sum); | ||
| __m128 sum_high = _mm256_extractf128_ps(sum, 1); | ||
| __m128 sum01 = _mm_add_ps(sum_low, sum_high); | ||
| __m128 sum23 = _mm_shuffle_ps(sum01, sum01, _MM_SHUFFLE(2, 3, 0, 1)); | ||
| __m128 sum0123 = _mm_add_ps(sum01, sum23); | ||
| __m128 sum4567 = _mm_movehl_ps(sum0123, sum0123); | ||
| __m128 total_sum = _mm_add_ss(sum0123, sum4567); | ||
|
|
||
| result += _mm_cvtss_f32(total_sum); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the AVX implementation, there is a lot of duplicated code for unpacking 4-bit values and for horizontal summation across the new SQ4 functions.
To improve code quality and maintainability, please consider refactoring this duplicated logic into static inline helper functions. This would be consistent with the cleaner implementation in avx512.cpp.
You could introduce helpers for:
- Unpacking 16 4-bit values into two
__m256float vectors. - Performing a horizontal sum on a
__m256vector.
This will make the SQ4... functions more concise and less prone to errors if changes are needed in the future.
| for (; d + 7 < dim; d += 8) { | ||
| // Load 4 bytes (8 4-bit values) | ||
| __m128i code_vec = _mm_loadl_epi64((__m128i*)(codes + (d >> 1))); | ||
|
|
||
| // Extract low nibbles (values 0,2,4,6) - even indices | ||
| __m128i low_nibbles = _mm_and_si128(code_vec, _mm_set1_epi8(0x0F)); | ||
| // Extract high nibbles (values 1,3,5,7) - odd indices | ||
| __m128i high_nibbles = _mm_and_si128(_mm_srli_epi16(code_vec, 4), _mm_set1_epi8(0x0F)); | ||
|
|
||
| // Interleave low and high nibbles to get correct order: [d0, d1, d2, d3, d4, d5, d6, d7] | ||
| __m128i interleaved = _mm_unpacklo_epi8(low_nibbles, high_nibbles); | ||
|
|
||
| // Convert to float and scale | ||
| __m128i low_part = _mm_cvtepu8_epi32(interleaved); | ||
| __m128i high_part = _mm_cvtepu8_epi32(_mm_srli_si128(interleaved, 4)); | ||
| __m128 values0 = _mm_cvtepi32_ps(low_part); | ||
| __m128 values1 = _mm_cvtepi32_ps(high_part); | ||
|
|
||
| // Scale by 1/15.0 | ||
| __m128 scale = _mm_set1_ps(1.0f / 15.0f); | ||
| values0 = _mm_mul_ps(values0, scale); | ||
| values1 = _mm_mul_ps(values1, scale); | ||
|
|
||
| // Apply diff and lower_bound | ||
| __m128 diff_vec0 = _mm_loadu_ps(diff + d); | ||
| __m128 diff_vec1 = _mm_loadu_ps(diff + d + 4); | ||
| __m128 lb_vec0 = _mm_loadu_ps(lower_bound + d); | ||
| __m128 lb_vec1 = _mm_loadu_ps(lower_bound + d + 4); | ||
|
|
||
| values0 = _mm_add_ps(_mm_mul_ps(values0, diff_vec0), lb_vec0); | ||
| values1 = _mm_add_ps(_mm_mul_ps(values1, diff_vec1), lb_vec1); | ||
|
|
||
| // Load query vectors | ||
| __m128 query_vec0 = _mm_loadu_ps(query + d); | ||
| __m128 query_vec1 = _mm_loadu_ps(query + d + 4); | ||
|
|
||
| // Compute dot products | ||
| __m128 prod0 = _mm_mul_ps(query_vec0, values0); | ||
| __m128 prod1 = _mm_mul_ps(query_vec1, values1); | ||
|
|
||
| // Horizontal sum | ||
| __m128 sum01 = _mm_add_ps(prod0, prod1); | ||
| __m128 sum23 = _mm_shuffle_ps(sum01, sum01, _MM_SHUFFLE(2, 3, 0, 1)); | ||
| __m128 sum0123 = _mm_add_ps(sum01, sum23); | ||
| __m128 sum4567 = _mm_movehl_ps(sum0123, sum0123); | ||
| __m128 sum = _mm_add_ss(sum0123, sum4567); | ||
|
|
||
| result += _mm_cvtss_f32(sum); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a lot of code duplication in the new SSE implementations for SQ4 computations. The logic for unpacking 4-bit values and for horizontal summation is repeated in SQ4ComputeIP, SQ4ComputeL2Sqr, SQ4ComputeCodesIP, and SQ4ComputeCodesL2Sqr.
To improve maintainability, I suggest extracting this common logic into static inline helper functions.
For example, you could create:
- A function to unpack 8 4-bit values into two
__m128float vectors. - A function for the horizontal sum of a
__m128vector.
This would make the code cleaner, more readable, and easier to maintain, similar to the pattern used in the AVX512 implementation.
Signed-off-by: LHT129 <[email protected]>
Codecov Report✅ All modified and coverable lines are covered by tests. @@ Coverage Diff @@
## main #1391 +/- ##
==========================================
- Coverage 91.62% 91.60% -0.02%
==========================================
Files 322 322
Lines 18334 18334
==========================================
- Hits 16798 16795 -3
- Misses 1536 1539 +3
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report in Codecov by Sentry.
🚀 New features to boost your workflow:
|
wxyucs
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
inabao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Summary by Sourcery
Optimize SQ4 quantization distance and inner-product computations with SIMD implementations across SSE, AVX, AVX2, and AVX512 backends and extend corresponding tests and benchmarks.
Enhancements: