diff --git a/datafusion/spark/src/function/datetime/next_day.rs b/datafusion/spark/src/function/datetime/next_day.rs index 72a0c830ffb25..a4ef15b3a10d9 100644 --- a/datafusion/spark/src/function/datetime/next_day.rs +++ b/datafusion/spark/src/function/datetime/next_day.rs @@ -19,11 +19,12 @@ use std::any::Any; use std::sync::Arc; use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType, new_null_array}; -use arrow::datatypes::{DataType, Date32Type}; +use arrow::datatypes::{DataType, Date32Type, Field, FieldRef}; use chrono::{Datelike, Duration, Weekday}; -use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_common::{Result, ScalarValue, exec_err, internal_err}; use datafusion_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, }; /// @@ -63,7 +64,30 @@ impl ScalarUDFImpl for SparkNextDay { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Date32) + internal_err!("return_field_from_args should be used instead") + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [date_field, weekday_field] = args.arg_fields else { + return internal_err!("Spark `next_day` expects exactly two arguments"); + }; + + let has_invalid_scalar = args.scalar_arguments.iter().any(|arg| match arg { + Some(ScalarValue::Utf8(Some(s))) + | Some(ScalarValue::LargeUtf8(Some(s))) + | Some(ScalarValue::Utf8View(Some(s))) => !is_valid_weekday(s), + Some(v) => v.is_null(), + _ => false, + }); + + let nullable = + date_field.is_nullable() || weekday_field.is_nullable() || has_invalid_scalar; + + Ok(Arc::new(Field::new( + self.name(), + DataType::Date32, + nullable, + ))) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -90,8 +114,6 @@ impl ScalarUDFImpl for SparkNextDay { spark_next_day(*days, day_of_week.as_str()), ))) } else { - // TODO: if spark.sql.ansi.enabled is false, - // returns NULL instead of an error for a malformed dayOfWeek. Ok(ColumnarValue::Scalar(ScalarValue::Date32(None))) } } else { @@ -120,8 +142,6 @@ impl ScalarUDFImpl for SparkNextDay { .with_data_type(DataType::Date32); Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef)) } else { - // TODO: if spark.sql.ansi.enabled is false, - // returns NULL instead of an error for a malformed dayOfWeek. Ok(ColumnarValue::Array(Arc::new(new_null_array( &DataType::Date32, date_array.len(), @@ -146,33 +166,26 @@ impl ScalarUDFImpl for SparkNextDay { let date_array: &Date32Array = date_array.as_primitive::(); match day_of_week_array.data_type() { - DataType::Utf8 => { - let day_of_week_array = - day_of_week_array.as_string::(); - process_next_day_arrays(date_array, day_of_week_array) - } - DataType::LargeUtf8 => { - let day_of_week_array = - day_of_week_array.as_string::(); - process_next_day_arrays(date_array, day_of_week_array) - } - DataType::Utf8View => { - let day_of_week_array = - day_of_week_array.as_string_view(); - process_next_day_arrays(date_array, day_of_week_array) - } - other => { - exec_err!( - "Spark `next_day` function: second arg must be string. Got {other:?}" - ) - } + DataType::Utf8 => process_next_day_arrays( + date_array, + day_of_week_array.as_string::(), + ), + DataType::LargeUtf8 => process_next_day_arrays( + date_array, + day_of_week_array.as_string::(), + ), + DataType::Utf8View => process_next_day_arrays( + date_array, + day_of_week_array.as_string_view(), + ), + other => exec_err!( + "Spark `next_day` function: second arg must be string. Got {other:?}" + ), } } - (left, right) => { - exec_err!( - "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}" - ) - } + (left, right) => exec_err!( + "Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}" + ), }?; Ok(ColumnarValue::Array(result)) } @@ -191,57 +204,176 @@ where let result = date_array .iter() .zip(day_of_week_array.iter()) - .map(|(days, day_of_week)| { - if let Some(days) = days { - if let Some(day_of_week) = day_of_week { - spark_next_day(days, day_of_week) - } else { - // TODO: if spark.sql.ansi.enabled is false, - // returns NULL instead of an error for a malformed dayOfWeek. - None - } - } else { - None - } + .map(|(days, day_of_week)| match (days, day_of_week) { + (Some(days), Some(day_of_week)) => spark_next_day(days, day_of_week), + _ => None, }) .collect::(); + Ok(Arc::new(result) as ArrayRef) } fn spark_next_day(days: i32, day_of_week: &str) -> Option { let date = Date32Type::to_naive_date(days); - let day_of_week = day_of_week.trim().to_uppercase(); - let day_of_week = match day_of_week.as_str() { - "MO" | "MON" | "MONDAY" => Some("MONDAY"), - "TU" | "TUE" | "TUESDAY" => Some("TUESDAY"), - "WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"), - "TH" | "THU" | "THURSDAY" => Some("THURSDAY"), - "FR" | "FRI" | "FRIDAY" => Some("FRIDAY"), - "SA" | "SAT" | "SATURDAY" => Some("SATURDAY"), - "SU" | "SUN" | "SUNDAY" => Some("SUNDAY"), - _ => { - // TODO: if spark.sql.ansi.enabled is false, - // returns NULL instead of an error for a malformed dayOfWeek. - None - } - }; + let s = day_of_week.trim(); - if let Some(day_of_week) = day_of_week { - let day_of_week = day_of_week.parse::(); - match day_of_week { - Ok(day_of_week) => Some(Date32Type::from_naive_date( - date + Duration::days( - (7 - date.weekday().days_since(day_of_week)) as i64, - ), - )), - Err(_) => { - // TODO: if spark.sql.ansi.enabled is false, - // returns NULL instead of an error for a malformed dayOfWeek. - None - } - } + let day_of_week = if s.eq_ignore_ascii_case("MO") + || s.eq_ignore_ascii_case("MON") + || s.eq_ignore_ascii_case("MONDAY") + { + Weekday::Mon + } else if s.eq_ignore_ascii_case("TU") + || s.eq_ignore_ascii_case("TUE") + || s.eq_ignore_ascii_case("TUESDAY") + { + Weekday::Tue + } else if s.eq_ignore_ascii_case("WE") + || s.eq_ignore_ascii_case("WED") + || s.eq_ignore_ascii_case("WEDNESDAY") + { + Weekday::Wed + } else if s.eq_ignore_ascii_case("TH") + || s.eq_ignore_ascii_case("THU") + || s.eq_ignore_ascii_case("THURSDAY") + { + Weekday::Thu + } else if s.eq_ignore_ascii_case("FR") + || s.eq_ignore_ascii_case("FRI") + || s.eq_ignore_ascii_case("FRIDAY") + { + Weekday::Fri + } else if s.eq_ignore_ascii_case("SA") + || s.eq_ignore_ascii_case("SAT") + || s.eq_ignore_ascii_case("SATURDAY") + { + Weekday::Sat + } else if s.eq_ignore_ascii_case("SU") + || s.eq_ignore_ascii_case("SUN") + || s.eq_ignore_ascii_case("SUNDAY") + { + Weekday::Sun } else { - None + return None; + }; + + Some(Date32Type::from_naive_date( + date + Duration::days((7 - date.weekday().days_since(day_of_week)) as i64), + )) +} + +fn is_valid_weekday(s: &str) -> bool { + let s = s.trim(); + + s.eq_ignore_ascii_case("MO") + || s.eq_ignore_ascii_case("MON") + || s.eq_ignore_ascii_case("MONDAY") + || s.eq_ignore_ascii_case("TU") + || s.eq_ignore_ascii_case("TUE") + || s.eq_ignore_ascii_case("TUESDAY") + || s.eq_ignore_ascii_case("WE") + || s.eq_ignore_ascii_case("WED") + || s.eq_ignore_ascii_case("WEDNESDAY") + || s.eq_ignore_ascii_case("TH") + || s.eq_ignore_ascii_case("THU") + || s.eq_ignore_ascii_case("THURSDAY") + || s.eq_ignore_ascii_case("FR") + || s.eq_ignore_ascii_case("FRI") + || s.eq_ignore_ascii_case("FRIDAY") + || s.eq_ignore_ascii_case("SA") + || s.eq_ignore_ascii_case("SAT") + || s.eq_ignore_ascii_case("SATURDAY") + || s.eq_ignore_ascii_case("SU") + || s.eq_ignore_ascii_case("SUN") + || s.eq_ignore_ascii_case("SUNDAY") +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::ReturnFieldArgs; + + #[test] + fn return_type_is_not_used() { + let func = SparkNextDay::new(); + let err = func + .return_type(&[DataType::Date32, DataType::Utf8]) + .unwrap_err(); + assert!( + err.to_string() + .contains("return_field_from_args should be used instead") + ); + } + + #[test] + fn next_day_nullability_derived_from_inputs() { + let func = SparkNextDay::new(); + + let non_nullable_date = Arc::new(Field::new("date", DataType::Date32, false)); + let non_nullable_weekday = Arc::new(Field::new("weekday", DataType::Utf8, false)); + + let field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&non_nullable_date), + Arc::clone(&non_nullable_weekday), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert!(!field.is_nullable()); + + let nullable_date = Arc::new(Field::new("date", DataType::Date32, true)); + + let field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[ + Arc::clone(&nullable_date), + Arc::clone(&non_nullable_weekday), + ], + scalar_arguments: &[None, None], + }) + .unwrap(); + + assert!(field.is_nullable()); + } + + #[test] + fn next_day_valid_scalar_is_not_nullable() { + let func = SparkNextDay::new(); + + let date_field = Arc::new(Field::new("date", DataType::Date32, false)); + let weekday_field = Arc::new(Field::new("weekday", DataType::Utf8, false)); + + let scalar = ScalarValue::Utf8(Some("MONDAY".to_string())); + + let field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[date_field, weekday_field], + scalar_arguments: &[None, Some(&scalar)], + }) + .unwrap(); + + assert!(!field.is_nullable()); + } + + #[test] + fn next_day_invalid_scalar_is_nullable() { + let func = SparkNextDay::new(); + + let date_field = Arc::new(Field::new("date", DataType::Date32, false)); + let weekday_field = Arc::new(Field::new("weekday", DataType::Utf8, false)); + + let invalid_scalar = ScalarValue::Utf8(Some("FUNDAY".to_string())); + + let field = func + .return_field_from_args(ReturnFieldArgs { + arg_fields: &[date_field, weekday_field], + scalar_arguments: &[None, Some(&invalid_scalar)], + }) + .unwrap(); + + assert!(field.is_nullable()); } }