Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 222 additions & 7 deletions wren-core/core/src/mdl/dialect/inner_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
use datafusion::logical_expr::Expr;

use datafusion::scalar::ScalarValue;
use datafusion::sql::sqlparser::ast::{self, ExtractSyntax, Ident, WindowFrameBound};
use datafusion::sql::sqlparser::ast::{
self, DataType, DateTimeField, Expr as AstExpr, ExtractSyntax, Function,
FunctionArg, FunctionArgExpr, Ident, Interval, TimezoneInfo, Value, WindowFrameBound,
};
use datafusion::sql::unparser::Unparser;
use regex::Regex;

Expand Down Expand Up @@ -135,18 +138,223 @@ impl InnerDialect for BigQueryDialect {
) -> Result<Option<ast::Expr>> {
match function_name {
"date_part" => {
if args.len() != 2 {
if args.len() != 2 && args.len() != 3 {
return plan_err!(
"date_part requires exactly 2 arguments, found {}",
"date_part requires 2 or 3 arguments, found {}",
args.len()
);
}
// Base timestamp/datetime expression
let mut source_expr = unparser.expr_to_sql(&args[1])?;
// Apply timezone if provided as 3rd arg
if args.len() == 3 {
if let Expr::Literal(ScalarValue::Utf8(Some(tz))) = &args[2] {
source_expr = AstExpr::AtTimeZone {
timestamp: Box::new(source_expr),
time_zone: TimezoneInfo::Tz(tz.clone()),
};
}
}
Ok(Some(ast::Expr::Extract {
field: self.datetime_field_from_expr(&args[0])?,
syntax: ExtractSyntax::From,
expr: Box::new(unparser.expr_to_sql(&args[1])?),
expr: Box::new(source_expr),
}))
}
"date_trunc" | "datetime_trunc" | "timestamp_trunc" | "time_trunc" => {
if args.len() != 2 {
return plan_err!(
"{} requires exactly 2 arguments, found {}",
function_name,
args.len()
);
}
Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[1])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[0])?,
)),
],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"date_add" | "datetime_add" | "timestamp_add" | "time_add" | "date_sub"
| "datetime_sub" | "timestamp_sub" | "time_sub" => {
if args.len() != 2 {
return plan_err!(
"{} requires exactly 2 arguments, found {}",
function_name,
args.len()
);
}

let interval_expr = match &args[1] {
Expr::Literal(ScalarValue::IntervalDayTime(Some(interval))) => {
let (days, ms) = (*interval >> 32, *interval as i32);
let use_day_unit = matches!(function_name, "date_add" | "date_sub");
let (value_str, unit) = if use_day_unit {
(format!("{}", days), DateTimeField::Day)
} else {
(
format!("{}", days * 24 * 3600 * 1000 + ms as i64),
DateTimeField::Millisecond,
)
};
AstExpr::Value(Value::Interval(Interval {
value: Box::new(AstExpr::Value(Value::Number(value_str, false))),
leading_field: Some(unit),
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
}))
}
Expr::Literal(ScalarValue::IntervalYearMonth(Some(interval))) => {
let (years, months) = (*interval / 12, *interval % 12);
if function_name.starts_with("time_") {
return plan_err!(
"Cannot add/subtract YEAR/MONTH interval to/from a TIME value"
);
}
AstExpr::Value(Value::Interval(Interval {
value: Box::new(AstExpr::Value(Value::Number(
format!("{}", years * 12 + months),
false,
))),
leading_field: Some(DateTimeField::Month),
leading_precision: None,
last_field: None,
fractional_seconds_precision: None,
}))
}
_ => return plan_err!("Invalid interval for {}", function_name),
};

Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[0])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(interval_expr)),
],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"date_diff" | "datetime_diff" | "timestamp_diff" | "time_diff" => {
if args.len() != 3 {
return plan_err!(
"{} requires exactly 3 arguments, found {}",
function_name,
args.len()
);
}
Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[1])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[2])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[0])?,
)),
],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"parse_date" | "parse_datetime" | "parse_timestamp" | "format_date"
| "format_datetime" | "format_timestamp" => {
if args.len() != 2 {
return plan_err!(
"{} requires exactly 2 arguments, found {}",
function_name,
args.len()
);
}
Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[0])?,
)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(
unparser.expr_to_sql(&args[1])?,
)),
],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"current_date" | "current_datetime" | "current_timestamp" => {
if !args.is_empty() {
return plan_err!(
"{} requires no arguments, found {}",
function_name,
args.len()
);
}
Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: vec![],
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
"generate_date_array" => {
if args.len() != 2 && args.len() != 3 {
return plan_err!(
"generate_date_array requires 2 or 3 arguments, found {}",
args.len()
);
}
let mut fn_args = vec![
FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[0])?)),
FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[1])?)),
];
if args.len() == 3 {
fn_args.push(FunctionArg::Unnamed(FunctionArgExpr::Expr(unparser.expr_to_sql(&args[2])?)));
}

