Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 222 additions & 7 deletions wren-core/core/src/mdl/dialect/inner_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
use datafusion::logical_expr::Expr;

use datafusion::scalar::ScalarValue;
use datafusion::sql::sqlparser::ast::{self, ExtractSyntax, Ident, WindowFrameBound};
use datafusion::sql::sqlparser::ast::{
self, DataType, DateTimeField, Expr as AstExpr, ExtractSyntax, Function,
FunctionArg, FunctionArgExpr, Ident, Interval, TimezoneInfo, Value, WindowFrameBound,
};
use datafusion::sql::unparser::Unparser;
use regex::Regex;

Expand Down Expand Up @@ -135,18 +138,223 @@ impl InnerDialect for BigQueryDialect {
) -> Result<Option<ast::Expr>> {
match function_name {
"date_part" => {
if args.len() != 2 {
if args.len() != 2 && args.len() != 3 {
return plan_err!(
"date_part requires exactly 2 arguments, found {}",
"date_part requires 2 or 3 arguments, found {}",
args.len()
);
}
// Base timestamp/datetime expression
let mut source_expr = unparser.expr_to_sql(&args[1])?;
// Apply timezone if provided as 3rd arg
if args.len() == 3 {
if let Expr::Literal(ScalarValue::Utf8(Some(tz))) = &args[2] {
source_expr = AstExpr::AtTimeZone {
timestamp: Box::new(source_expr),
time_zone: TimezoneInfo::Tz(tz.clone()),
};
}
}
Ok(Some(ast::Expr::Extract {
field: self.datetime_field_from_expr(&args[0])?,
syntax: ExtractSyntax::From,
expr: Box::new(unparser.expr_to_sql(&args[1])?),
expr: Box::new(source_expr),
}))
}
"date_trunc" | "datetime_trunc" | "timestamp_trunc" | "time_trunc" => {
if args.len() != 2 {
return plan_err!(
"{} requires exactly 2 arguments, found {}",
function_name,
args.len()
);
}
Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[1])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[0])?,
)),
],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"date_add" | "datetime_add" | "timestamp_add" | "time_add" | "date_sub"
| "datetime_sub" | "timestamp_sub" | "time_sub" => {
if args.len() != 2 {
return plan_err!(
"{} requires exactly 2 arguments, found {}",
function_name,
args.len()
);
}

let interval_expr = match &args[1] {
Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))) => {
let (days, ms) = (*interval >> 32, *interval as i32);
let use_day_unit = matches!(function_name, "date_add" | "date_sub");
let (value_str, unit) = if use_day_unit {
(format!("{}", days), DateTimeField::Day)
} else {
(
format!("{}", days * 24 * 3600 * 1000 + ms as i64),
DateTimeField::Millisecond,
)
};
AstExpr::Value(Value::Interval(Interval {
value: Box::new(AstExpr::Value(Value::Number(value_str, false))),
leading_field: Some(unit),
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
}))
}
Expr::Literal(ScalarValue::IntervalYearMonth(Some(interval))) => {
let (years, months) = (*interval / 12, *interval % 12);
if function_name.starts_with("time_") {
return plan_err!(
"Cannot add/subtract YEAR/MONTH interval to/from a TIME value"
);
}
AstExpr::Value(Value::Interval(Interval {
value: Box::new(AstExpr::Value(Value::Number(
format!("{}", years * 12 + months),
false,
))),
leading_field: Some(DateTimeField::Month),
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
}))
}
_ => return plan_err!("Invalid interval for {}", function_name),
};

Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[0])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(interval_expr)),
],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"date_diff" | "datetime_diff" | "timestamp_diff" | "time_diff" => {
if args.len() != 3 {
return plan_err!(
"{} requires exactly 3 arguments, found {}",
function_name,
args.len()
);
}
Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[1])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[2])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[0])?,
)),
],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"parse_date" | "parse_datetime" | "parse_timestamp" | "format_date"
| "format_datetime" | "format_timestamp" => {
if args.len() != 2 {
return plan_err!(
"{} requires exactly 2 arguments, found {}",
function_name,
args.len()
);
}
Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[0])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[1])?,
)),
],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"current_date" | "current_datetime" | "current_timestamp" => {
if !args.is_empty() {
return plan_err!(
"{} requires no arguments, found {}",
function_name,
args.len()
);
}
Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"generate_date_array" => {
if args.len() != 2 && args.len() != 3 {
return plan_err!(
"generate_date_array requires 2 or 3 arguments, found {}",
args.len()
);
}
let mut fn_args = vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[0])?)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[1])?)),
];
if args.len() == 3 {
fn_args.push(FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[2])?)));
}

Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: fn_args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
_ => Ok(None),
}
}
Expand Down Expand Up @@ -203,9 +411,11 @@ impl BigQueryDialect {
if let Some(end) = s.find(')') {
let weekday = &s[start + 1..end];
match weekday {
"SUNDAY" | "MONDAY" | "TUESDAY" | "WEDNESDAY"
"SUNDAY" | "MONDAY" | "TUESDAY" | "WEDNESDAY"
| "THURSDAY" | "FRIDAY" | "SATURDAY" => {
return Ok(ast::DateTimeField::Week(Some(Ident::new(weekday))));
return Ok(ast::DateTimeField::Week(Some(Ident::new(
weekday,
))));
}
_ => return plan_err!("Invalid weekday '{}' for WEEK. Valid values are SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, and SATURDAY", weekday),
}
Expand All @@ -224,6 +434,11 @@ impl BigQueryDialect {
"QUARTER" => Ok(ast::DateTimeField::Quarter),
"YEAR" => Ok(ast::DateTimeField::Year),
"ISOYEAR" => Ok(ast::DateTimeField::Isoyear),
"HOUR" => Ok(ast::DateTimeField::Hour),
"MINUTE" => Ok(ast::DateTimeField::Minute),
"SECOND" => Ok(ast::DateTimeField::Second),
"MILLISECOND" => Ok(ast::DateTimeField::Millisecond),
"MICROSECOND" => Ok(ast::DateTimeField::Microsecond),
_ => {
plan_err!("Unsupported date part '{}' for BigQuery", s)
}
Expand Down Expand Up @@ -251,4 +466,4 @@ impl InnerDialect for OracleDialect {
fn non_uppercase(sql: &str) -> bool {
let uppsercase = sql.to_uppercase();
uppsercase != sql
}
}
14 changes: 12 additions & 2 deletions wren-core/core/src/mdl/dialect/wren_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
* specific language governing permissions and limitations
* under the License.
*/

use crate::mdl::dialect::inner_dialect::{get_inner_dialect, InnerDialect};
use crate::mdl::manifest::DataSource;
use crate::mdl::utils::scalar_value_to_ast_value;
use datafusion::common::{internal_err, plan_err, Result, ScalarValue};
use datafusion::logical_expr::sqlparser::ast::{Ident, Subscript};
use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
Expand Down Expand Up @@ -84,7 +86,15 @@ impl Dialect for WrenDialect {
let sql = self.named_struct_to_sql(args, unparser)?;
Ok(Some(sql))
}
_ => Ok(None),
// Add override for Literal
_ => {
if let Some(Expr::Literal(value)) = args.get(0) {
if func_name == "lit" {
return Ok(Some(ast::Expr::Value(scalar_value_to_ast_value(value))));
}
}
Ok(None)
}
}
}

Expand Down Expand Up @@ -218,4 +228,4 @@ impl WrenDialect {
fn non_lowercase(sql: &str) -> bool {
let lowercase = sql.to_lowercase();
lowercase != sql
}
}
Loading