@@ -206,75 +206,6 @@ enum class RoundingMode
206206#endif
207207};
208208
209- /* * Rounding functions for decimal values
210- */
211-
212- template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
213- struct DecimalRoundingComputation
214- {
215- static_assert (IsDecimal<T>);
216- static const size_t data_count = 1 ;
217- static size_t prepare (size_t scale)
218- {
219- return scale;
220- }
221- // compute need decimal_scale to interpret decimals
222- static inline void compute (const T * __restrict in, size_t scale, OutputType * __restrict out, ScaleType decimal_scale)
223- {
224- static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
225- Float64 val = in->template toFloat <Float64>(decimal_scale);
226-
227- if constexpr (scale_mode == ScaleMode::Positive)
228- {
229- val = val * scale;
230- }
231- else if constexpr (scale_mode == ScaleMode::Negative)
232- {
233- val = val / scale;
234- }
235-
236- if constexpr (rounding_mode == RoundingMode::Round)
237- {
238- val = round (val);
239- }
240- else if constexpr (rounding_mode == RoundingMode::Floor)
241- {
242- val = floor (val);
243- }
244- else if constexpr (rounding_mode == RoundingMode::Ceil)
245- {
246- val = ceil (val);
247- }
248- else if constexpr (rounding_mode == RoundingMode::Trunc)
249- {
250- val = trunc (val);
251- }
252-
253-
254- if constexpr (scale_mode == ScaleMode::Positive)
255- {
256- val = val / scale;
257- }
258- else if constexpr (scale_mode == ScaleMode::Negative)
259- {
260- val = val * scale;
261- }
262-
263- if constexpr (std::is_same_v<T, OutputType>)
264- {
265- *out = ToDecimal<Float64, T>(val, decimal_scale);
266- }
267- else if constexpr (std::is_same_v<OutputType, Int64>)
268- {
269- *out = static_cast <Int64>(val);
270- }
271- else
272- {
273- ; // never arrived here
274- }
275- }
276- };
277-
278209
279210/* * Rounding functions for integer values.
280211 */
@@ -336,12 +267,74 @@ struct IntegerRoundingComputation
336267 }
337268 }
338269
339- static ALWAYS_INLINE void compute (const T * __restrict in, size_t scale, T * __restrict out)
270+ static ALWAYS_INLINE void compute (const T * __restrict in, T scale, T * __restrict out)
340271 {
341272 *out = compute (*in, scale);
342273 }
343274};
344275
276+ /* * Rounding functions for decimal values
277+ */
278+
279+ template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
280+ struct DecimalRoundingComputation
281+ {
282+ static_assert (IsDecimal<T>);
283+ using NativeType = typename T::NativeType;
284+ static const size_t data_count = 1 ;
285+ static size_t prepare (size_t scale) { return scale; }
286+ // compute need decimal_scale to interpret decimals
287+ static inline void compute (
288+ const T * __restrict in,
289+ size_t scale,
290+ OutputType * __restrict out,
291+ NativeType decimal_scale)
292+ {
293+ static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
294+ // Currently, we only use DecimalRoundingComputation for floor/ceil.
295+ // As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
296+ // So, we only handle ScaleMode::Zero here.
297+ if constexpr (scale_mode == ScaleMode::Zero)
298+ {
299+ try
300+ {
301+ if constexpr (rounding_mode == RoundingMode::Floor)
302+ {
303+ auto x = in->value ;
304+ if (x < 0 )
305+ x -= decimal_scale - 1 ;
306+ *out = static_cast <OutputType>(x / decimal_scale);
307+ }
308+ else if constexpr (rounding_mode == RoundingMode::Ceil)
309+ {
310+ auto x = in->value ;
311+ if (x >= 0 )
312+ x += decimal_scale - 1 ;
313+ *out = static_cast <OutputType>(x / decimal_scale);
314+ }
315+ else
316+ {
317+ throw Exception (
318+ " Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation" ,
319+ ErrorCodes::LOGICAL_ERROR);
320+ }
321+ }
322+ catch (const std::overflow_error & e)
323+ {
324+ throw Exception (
325+ " Logical error: unexpected overflow in DecimalRoundingComputation" ,
326+ ErrorCodes::LOGICAL_ERROR);
327+ }
328+ }
329+ else
330+ {
331+ throw Exception (
332+ " Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
333+ + toString (scale),
334+ ErrorCodes::LOGICAL_ERROR);
335+ }
336+ }
337+ };
345338
346339#if __SSE4_1__
347340
@@ -554,7 +547,7 @@ struct IntegerRoundingImpl
554547
555548 while (p_in < end_in)
556549 {
557- Op::compute (p_in, scale, p_out);
550+ Op::compute (p_in, static_cast <T>( scale) , p_out);
558551 ++p_in;
559552 ++p_out;
560553 }
@@ -620,14 +613,18 @@ struct DecimalRoundingImpl;
620613template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
621614struct DecimalRoundingImpl <T, rounding_mode, scale_mode, Int64>
622615{
616+ static_assert (IsDecimal<T>);
617+ using NativeType = typename T::NativeType;
618+
623619private:
624620 using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, Int64>;
625621 using Data = T;
626622
627623public:
628624 static NO_INLINE void apply (const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnVector<Int64>::Container & out)
629625 {
630- ScaleType decimal_scale = in.getScale ();
626+ ScaleType in_scale = in.getScale ();
627+ auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
631628 const T * end_in = in.data () + in.size ();
632629
633630 const T * __restrict p_in = in.data ();
@@ -645,14 +642,18 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
645642template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
646643struct DecimalRoundingImpl <T, rounding_mode, scale_mode, T>
647644{
645+ static_assert (IsDecimal<T>);
646+ using NativeType = typename T::NativeType;
647+
648648private:
649649 using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, T>;
650650 using Data = T;
651651
652652public:
653653 static NO_INLINE void apply (const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnDecimal<T>::Container & out)
654654 {
655- ScaleType decimal_scale = in.getScale ();
655+ ScaleType in_scale = in.getScale ();
656+ auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
656657 const T * end_in = in.data () + in.size ();
657658
658659 const T * __restrict p_in = in.data ();
@@ -705,7 +706,12 @@ struct Dispatcher
705706
706707 if constexpr (IsDecimal<OutputType>)
707708 {
708- auto col_res = ColumnDecimal<OutputType>::create (col->getData ().size (), col->getData ().getScale ());
709+ UInt32 res_scale = 0 ;
710+ if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
711+ {
712+ res_scale = col->getData ().getScale ();
713+ }
714+ auto col_res = ColumnDecimal<OutputType>::create (col->getData ().size (), res_scale);
709715 typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData ();
710716 applyInternal (col, vec_res, col_res, block, scale_arg, result);
711717 }
@@ -808,6 +814,20 @@ class FunctionRounding : public IFunction
808814 fmt::format (" Illegal type {} of argument of function {}" , arguments[0 ]->getName (), getName ()),
809815 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
810816
817+ if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
818+ {
819+ if (arguments[0 ]->isDecimal ())
820+ {
821+ if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0 ].get ()))
822+ return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec (), 0 );
823+ else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0 ].get ()))
824+ return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec (), 0 );
825+ else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0 ].get ()))
826+ return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec (), 0 );
827+ else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0 ].get ()))
828+ return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec (), 0 );
829+ }
830+ }
811831 return arguments[0 ];
812832 }
813833
0 commit comments