Skip to content

Commit 4e03c92

Browse files
authored
refactor: add dialect enum (#18043)
## Which issue does this PR close? - Closes #18042 ## Rationale for this change This PR introduces a new dialect enum to improve type safety and code maintainability when handling different SQL dialects in DataFusion 1. Provide compile-time guarantees for dialect handling 2. Improve code readability and self-documentation 3. Enable better IDE support and autocomplete ## What changes are included in this PR? - Added a new `Dialect` enum to represent supported SQL dialects - Refactored existing code to use the new enum instead of previous representations - Modified tests to work with the new enum-based approach ## Are these changes tested? Yes ## Are there any user-facing changes? Yes, this is an API change: the type of the `dialect` field changed from `String` to `Dialect`
1 parent b1723e5 commit 4e03c92

File tree

10 files changed

+141
-35
lines changed

10 files changed

+141
-35
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-cli/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ async-trait = { workspace = true }
4040
aws-config = "1.8.7"
4141
aws-credential-types = "1.2.7"
4242
chrono = { workspace = true }
43-
clap = { version = "4.5.47", features = ["derive", "cargo"] }
43+
clap = { version = "4.5.47", features = ["cargo", "derive"] }
4444
datafusion = { workspace = true, features = [
4545
"avro",
4646
"compression",
@@ -55,6 +55,7 @@ datafusion = { workspace = true, features = [
5555
"sql",
5656
"unicode_expressions",
5757
] }
58+
datafusion-common = { workspace = true }
5859
dirs = "6.0.0"
5960
env_logger = { workspace = true }
6061
futures = { workspace = true }
@@ -65,7 +66,7 @@ parking_lot = { workspace = true }
6566
parquet = { workspace = true, default-features = false }
6667
regex = { workspace = true }
6768
rustyline = "17.0"
68-
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] }
69+
tokio = { workspace = true, features = ["macros", "parking_lot", "rt", "rt-multi-thread", "signal", "sync"] }
6970
url = { workspace = true }
7071

7172
[dev-dependencies]

datafusion-cli/src/helper.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter};
2424

2525
use datafusion::sql::parser::{DFParser, Statement};
2626
use datafusion::sql::sqlparser::dialect::dialect_from_str;
27+
use datafusion_common::config::Dialect;
2728

2829
use rustyline::completion::{Completer, FilenameCompleter, Pair};
2930
use rustyline::error::ReadlineError;
@@ -34,33 +35,33 @@ use rustyline::{Context, Helper, Result};
3435

3536
pub struct CliHelper {
3637
completer: FilenameCompleter,
37-
dialect: String,
38+
dialect: Dialect,
3839
highlighter: Box<dyn Highlighter>,
3940
}
4041