Ok(Some(AstExpr::Function(Function {
name: ast::ObjectName(vec![Ident::new(function_name.to_uppercase())]),
args: fn_args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: Vec::new(),
})))
}
_ => Ok(None),
}
}
Expand Down Expand Up @@ -203,9 +411,11 @@ impl BigQueryDialect {
if let Some(end) = s.find(')') {
let weekday = &s[start + 1..end];
match weekday {
"SUNDAY" | "MONDAY" | "TUESDAY" | "WEDNESDAY"
"SUNDAY" | "MONDAY" | "TUESDAY" | "WEDNESDAY"
| "THURSDAY" | "FRIDAY" | "SATURDAY" => {
return Ok(ast::DateTimeField::Week(Some(Ident::new(weekday))));
return Ok(ast::DateTimeField::Week(Some(Ident::new(
weekday,
))));
}
_ => return plan_err!("Invalid weekday '{}' for WEEK. Valid values are SUNDAY, MONDAY, TUESDAY, WEDNESDAY, THURSDAY, FRIDAY, and SATURDAY", weekday),
}
Expand All @@ -224,6 +434,11 @@ impl BigQueryDialect {
"QUARTER" => Ok(ast::DateTimeField::Quarter),
"YEAR" => Ok(ast::DateTimeField::Year),
"ISOYEAR" => Ok(ast::DateTimeField::Isoyear),
"HOUR" => Ok(ast::DateTimeField::Hour),
"MINUTE" => Ok(ast::DateTimeField::Minute),
"SECOND" => Ok(ast::DateTimeField::Second),
"MILLISECOND" => Ok(ast::DateTimeField::Millisecond),
"MICROSECOND" => Ok(ast::DateTimeField::Microsecond),
_ => {
plan_err!("Unsupported date part '{}' for BigQuery", s)
}
Expand Down Expand Up @@ -251,4 +466,4 @@ impl InnerDialect for OracleDialect {
fn non_uppercase(sql: &str) -> bool {
let uppsercase = sql.to_uppercase();
uppsercase != sql
}
}
22 changes: 20 additions & 2 deletions wren-core/core/src/mdl/dialect/wren_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
use crate::mdl::dialect::inner_dialect::{get_inner_dialect, InnerDialect};
use crate::mdl::manifest::DataSource;
use crate::mdl::utils::scalar_value_to_ast_value;
use datafusion::common::{internal_err, plan_err, Result, ScalarValue};
use datafusion::logical_expr::sqlparser::ast::{Ident, Subscript};
use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
Expand Down Expand Up @@ -84,7 +85,24 @@ impl Dialect for WrenDialect {
let sql = self.named_struct_to_sql(args, unparser)?;
Ok(Some(sql))
}
_ => Ok(None),
_ => {
if func_name == "lit" {
if args.len() != 1 {
return plan_err!("lit requires exactly 1 argument");
}
match &args[0] {
Expr::Literal(value) => {
Ok(Some(ast::Expr::Value(scalar_value_to_ast_value(value))))
}
other => {
// Fall back to the expression itself to avoid emitting `lit(...)` in SQL
Ok(Some(unparser.expr_to_sql(other)?))
}
}
} else {
Ok(None)
}
}
}
}

Expand Down Expand Up @@ -218,4 +236,4 @@ impl WrenDialect {
fn non_lowercase(sql: &str) -> bool {
let lowercase = sql.to_lowercase();
lowercase != sql
}
}
Loading
Loading