Skip to content

Commit d3ac95c

Browse files
committed
Fix decimal floor/ceil (#10365)
1 parent c9d93fc commit d3ac95c

File tree

4 files changed

+720
-84
lines changed

4 files changed

+720
-84
lines changed

dbms/src/Functions/FunctionsRound.h

Lines changed: 94 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -206,75 +206,6 @@ enum class RoundingMode
206206
#endif
207207
};
208208

209-
/** Rounding functions for decimal values
210-
*/
211-
212-
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
213-
struct DecimalRoundingComputation
214-
{
215-
static_assert(IsDecimal<T>);
216-
static const size_t data_count = 1;
217-
static size_t prepare(size_t scale)
218-
{
219-
return scale;
220-
}
221-
// compute need decimal_scale to interpret decimals
222-
static inline void compute(const T * __restrict in, size_t scale, OutputType * __restrict out, ScaleType decimal_scale)
223-
{
224-
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
225-
Float64 val = in->template toFloat<Float64>(decimal_scale);
226-
227-
if constexpr (scale_mode == ScaleMode::Positive)
228-
{
229-
val = val * scale;
230-
}
231-
else if constexpr (scale_mode == ScaleMode::Negative)
232-
{
233-
val = val / scale;
234-
}
235-
236-
if constexpr (rounding_mode == RoundingMode::Round)
237-
{
238-
val = round(val);
239-
}
240-
else if constexpr (rounding_mode == RoundingMode::Floor)
241-
{
242-
val = floor(val);
243-
}
244-
else if constexpr (rounding_mode == RoundingMode::Ceil)
245-
{
246-
val = ceil(val);
247-
}
248-
else if constexpr (rounding_mode == RoundingMode::Trunc)
249-
{
250-
val = trunc(val);
251-
}
252-
253-
254-
if constexpr (scale_mode == ScaleMode::Positive)
255-
{
256-
val = val / scale;
257-
}
258-
else if constexpr (scale_mode == ScaleMode::Negative)
259-
{
260-
val = val * scale;
261-
}
262-
263-
if constexpr (std::is_same_v<T, OutputType>)
264-
{
265-
*out = ToDecimal<Float64, T>(val, decimal_scale);
266-
}
267-
else if constexpr (std::is_same_v<OutputType, Int64>)
268-
{
269-
*out = static_cast<Int64>(val);
270-
}
271-
else
272-
{
273-
; // never arrived here
274-
}
275-
}
276-
};
277-
278209

279210
/** Rounding functions for integer values.
280211
*/
@@ -336,12 +267,74 @@ struct IntegerRoundingComputation
336267
}
337268
}
338269

339-
static ALWAYS_INLINE void compute(const T * __restrict in, size_t scale, T * __restrict out)
270+
static ALWAYS_INLINE void compute(const T * __restrict in, T scale, T * __restrict out)
340271
{
341272
*out = compute(*in, scale);
342273
}
343274
};
344275

276+
/** Rounding functions for decimal values
277+
*/
278+
279+
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
280+
struct DecimalRoundingComputation
281+
{
282+
static_assert(IsDecimal<T>);
283+
using NativeType = typename T::NativeType;
284+
static const size_t data_count = 1;
285+
static size_t prepare(size_t scale) { return scale; }
286+
// compute need decimal_scale to interpret decimals
287+
static inline void compute(
288+
const T * __restrict in,
289+
size_t scale,
290+
OutputType * __restrict out,
291+
NativeType decimal_scale)
292+
{
293+
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
294+
// Currently, we only use DecimalRoundingComputation for floor/ceil.
295+
// As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
296+
// So, we only handle ScaleMode::Zero here.
297+
if constexpr (scale_mode == ScaleMode::Zero)
298+
{
299+
try
300+
{
301+
if constexpr (rounding_mode == RoundingMode::Floor)
302+
{
303+
auto x = in->value;
304+
if (x < 0)
305+
x -= decimal_scale - 1;
306+
*out = static_cast<OutputType>(x / decimal_scale);
307+
}
308+
else if constexpr (rounding_mode == RoundingMode::Ceil)
309+
{
310+
auto x = in->value;
311+
if (x >= 0)
312+
x += decimal_scale - 1;
313+
*out = static_cast<OutputType>(x / decimal_scale);
314+
}
315+
else
316+
{
317+
throw Exception(
318+
"Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation",
319+
ErrorCodes::LOGICAL_ERROR);
320+
}
321+
}
322+
catch (const std::overflow_error & e)
323+
{
324+
throw Exception(
325+
"Logical error: unexpected overflow in DecimalRoundingComputation",
326+
ErrorCodes::LOGICAL_ERROR);
327+
}
328+
}
329+
else
330+
{
331+
throw Exception(
332+
"Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
333+
+ toString(scale),
334+
ErrorCodes::LOGICAL_ERROR);
335+
}
336+
}
337+
};
345338

346339
#if __SSE4_1__
347340

@@ -554,7 +547,7 @@ struct IntegerRoundingImpl
554547

555548
while (p_in < end_in)
556549
{
557-
Op::compute(p_in, scale, p_out);
550+
Op::compute(p_in, static_cast<T>(scale), p_out);
558551
++p_in;
559552
++p_out;
560553
}
@@ -620,14 +613,18 @@ struct DecimalRoundingImpl;
620613
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
621614
struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
622615
{
616+
static_assert(IsDecimal<T>);
617+
using NativeType = typename T::NativeType;
618+
623619
private:
624620
using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, Int64>;
625621
using Data = T;
626622

627623
public:
628624
static NO_INLINE void apply(const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnVector<Int64>::Container & out)
629625
{
630-
ScaleType decimal_scale = in.getScale();
626+
ScaleType in_scale = in.getScale();
627+
auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
631628
const T * end_in = in.data() + in.size();
632629

633630
const T * __restrict p_in = in.data();
@@ -645,14 +642,18 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
645642
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
646643
struct DecimalRoundingImpl<T, rounding_mode, scale_mode, T>
647644
{
645+
static_assert(IsDecimal<T>);
646+
using NativeType = typename T::NativeType;
647+
648648
private:
649649
using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, T>;
650650
using Data = T;
651651

652652
public:
653653
static NO_INLINE void apply(const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnDecimal<T>::Container & out)
654654
{
655-
ScaleType decimal_scale = in.getScale();
655+
ScaleType in_scale = in.getScale();
656+
auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
656657
const T * end_in = in.data() + in.size();
657658

658659
const T * __restrict p_in = in.data();
@@ -705,7 +706,12 @@ struct Dispatcher
705706

706707
if constexpr (IsDecimal<OutputType>)
707708
{
708-
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), col->getData().getScale());
709+
UInt32 res_scale = 0;
710+
if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
711+
{
712+
res_scale = col->getData().getScale();
713+
}
714+
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), res_scale);
709715
typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData();
710716
applyInternal(col, vec_res, col_res, block, scale_arg, result);
711717
}
@@ -808,6 +814,20 @@ class FunctionRounding : public IFunction
808814
fmt::format("Illegal type {} of argument of function {}", arguments[0]->getName(), getName()),
809815
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
810816

817+
if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
818+
{
819+
if (arguments[0]->isDecimal())
820+
{
821+
if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0].get()))
822+
return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec(), 0);
823+
else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0].get()))
824+
return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec(), 0);
825+
else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0].get()))
826+
return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec(), 0);
827+
else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0].get()))
828+
return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec(), 0);
829+
}
830+
}
811831
return arguments[0];
812832
}
813833

0 commit comments

Comments
 (0)