Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 207 additions & 75 deletions datafusion/spark/src/function/datetime/next_day.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ use std::any::Any;
use std::sync::Arc;

use arrow::array::{ArrayRef, AsArray, Date32Array, StringArrayType, new_null_array};
use arrow::datatypes::{DataType, Date32Type};
use arrow::datatypes::{DataType, Date32Type, Field, FieldRef};
use chrono::{Datelike, Duration, Weekday};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature,
Volatility,
};

/// <https://spark.apache.org/docs/latest/api/sql/index.html#next_day>
Expand Down Expand Up @@ -63,7 +64,30 @@ impl ScalarUDFImpl for SparkNextDay {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Date32)
internal_err!("return_field_from_args should be used instead")
}

fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let [date_field, weekday_field] = args.arg_fields else {
return internal_err!("Spark `next_day` expects exactly two arguments");
};

let has_invalid_scalar = args.scalar_arguments.iter().any(|arg| match arg {
Some(ScalarValue::Utf8(Some(s)))
| Some(ScalarValue::LargeUtf8(Some(s)))
| Some(ScalarValue::Utf8View(Some(s))) => !is_valid_weekday(s),
Some(v) => v.is_null(),
_ => false,
});

let nullable =
date_field.is_nullable() || weekday_field.is_nullable() || has_invalid_scalar;

