Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 94 additions & 74 deletions dbms/src/Functions/FunctionsRound.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,75 +206,6 @@ enum class RoundingMode
#endif
};

/** Rounding functions for decimal values
*/

template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
struct DecimalRoundingComputation
{
static_assert(IsDecimal<T>);
static const size_t data_count = 1;
static size_t prepare(size_t scale)
{
return scale;
}
// compute need decimal_scale to interpret decimals
static inline void compute(const T * __restrict in, size_t scale, OutputType * __restrict out, ScaleType decimal_scale)
{
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
Float64 val = in->template toFloat<Float64>(decimal_scale);

if constexpr (scale_mode == ScaleMode::Positive)
{
val = val * scale;
}
else if constexpr (scale_mode == ScaleMode::Negative)
{
val = val / scale;
}

if constexpr (rounding_mode == RoundingMode::Round)
{
val = round(val);
}
else if constexpr (rounding_mode == RoundingMode::Floor)
{
val = floor(val);
}
else if constexpr (rounding_mode == RoundingMode::Ceil)
{
val = ceil(val);
}
else if constexpr (rounding_mode == RoundingMode::Trunc)
{
val = trunc(val);
}


if constexpr (scale_mode == ScaleMode::Positive)
{
val = val / scale;
}
else if constexpr (scale_mode == ScaleMode::Negative)
{
val = val * scale;
}

if constexpr (std::is_same_v<T, OutputType>)
{
*out = ToDecimal<Float64, T>(val, decimal_scale);
}
else if constexpr (std::is_same_v<OutputType, Int64>)
{
*out = static_cast<Int64>(val);
}
else
{
; // never arrived here
}
}
};


/** Rounding functions for integer values.
*/
Expand Down Expand Up @@ -336,12 +267,74 @@ struct IntegerRoundingComputation
}
}

static ALWAYS_INLINE void compute(const T * __restrict in, size_t scale, T * __restrict out)
static ALWAYS_INLINE void compute(const T * __restrict in, T scale, T * __restrict out)
{
*out = compute(*in, scale);
}
};

/** Rounding functions for decimal values
*/

template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
struct DecimalRoundingComputation
{
static_assert(IsDecimal<T>);
using NativeType = typename T::NativeType;
static const size_t data_count = 1;
static size_t prepare(size_t scale) { return scale; }
// compute need decimal_scale to interpret decimals
static inline void compute(
const T * __restrict in,
size_t scale,
OutputType * __restrict out,
NativeType decimal_scale)
{
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
// Currently, we only use DecimalRoundingComputation for floor/ceil.
// As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
// So, we only handle ScaleMode::Zero here.
if constexpr (scale_mode == ScaleMode::Zero)
{
try
{
if constexpr (rounding_mode == RoundingMode::Floor)
{
auto x = in->value;
if (x < 0)
x -= decimal_scale - 1;
*out = static_cast<OutputType>(x / decimal_scale);
}
else if constexpr (rounding_mode == RoundingMode::Ceil)
{
auto x = in->value;
if (x >= 0)
x += decimal_scale - 1;
*out = static_cast<OutputType>(x / decimal_scale);
}
else
{
throw Exception(
"Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation",
ErrorCodes::LOGICAL_ERROR);
}
}
catch (const std::overflow_error & e)
{
throw Exception(
"Logical error: unexpected overflow in DecimalRoundingComputation",
ErrorCodes::LOGICAL_ERROR);
}
}
else
{
throw Exception(
"Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
+ toString(scale),
ErrorCodes::LOGICAL_ERROR);
}
}
};

#if __SSE4_1__

Expand Down Expand Up @@ -554,7 +547,7 @@ struct IntegerRoundingImpl

while (p_in < end_in)
{
Op::compute(p_in, scale, p_out);
Op::compute(p_in, static_cast<T>(scale), p_out);
++p_in;
++p_out;
}
Expand Down Expand Up @@ -620,14 +613,18 @@ struct DecimalRoundingImpl;
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
{
static_assert(IsDecimal<T>);
using NativeType = typename T::NativeType;

private:
using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, Int64>;
using Data = T;

public:
static NO_INLINE void apply(const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnVector<Int64>::Container & out)
{
ScaleType decimal_scale = in.getScale();
ScaleType in_scale = in.getScale();
auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
const T * end_in = in.data() + in.size();

const T * __restrict p_in = in.data();
Expand All @@ -645,14 +642,18 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
struct DecimalRoundingImpl<T, rounding_mode, scale_mode, T>
{
static_assert(IsDecimal<T>);
using NativeType = typename T::NativeType;

private:
using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, T>;
using Data = T;

public:
static NO_INLINE void apply(const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnDecimal<T>::Container & out)
{
ScaleType decimal_scale = in.getScale();
ScaleType in_scale = in.getScale();
auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
const T * end_in = in.data() + in.size();

const T * __restrict p_in = in.data();
Expand Down Expand Up @@ -705,7 +706,12 @@ struct Dispatcher

if constexpr (IsDecimal<OutputType>)
{
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), col->getData().getScale());
UInt32 res_scale = 0;
if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
{
res_scale = col->getData().getScale();
}
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), res_scale);
typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData();
applyInternal(col, vec_res, col_res, block, scale_arg, result);
}
Expand Down Expand Up @@ -808,6 +814,20 @@ class FunctionRounding : public IFunction
fmt::format("Illegal type {} of argument of function {}", arguments[0]->getName(), getName()),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
{
if (arguments[0]->isDecimal())
{
if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0].get()))
return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec(), 0);
else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0].get()))
return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec(), 0);
else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0].get()))
return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec(), 0);
else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0].get()))
return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec(), 0);
}
}
return arguments[0];
}

Expand Down
Loading