-
Notifications
You must be signed in to change notification settings - Fork 650
proposed more optimized versions of next_pow2 and prev_pow2 in util.h #6083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -125,26 +125,106 @@ static_assert(align_up(17, 16) == 32, "Should align up"); | |||||||||||||||||||||||||||||||
| static_assert(align_up(8, 8) == 8, "Should be already aligned"); | ||||||||||||||||||||||||||||||||
| static_assert(align_up(5, 8) == 8, "Should align"); | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||
| * @brief Calculates the smallest power of 2 that is greater than or equal to n. | ||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||
| * For n <= 1 the result is 1. | ||||||||||||||||||||||||||||||||
| * If n is already a power of 2, the result is n itself. | ||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||
| * @return The least power of 2 ≥ n. | ||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||
| DALI_HOST_DEV DALI_FORCEINLINE | ||||||||||||||||||||||||||||||||
| constexpr std::enable_if_t<std::is_integral<T>::value, T> next_pow2(T n) { | ||||||||||||||||||||||||||||||||
| T pow2 = 1; | ||||||||||||||||||||||||||||||||
| while (n > pow2) { | ||||||||||||||||||||||||||||||||
| pow2 += pow2; | ||||||||||||||||||||||||||||||||
| using U = std::make_unsigned_t<T>; | ||||||||||||||||||||||||||||||||
| U x = static_cast<U>(n); | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if (x <= 1) | ||||||||||||||||||||||||||||||||
| return 1; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| #if defined(__CUDA_ARCH__) | ||||||||||||||||||||||||||||||||
| // CUDA DEVICE PATH | ||||||||||||||||||||||||||||||||
| constexpr int bits = sizeof(U) * 8; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if constexpr (sizeof(U) == 8) { | ||||||||||||||||||||||||||||||||
| // 64-bit version: uses __clzll() | ||||||||||||||||||||||||||||||||
| unsigned long long y = static_cast<unsigned long long>(x - 1); | ||||||||||||||||||||||||||||||||
| int lz = __clzll(y); | ||||||||||||||||||||||||||||||||
| int pos = 63 - lz; | ||||||||||||||||||||||||||||||||
| return static_cast<T>(U(1) << (pos + 1)); | ||||||||||||||||||||||||||||||||
| } else if constexpr (sizeof(U) == 4) { | ||||||||||||||||||||||||||||||||
| // 32-bit version: uses __clz() | ||||||||||||||||||||||||||||||||
| unsigned int y = static_cast<unsigned int>(x - 1); | ||||||||||||||||||||||||||||||||
| int lz = __clz(y); | ||||||||||||||||||||||||||||||||
| int pos = 31 - lz; | ||||||||||||||||||||||||||||||||
| return static_cast<T>(U(1) << (pos + 1)); | ||||||||||||||||||||||||||||||||
|
Comment on lines
+160
to
+161
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
| } else if constexpr (sizeof(U) == 2 || sizeof(U) == 1) { | ||||||||||||||||||||||||||||||||
| // 8- and 16-bit version: CUDA does not provide __clz for 8/16 bits, | ||||||||||||||||||||||||||||||||
| // so they need to be safely widened to 32-bit unsigned and __clz() used. | ||||||||||||||||||||||||||||||||
| unsigned int y = static_cast<unsigned int>(x - 1); | ||||||||||||||||||||||||||||||||
| int lz = __clz(y); | ||||||||||||||||||||||||||||||||
| int pos = 31 - lz; // position of the most significant bit in a 32-bit container | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| // normalize to the actual number of bits of the original type | ||||||||||||||||||||||||||||||||
| int shift = pos - (32 - bits); | ||||||||||||||||||||||||||||||||
| return static_cast<T>(U(1) << (shift + 1)); | ||||||||||||||||||||||||||||||||
|
Comment on lines
+163
to
+171
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||
| // fallback, in case of exotic sizes | ||||||||||||||||||||||||||||||||
| unsigned int y = static_cast<unsigned int>(x - 1); | ||||||||||||||||||||||||||||||||
| int lz = __clz(y); | ||||||||||||||||||||||||||||||||
| int pos = 31 - lz; | ||||||||||||||||||||||||||||||||
| return static_cast<T>(U(1) << (pos + 1)); | ||||||||||||||||||||||||||||||||
|
Comment on lines
+176
to
+177
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| return pow2; | ||||||||||||||||||||||||||||||||
| #else | ||||||||||||||||||||||||||||||||
| // CPU fallback (portable bitwise version with loop) | ||||||||||||||||||||||||||||||||
| x--; | ||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this CPU fallback provides any difference over:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 This code is way more complex than it used to be, has ths same theoretical complexity but a much larger constant factor.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Description corrected.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean that this loop is replaced with CPU clz build in?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The basic assumption is the loop
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I see - the code is indeed |
||||||||||||||||||||||||||||||||
| #if defined(__clang__) || defined(__GNUC__) | ||||||||||||||||||||||||||||||||
| #pragma unroll | ||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||
| for (unsigned i = 1; i < sizeof(U) * 8; i <<= 1) { | ||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: this loop breaks constexpr evaluation since it's not a constant expression - the function is marked constexpr but won't work at compile time for CPU builds |
||||||||||||||||||||||||||||||||
| x |= x >> i; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| return static_cast<T>(x + 1); | ||||||||||||||||||||||||||||||||
| #endif | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||
| * @brief Calculates the largest power of 2 that is less than or equal to n. | ||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||
| * For n == 0 the result is 0. | ||||||||||||||||||||||||||||||||
| * For n == 1 or when n is already a power of 2, the result is n itself. | ||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||
| * @return The greatest power of 2 ≤ n. | ||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||
| DALI_HOST_DEV DALI_FORCEINLINE | ||||||||||||||||||||||||||||||||
| constexpr std::enable_if_t<std::is_integral<T>::value, T> prev_pow2(T n) { | ||||||||||||||||||||||||||||||||
| T pow2 = 1; | ||||||||||||||||||||||||||||||||
| while (n - pow2 > pow2) { // avoids overflow | ||||||||||||||||||||||||||||||||
| pow2 += pow2; | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
| return pow2; | ||||||||||||||||||||||||||||||||
| using U = std::make_unsigned_t<T>; | ||||||||||||||||||||||||||||||||
| U x = static_cast<U>(n); | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if (x == 0) | ||||||||||||||||||||||||||||||||
| return 0; // no pow2 <= 0 | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| // if n is pow2 | ||||||||||||||||||||||||||||||||
| U np2 = next_pow2(x); | ||||||||||||||||||||||||||||||||
| if (np2 == x) | ||||||||||||||||||||||||||||||||
| return x; | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| // otherwise <prev_pow2> is a half of <next_pow2> | ||||||||||||||||||||||||||||||||
| return static_cast<T>(np2 >> 1); | ||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| /** | ||||||||||||||||||||||||||||||||
| * @brief Checks whether the given integer n is a power of 2. | ||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||
| * Returns true if n has exactly one bit set in its binary representation. | ||||||||||||||||||||||||||||||||
| * Note that for n <= 0 the result is false. | ||||||||||||||||||||||||||||||||
| * | ||||||||||||||||||||||||||||||||
| * @return true if n is a power of 2, false otherwise. | ||||||||||||||||||||||||||||||||
| */ | ||||||||||||||||||||||||||||||||
| template <typename T> | ||||||||||||||||||||||||||||||||
| DALI_HOST_DEV DALI_FORCEINLINE | ||||||||||||||||||||||||||||||||
| constexpr std::enable_if_t<std::is_integral<T>::value, bool> is_pow2(T n) { | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.