Ok(Arc::new(Field::new(
self.name(),
DataType::Date32,
nullable,
)))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Expand All @@ -90,8 +114,6 @@ impl ScalarUDFImpl for SparkNextDay {
spark_next_day(*days, day_of_week.as_str()),
)))
} else {
// TODO: if spark.sql.ansi.enabled is false,
// returns NULL instead of an error for a malformed dayOfWeek.
Ok(ColumnarValue::Scalar(ScalarValue::Date32(None)))
}
} else {
Expand Down Expand Up @@ -120,8 +142,6 @@ impl ScalarUDFImpl for SparkNextDay {
.with_data_type(DataType::Date32);
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
} else {
// TODO: if spark.sql.ansi.enabled is false,
// returns NULL instead of an error for a malformed dayOfWeek.
Ok(ColumnarValue::Array(Arc::new(new_null_array(
&DataType::Date32,
date_array.len(),
Expand All @@ -146,33 +166,26 @@ impl ScalarUDFImpl for SparkNextDay {
let date_array: &Date32Array =
date_array.as_primitive::<Date32Type>();
match day_of_week_array.data_type() {
DataType::Utf8 => {
let day_of_week_array =
day_of_week_array.as_string::<i32>();
process_next_day_arrays(date_array, day_of_week_array)
}
DataType::LargeUtf8 => {
let day_of_week_array =
day_of_week_array.as_string::<i64>();
process_next_day_arrays(date_array, day_of_week_array)
}
DataType::Utf8View => {
let day_of_week_array =
day_of_week_array.as_string_view();
process_next_day_arrays(date_array, day_of_week_array)
}
other => {
exec_err!(
"Spark `next_day` function: second arg must be string. Got {other:?}"
)
}
DataType::Utf8 => process_next_day_arrays(
date_array,
day_of_week_array.as_string::<i32>(),
),
DataType::LargeUtf8 => process_next_day_arrays(
date_array,
day_of_week_array.as_string::<i64>(),
),
DataType::Utf8View => process_next_day_arrays(
date_array,
day_of_week_array.as_string_view(),
),
other => exec_err!(
"Spark `next_day` function: second arg must be string. Got {other:?}"
),
}
}
(left, right) => {
exec_err!(
"Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
)
}
(left, right) => exec_err!(
"Spark `next_day` function: first arg must be date, second arg must be string. Got {left:?}, {right:?}"
),
}?;
Ok(ColumnarValue::Array(result))
}
Expand All @@ -191,57 +204,176 @@ where
let result = date_array
.iter()
.zip(day_of_week_array.iter())
.map(|(days, day_of_week)| {
if let Some(days) = days {
if let Some(day_of_week) = day_of_week {
spark_next_day(days, day_of_week)
} else {
// TODO: if spark.sql.ansi.enabled is false,
// returns NULL instead of an error for a malformed dayOfWeek.
None
}
} else {
None
}
.map(|(days, day_of_week)| match (days, day_of_week) {
(Some(days), Some(day_of_week)) => spark_next_day(days, day_of_week),
_ => None,
})
.collect::<Date32Array>();

Ok(Arc::new(result) as ArrayRef)
}

fn spark_next_day(days: i32, day_of_week: &str) -> Option<i32> {
let date = Date32Type::to_naive_date(days);

let day_of_week = day_of_week.trim().to_uppercase();
let day_of_week = match day_of_week.as_str() {
"MO" | "MON" | "MONDAY" => Some("MONDAY"),
"TU" | "TUE" | "TUESDAY" => Some("TUESDAY"),
"WE" | "WED" | "WEDNESDAY" => Some("WEDNESDAY"),
"TH" | "THU" | "THURSDAY" => Some("THURSDAY"),
"FR" | "FRI" | "FRIDAY" => Some("FRIDAY"),
"SA" | "SAT" | "SATURDAY" => Some("SATURDAY"),
"SU" | "SUN" | "SUNDAY" => Some("SUNDAY"),
_ => {
// TODO: if spark.sql.ansi.enabled is false,
// returns NULL instead of an error for a malformed dayOfWeek.
None
}
};
let s = day_of_week.trim();

if let Some(day_of_week) = day_of_week {
let day_of_week = day_of_week.parse::<Weekday>();
match day_of_week {
Ok(day_of_week) => Some(Date32Type::from_naive_date(
date + Duration::days(
(7 - date.weekday().days_since(day_of_week)) as i64,
),
)),
Err(_) => {
// TODO: if spark.sql.ansi.enabled is false,
// returns NULL instead of an error for a malformed dayOfWeek.
None
}
}
let day_of_week = if s.eq_ignore_ascii_case("MO")
|| s.eq_ignore_ascii_case("MON")
|| s.eq_ignore_ascii_case("MONDAY")
{
Weekday::Mon
} else if s.eq_ignore_ascii_case("TU")
|| s.eq_ignore_ascii_case("TUE")
|| s.eq_ignore_ascii_case("TUESDAY")
{
Weekday::Tue
} else if s.eq_ignore_ascii_case("WE")
|| s.eq_ignore_ascii_case("WED")
|| s.eq_ignore_ascii_case("WEDNESDAY")
{
Weekday::Wed
} else if s.eq_ignore_ascii_case("TH")
|| s.eq_ignore_ascii_case("THU")
|| s.eq_ignore_ascii_case("THURSDAY")
{
Weekday::Thu
} else if s.eq_ignore_ascii_case("FR")
|| s.eq_ignore_ascii_case("FRI")
|| s.eq_ignore_ascii_case("FRIDAY")
{
Weekday::Fri
} else if s.eq_ignore_ascii_case("SA")
|| s.eq_ignore_ascii_case("SAT")
|| s.eq_ignore_ascii_case("SATURDAY")
{
Weekday::Sat
} else if s.eq_ignore_ascii_case("SU")
|| s.eq_ignore_ascii_case("SUN")
|| s.eq_ignore_ascii_case("SUNDAY")
{
Weekday::Sun
} else {
None
return None;
};

Some(Date32Type::from_naive_date(
date + Duration::days((7 - date.weekday().days_since(day_of_week)) as i64),
))
}

fn is_valid_weekday(s: &str) -> bool {
let s = s.trim();

s.eq_ignore_ascii_case("MO")
|| s.eq_ignore_ascii_case("MON")
|| s.eq_ignore_ascii_case("MONDAY")
|| s.eq_ignore_ascii_case("TU")
|| s.eq_ignore_ascii_case("TUE")
|| s.eq_ignore_ascii_case("TUESDAY")
|| s.eq_ignore_ascii_case("WE")
|| s.eq_ignore_ascii_case("WED")
|| s.eq_ignore_ascii_case("WEDNESDAY")
|| s.eq_ignore_ascii_case("TH")
|| s.eq_ignore_ascii_case("THU")
|| s.eq_ignore_ascii_case("THURSDAY")
|| s.eq_ignore_ascii_case("FR")
|| s.eq_ignore_ascii_case("FRI")
|| s.eq_ignore_ascii_case("FRIDAY")
|| s.eq_ignore_ascii_case("SA")
|| s.eq_ignore_ascii_case("SAT")
|| s.eq_ignore_ascii_case("SATURDAY")
|| s.eq_ignore_ascii_case("SU")
|| s.eq_ignore_ascii_case("SUN")
|| s.eq_ignore_ascii_case("SUNDAY")
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr::ReturnFieldArgs;

#[test]
fn return_type_is_not_used() {
let func = SparkNextDay::new();
let err = func
.return_type(&[DataType::Date32, DataType::Utf8])
.unwrap_err();
assert!(
err.to_string()
.contains("return_field_from_args should be used instead")
);
}

#[test]
fn next_day_nullability_derived_from_inputs() {
let func = SparkNextDay::new();

let non_nullable_date = Arc::new(Field::new("date", DataType::Date32, false));
let non_nullable_weekday = Arc::new(Field::new("weekday", DataType::Utf8, false));

let field = func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[
Arc::clone(&non_nullable_date),
Arc::clone(&non_nullable_weekday),
],
scalar_arguments: &[None, None],
})
.unwrap();

assert!(!field.is_nullable());

let nullable_date = Arc::new(Field::new("date", DataType::Date32, true));

let field = func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[
Arc::clone(&nullable_date),
Arc::clone(&non_nullable_weekday),
],
scalar_arguments: &[None, None],
})
.unwrap();

assert!(field.is_nullable());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add tests for non-None scalar arguments and for invalid scalar arguments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added tests for non-None scalar arguments as well as invalid scalar arguments.

}

#[test]
fn next_day_valid_scalar_is_not_nullable() {
let func = SparkNextDay::new();

let date_field = Arc::new(Field::new("date", DataType::Date32, false));
let weekday_field = Arc::new(Field::new("weekday", DataType::Utf8, false));

let scalar = ScalarValue::Utf8(Some("MONDAY".to_string()));

let field = func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[date_field, weekday_field],
scalar_arguments: &[None, Some(&scalar)],
})
.unwrap();

assert!(!field.is_nullable());
}

#[test]
fn next_day_invalid_scalar_is_nullable() {
let func = SparkNextDay::new();

let date_field = Arc::new(Field::new("date", DataType::Date32, false));
let weekday_field = Arc::new(Field::new("weekday", DataType::Utf8, false));

let invalid_scalar = ScalarValue::Utf8(Some("FUNDAY".to_string()));

let field = func
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[date_field, weekday_field],
scalar_arguments: &[None, Some(&invalid_scalar)],
})
.unwrap();

assert!(field.is_nullable());
}
}