4142
impl CliHelper {
42-
pub fn new(dialect: &str, color: bool) -> Self {
43+
pub fn new(dialect: &Dialect, color: bool) -> Self {
4344
let highlighter: Box<dyn Highlighter> = if !color {
4445
Box::new(NoSyntaxHighlighter {})
4546
} else {
4647
Box::new(SyntaxHighlighter::new(dialect))
4748
};
4849
Self {
4950
completer: FilenameCompleter::new(),
50-
dialect: dialect.into(),
51+
dialect: *dialect,
5152
highlighter,
5253
}
5354
}
5455

55-
pub fn set_dialect(&mut self, dialect: &str) {
56-
if dialect != self.dialect {
57-
self.dialect = dialect.to_string();
56+
pub fn set_dialect(&mut self, dialect: &Dialect) {
57+
if *dialect != self.dialect {
58+
self.dialect = *dialect;
5859
}
5960
}
6061

6162
fn validate_input(&self, input: &str) -> Result<ValidationResult> {
6263
if let Some(sql) = input.strip_suffix(';') {
63-
let dialect = match dialect_from_str(&self.dialect) {
64+
let dialect = match dialect_from_str(self.dialect) {
6465
Some(dialect) => dialect,
6566
None => {
6667
return Ok(ValidationResult::Invalid(Some(format!(
@@ -97,7 +98,7 @@ impl CliHelper {
9798

9899
impl Default for CliHelper {
99100
fn default() -> Self {
100-
Self::new("generic", false)
101+
Self::new(&Dialect::Generic, false)
101102
}
102103
}
103104

@@ -289,7 +290,7 @@ mod tests {
289290
);
290291

291292
// valid in postgresql dialect
292-
validator.set_dialect("postgresql");
293+
validator.set_dialect(&Dialect::PostgreSQL);
293294
let result =
294295
readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?;
295296
assert!(matches!(result, ValidationResult::Valid(None)));

datafusion-cli/src/highlighter.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use datafusion::sql::sqlparser::{
2727
keywords::Keyword,
2828
tokenizer::{Token, Tokenizer},
2929
};
30+
use datafusion_common::config;
3031
use rustyline::highlight::{CmdKind, Highlighter};
3132

3233
/// The syntax highlighter.
@@ -36,7 +37,7 @@ pub struct SyntaxHighlighter {
3637
}
3738

3839
impl SyntaxHighlighter {
39-
pub fn new(dialect: &str) -> Self {
40+
pub fn new(dialect: &config::Dialect) -> Self {
4041
let dialect = dialect_from_str(dialect).unwrap_or(Box::new(GenericDialect {}));
4142
Self { dialect }
4243
}
@@ -93,13 +94,14 @@ impl Color {
9394

9495
#[cfg(test)]
9596
mod tests {
97+
use super::config::Dialect;
9698
use super::SyntaxHighlighter;
9799
use rustyline::highlight::Highlighter;
98100

99101
#[test]
100102
fn highlighter_valid() {
101103
let s = "SElect col_a from tab_1;";
102-
let highlighter = SyntaxHighlighter::new("generic");
104+
let highlighter = SyntaxHighlighter::new(&Dialect::Generic);
103105
let out = highlighter.highlight(s, s.len());
104106
assert_eq!(
105107
"\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1;",
@@ -110,7 +112,7 @@ mod tests {
110112
#[test]
111113
fn highlighter_valid_with_new_line() {
112114
let s = "SElect col_a from tab_1\n WHERE col_b = 'なにか';";
113-
let highlighter = SyntaxHighlighter::new("generic");
115+
let highlighter = SyntaxHighlighter::new(&Dialect::Generic);
114116
let out = highlighter.highlight(s, s.len());
115117
assert_eq!(
116118
"\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1\n \u{1b}[91mWHERE\u{1b}[0m col_b = \u{1b}[92m'なにか'\u{1b}[0m;",
@@ -121,7 +123,7 @@ mod tests {
121123
#[test]
122124
fn highlighter_invalid() {
123125
let s = "SElect col_a from tab_1 WHERE col_b = ';";
124-
let highlighter = SyntaxHighlighter::new("generic");
126+
let highlighter = SyntaxHighlighter::new(&Dialect::Generic);
125127
let out = highlighter.highlight(s, s.len());
126128
assert_eq!("SElect col_a from tab_1 WHERE col_b = ';", out);
127129
}

datafusion-examples/examples/remote_catalog.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ async fn main() -> Result<()> {
7575
let state = ctx.state();
7676

7777
// First, parse the SQL (but don't plan it / resolve any table references)
78-
let dialect = state.config().options().sql_parser.dialect.as_str();
79-
let statement = state.sql_to_statement(sql, dialect)?;
78+
let dialect = state.config().options().sql_parser.dialect;
79+
let statement = state.sql_to_statement(sql, &dialect)?;
8080

8181
// Find all `TableReferences` in the parsed queries. These correspond to the
8282
// tables referred to by the query (in this case

datafusion/common/src/config.rs

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ config_namespace! {
258258

259259
/// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic,
260260
/// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks.
261-
pub dialect: String, default = "generic".to_string()
261+
pub dialect: Dialect, default = Dialect::Generic
262262
// no need to lowercase because `sqlparser::dialect_from_str`] is case-insensitive
263263

264264
/// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but
@@ -292,6 +292,94 @@ config_namespace! {
292292
}
293293
}
294294

295+
/// This is the SQL dialect used by DataFusion's parser.
296+
/// This mirrors [sqlparser::dialect::Dialect](https://docs.rs/sqlparser/latest/sqlparser/dialect/trait.Dialect.html)
297+
/// trait in order to offer an easier API and avoid adding the `sqlparser` dependency
298+
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
299+
pub enum Dialect {
300+
#[default]
301+
Generic,
302+
MySQL,
303+
PostgreSQL,
304+
Hive,
305+
SQLite,
306+
Snowflake,
307+
Redshift,
308+
MsSQL,
309+
ClickHouse,
310+
BigQuery,
311+
Ansi,
312+
DuckDB,
313+
Databricks,
314+
}
315+
316+
impl AsRef<str> for Dialect {
317+
fn as_ref(&self) -> &str {
318+
match self {
319+
Self::Generic => "generic",
320+
Self::MySQL => "mysql",
321+
Self::PostgreSQL => "postgresql",
322+
Self::Hive => "hive",
323+
Self::SQLite => "sqlite",
324+
Self::Snowflake => "snowflake",
325+
Self::Redshift => "redshift",
326+
Self::MsSQL => "mssql",
327+
Self::ClickHouse => "clickhouse",
328+
Self::BigQuery => "bigquery",
329+
Self::Ansi => "ansi",
330+
Self::DuckDB => "duckdb",
331+
Self::Databricks => "databricks",
332+
}
333+
}
334+
}
335+
336+
impl FromStr for Dialect {
337+
type Err = DataFusionError;
338+
339+
fn from_str(s: &str) -> Result<Self, Self::Err> {
340+
let value = match s.to_ascii_lowercase().as_str() {
341+
"generic" => Self::Generic,
342+
"mysql" => Self::MySQL,
343+
"postgresql" | "postgres" => Self::PostgreSQL,
344+
"hive" => Self::Hive,
345+
"sqlite" => Self::SQLite,
346+
"snowflake" => Self::Snowflake,
347+
"redshift" => Self::Redshift,
348+
"mssql" => Self::MsSQL,
349+
"clickhouse" => Self::ClickHouse,
350+
"bigquery" => Self::BigQuery,
351+
"ansi" => Self::Ansi,
352+
"duckdb" => Self::DuckDB,
353+
"databricks" => Self::Databricks,
354+
other => {
355+
let error_message = format!(
356+
"Invalid Dialect: {other}. Expected one of: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks"
357+
);
358+
return Err(DataFusionError::Configuration(error_message));
359+
}
360+
};
361+
Ok(value)
362+
}
363+
}
364+
365+
impl ConfigField for Dialect {
366+
fn visit<V: Visit>(&self, v: &mut V, key: &str, description: &'static str) {
367+
v.some(key, self, description)
368+
}
369+
370+
fn set(&mut self, _: &str, value: &str) -> Result<()> {
371+
*self = Self::from_str(value)?;
372+
Ok(())
373+
}
374+
}
375+
376+
impl Display for Dialect {
377+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378+
let str = self.as_ref();
379+
write!(f, "{str}")
380+
}
381+
}
382+
295383
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
296384
pub enum SpillCompression {
297385
Zstd,

datafusion/core/benches/sql_planner.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use criterion::Bencher;
3030
use datafusion::datasource::MemTable;
3131
use datafusion::execution::context::SessionContext;
3232
use datafusion::prelude::DataFrame;
33-
use datafusion_common::ScalarValue;
33+
use datafusion_common::{config::Dialect, ScalarValue};
3434
use datafusion_expr::Expr::Literal;
3535
use datafusion_expr::{cast, col, lit, not, try_cast, when};
3636
use datafusion_functions::expr_fn::{
@@ -288,7 +288,10 @@ fn benchmark_with_param_values_many_columns(
288288
}
289289
// SELECT max(attr0), ..., max(attrN) FROM t1.
290290
let query = format!("SELECT {aggregates} FROM t1");
291-
let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap();
291+
let statement = ctx
292+
.state()
293+
.sql_to_statement(&query, &Dialect::Generic)
294+
.unwrap();
292295
let plan =
293296
rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() });
294297
b.iter(|| {

datafusion/core/src/execution/session_state.rs

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@ use crate::datasource::provider_as_source;
3030
use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner};
3131
use crate::execution::SessionStateDefaults;
3232
use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
33+
use arrow::datatypes::DataType;
3334
use datafusion_catalog::information_schema::{
3435
InformationSchemaProvider, INFORMATION_SCHEMA,
3536
};
36-
37-
use arrow::datatypes::DataType;
3837
use datafusion_catalog::MemoryCatalogProviderList;
3938
use datafusion_catalog::{TableFunction, TableFunctionImpl};
4039
use datafusion_common::alias::AliasGenerator;
41-
use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions};
40+
use datafusion_common::config::{ConfigExtension, ConfigOptions, Dialect, TableOptions};
4241
use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
4342
use datafusion_common::tree_node::TreeNode;
4443
use datafusion_common::{
@@ -374,7 +373,7 @@ impl SessionState {
374373
pub fn sql_to_statement(
375374
&self,
376375
sql: &str,
377-
dialect: &str,
376+
dialect: &Dialect,
378377
) -> datafusion_common::Result<Statement> {
379378
let dialect = dialect_from_str(dialect).ok_or_else(|| {
380379
plan_datafusion_err!(
@@ -411,7 +410,7 @@ impl SessionState {
411410
pub fn sql_to_expr(
412411
&self,
413412
sql: &str,
414-
dialect: &str,
413+
dialect: &Dialect,
415414
) -> datafusion_common::Result<SQLExpr> {
416415
self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr)
417416
}
@@ -423,7 +422,7 @@ impl SessionState {
423422
pub fn sql_to_expr_with_alias(
424423
&self,
425424
sql: &str,
426-
dialect: &str,
425+
dialect: &Dialect,
427426
) -> datafusion_common::Result<SQLExprWithAlias> {
428427
let dialect = dialect_from_str(dialect).ok_or_else(|| {
429428
plan_datafusion_err!(
@@ -527,8 +526,8 @@ impl SessionState {
527526
&self,
528527
sql: &str,
529528
) -> datafusion_common::Result<LogicalPlan> {
530-
let dialect = self.config.options().sql_parser.dialect.as_str();
531-
let statement = self.sql_to_statement(sql, dialect)?;
529+
let dialect = self.config.options().sql_parser.dialect;
530+
let statement = self.sql_to_statement(sql, &dialect)?;
532531
let plan = self.statement_to_plan(statement).await?;
533532
Ok(plan)
534533
}
@@ -542,9 +541,9 @@ impl SessionState {
542541
sql: &str,
543542
df_schema: &DFSchema,
544543
) -> datafusion_common::Result<Expr> {
545-
let dialect = self.config.options().sql_parser.dialect.as_str();
544+
let dialect = self.config.options().sql_parser.dialect;
546545

547-
let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?;
546+
let sql_expr = self.sql_to_expr_with_alias(sql, &dialect)?;
548547

549548
let provider = SessionContextProvider {
550549
state: self,
@@ -2034,6 +2033,7 @@ mod tests {
20342033
use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray};
20352034
use arrow::datatypes::{DataType, Field, Schema};
20362035
use datafusion_catalog::MemoryCatalogProviderList;
2036+
use datafusion_common::config::Dialect;
20372037
use datafusion_common::DFSchema;
20382038
use datafusion_common::Result;
20392039
use datafusion_execution::config::SessionConfig;
@@ -2059,8 +2059,8 @@ mod tests {
20592059
let sql = "[1,2,3]";
20602060
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
20612061
let df_schema = DFSchema::try_from(schema)?;
2062-
let dialect = state.config.options().sql_parser.dialect.as_str();
2063-
let sql_expr = state.sql_to_expr(sql, dialect)?;
2062+
let dialect = state.config.options().sql_parser.dialect;
2063+
let sql_expr = state.sql_to_expr(sql, &dialect)?;
20642064

20652065
let query = SqlToRel::new_with_options(&provider, state.get_parser_options());
20662066
query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())
@@ -2218,7 +2218,8 @@ mod tests {
22182218
}
22192219

22202220
let state = &context_provider.state;
2221-
let statement = state.sql_to_statement("select count(*) from t", "mysql")?;
2221+
let statement =
2222+
state.sql_to_statement("select count(*) from t", &Dialect::MySQL)?;
22222223
let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?;
22232224
state.create_physical_plan(&plan).await
22242225
}

0 commit comments

Comments
 (0)