Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3826,7 +3826,7 @@ pub enum Statement {
or_alter: bool,
name: ObjectName,
params: Option<Vec<ProcedureParam>>,
body: Vec<Statement>,
body: ConditionalStatements,
},
/// ```sql
/// CREATE MACRO
Expand Down Expand Up @@ -4705,11 +4705,8 @@ impl fmt::Display for Statement {
write!(f, " ({})", display_comma_separated(p))?;
}
}
write!(
f,
" AS BEGIN {body} END",
body = display_separated(body, "; ")
)

write!(f, " AS {body}")
}
Statement::CreateMacro {
or_replace,
Expand Down
5 changes: 5 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,11 @@ pub trait Dialect: Debug + Any {
fn supports_set_names(&self) -> bool {
false
}

/// Returns true if the dialect supports parsing statements without a semicolon delimiter.
fn supports_statements_without_semicolon_delimiter(&self) -> bool {
false
}
}

/// This represents the operators for which precedence must be defined
Expand Down
9 changes: 8 additions & 1 deletion src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl Dialect for MsSqlDialect {
}

fn supports_connect_by(&self) -> bool {
true
false
}

fn supports_eq_alias_assignment(&self) -> bool {
Expand Down Expand Up @@ -119,6 +119,10 @@ impl Dialect for MsSqlDialect {
true
}

fn supports_statements_without_semicolon_delimiter(&self) -> bool {
true
}

fn is_column_alias(&self, kw: &Keyword, _parser: &mut Parser) -> bool {
!keywords::RESERVED_FOR_COLUMN_ALIAS.contains(kw) && !RESERVED_FOR_COLUMN_ALIAS.contains(kw)
}
Expand Down Expand Up @@ -271,6 +275,9 @@ impl MsSqlDialect {
) -> Result<Vec<Statement>, ParserError> {
let mut stmts = Vec::new();
loop {
while let Token::SemiColon = parser.peek_token_ref().token {
parser.advance_token();
}
if let Token::EOF = parser.peek_token_ref().token {
break;
}
Expand Down
9 changes: 9 additions & 0 deletions src/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,7 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
Keyword::ANTI,
Keyword::SEMI,
Keyword::RETURNING,
Keyword::RETURN,
Keyword::ASOF,
Keyword::MATCH_CONDITION,
// for MSSQL-specific OUTER APPLY (seems reserved in most dialects)
Expand All @@ -1087,6 +1088,11 @@ pub const RESERVED_FOR_TABLE_ALIAS: &[Keyword] = &[
Keyword::TABLESAMPLE,
Keyword::FROM,
Keyword::OPEN,
Keyword::INSERT,
Keyword::UPDATE,
Keyword::DELETE,
Keyword::EXEC,
Keyword::EXECUTE,
];

/// Can't be used as a column alias, so that `SELECT <expr> alias`
Expand Down Expand Up @@ -1115,6 +1121,7 @@ pub const RESERVED_FOR_COLUMN_ALIAS: &[Keyword] = &[
Keyword::CLUSTER,
Keyword::DISTRIBUTE,
Keyword::RETURNING,
Keyword::RETURN,
// Reserved only as a column alias in the `SELECT` clause
Keyword::FROM,
Keyword::INTO,
Expand All @@ -1129,6 +1136,7 @@ pub const RESERVED_FOR_TABLE_FACTOR: &[Keyword] = &[
Keyword::LIMIT,
Keyword::HAVING,
Keyword::WHERE,
Keyword::RETURN,
];

/// Global list of reserved keywords that cannot be parsed as identifiers
Expand All @@ -1139,4 +1147,5 @@ pub const RESERVED_FOR_IDENTIFIER: &[Keyword] = &[
Keyword::INTERVAL,
Keyword::STRUCT,
Keyword::TRIM,
Keyword::RETURN,
];
78 changes: 66 additions & 12 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,17 @@ pub struct ParserOptions {
/// Controls how literal values are unescaped. See
/// [`Tokenizer::with_unescape`] for more details.
pub unescape: bool,
/// Determines if the parser requires a semicolon at the end of every statement.
/// (Default: true)
pub require_semicolon_statement_delimiter: bool,
}

impl Default for ParserOptions {
fn default() -> Self {
Self {
trailing_commas: false,
unescape: true,
require_semicolon_statement_delimiter: true,
}
}
}
Expand Down Expand Up @@ -261,6 +265,22 @@ impl ParserOptions {
self.unescape = unescape;
self
}

/// Set if semicolon statement delimiters are required.
///
/// If this option is `true`, the following SQL will not parse. If the option is `false`, the SQL will parse.
///
/// ```sql
/// SELECT 1
/// SELECT 2
/// ```
pub fn with_require_semicolon_statement_delimiter(
mut self,
require_semicolon_statement_delimiter: bool,
) -> Self {
self.require_semicolon_statement_delimiter = require_semicolon_statement_delimiter;
self
}
}

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -351,7 +371,11 @@ impl<'a> Parser<'a> {
state: ParserState::Normal,
dialect,
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
options: ParserOptions::new().with_trailing_commas(dialect.supports_trailing_commas()),
options: ParserOptions::new()
.with_trailing_commas(dialect.supports_trailing_commas())
.with_require_semicolon_statement_delimiter(
!dialect.supports_statements_without_semicolon_delimiter(),
),
}
}

Expand Down Expand Up @@ -470,10 +494,10 @@ impl<'a> Parser<'a> {
match self.peek_token().token {
Token::EOF => break,

// end of statement
Token::Word(word) => {
if expecting_statement_delimiter && word.keyword == Keyword::END {
break;
// don't expect a semicolon statement delimiter after a newline when not otherwise required
Token::Whitespace(Whitespace::Newline) => {
if !self.options.require_semicolon_statement_delimiter {
expecting_statement_delimiter = false;
}
}
_ => {}
Expand All @@ -485,7 +509,7 @@ impl<'a> Parser<'a> {

let statement = self.parse_statement()?;
stmts.push(statement);
expecting_statement_delimiter = true;
expecting_statement_delimiter = self.options.require_semicolon_statement_delimiter;
}
Ok(stmts)
}
Expand Down Expand Up @@ -4513,6 +4537,9 @@ impl<'a> Parser<'a> {
) -> Result<Vec<Statement>, ParserError> {
let mut values = vec![];
loop {
// ignore empty statements (between successive statement delimiters)
while self.consume_token(&Token::SemiColon) {}

match &self.peek_nth_token_ref(0).token {
Token::EOF => break,
Token::Word(w) => {
Expand All @@ -4524,7 +4551,13 @@ impl<'a> Parser<'a> {
}

values.push(self.parse_statement()?);
self.expect_token(&Token::SemiColon)?;

if self.options.require_semicolon_statement_delimiter {
self.expect_token(&Token::SemiColon)?;
}

// ignore empty statements (between successive statement delimiters)
while self.consume_token(&Token::SemiColon) {}
}
Ok(values)
}
Expand Down Expand Up @@ -15505,14 +15538,14 @@ impl<'a> Parser<'a> {
let name = self.parse_object_name(false)?;
let params = self.parse_optional_procedure_parameters()?;
self.expect_keyword_is(Keyword::AS)?;
self.expect_keyword_is(Keyword::BEGIN)?;
let statements = self.parse_statements()?;
self.expect_keyword_is(Keyword::END)?;

let body = self.parse_conditional_statements(&[Keyword::END])?;

Ok(Statement::CreateProcedure {
name,
or_alter,
params,
body: statements,
body,
})
}

Expand Down Expand Up @@ -15639,7 +15672,28 @@ impl<'a> Parser<'a> {

/// Parse [Statement::Return]
fn parse_return(&mut self) -> Result<Statement, ParserError> {
match self.maybe_parse(|p| p.parse_expr())? {
let rs = self.maybe_parse(|p| {
let expr = p.parse_expr()?;

match &expr {
Expr::Value(_)
| Expr::Function(_)
| Expr::UnaryOp { .. }
| Expr::BinaryOp { .. }
| Expr::Case { .. }
| Expr::Cast { .. }
| Expr::Convert { .. }
| Expr::Subquery(_) => Ok(expr),
// todo: how to retstrict to variables?
Expr::Identifier(id) if id.value.starts_with('@') => Ok(expr),
_ => parser_err!(
"Non-returnable expression found following RETURN",
p.peek_token().span.start
),
}
})?;

match rs {
Some(expr) => Ok(Statement::Return(ReturnStatement {
value: Some(ReturnStatementValue::Expr(expr)),
})),
Expand Down
68 changes: 68 additions & 0 deletions src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,37 @@ impl TestedDialects {
statements
}

/// The same as [`statements_parse_to`] but it will strip semicolons from the SQL text.
pub fn statements_without_semicolons_parse_to(
&self,
sql: &str,
canonical: &str,
) -> Vec<Statement> {
let sql_without_semicolons = sql
.replace("; ", " ")
.replace(" ;", " ")
.replace(";\n", "\n")
.replace("\n;", "\n")
.replace(";", " ");
let statements = self
.parse_sql_statements(&sql_without_semicolons)
.expect(&sql_without_semicolons);
if !canonical.is_empty() && sql != canonical {
assert_eq!(self.parse_sql_statements(canonical).unwrap(), statements);
} else {
assert_eq!(
sql,
statements
.iter()
// note: account for format_statement_list manually inserted semicolons
.map(|s| s.to_string().trim_end_matches(";").to_string())
.collect::<Vec<_>>()
.join("; ")
);
}
statements
}

/// Ensures that `sql` parses as an [`Expr`], and that
/// re-serializing the parse result produces canonical
pub fn expr_parses_to(&self, sql: &str, canonical: &str) -> Expr {
Expand Down Expand Up @@ -313,6 +344,43 @@ where
all_dialects_where(|d| !except(d))
}

/// Returns all dialects that don't support statements without semicolon delimiters.
/// (i.e. dialects that require semicolon delimiters.)
pub fn all_dialects_requiring_semicolon_statement_delimiter() -> TestedDialects {
let tested_dialects =
all_dialects_except(|d| d.supports_statements_without_semicolon_delimiter());
assert_ne!(tested_dialects.dialects.len(), 0);
tested_dialects
}

/// Returns all dialects that do support statements without semicolon delimiters.
/// (i.e. dialects not requiring semicolon delimiters.)
pub fn all_dialects_not_requiring_semicolon_statement_delimiter() -> TestedDialects {
let tested_dialects =
all_dialects_where(|d| d.supports_statements_without_semicolon_delimiter());
assert_ne!(tested_dialects.dialects.len(), 0);
tested_dialects
}

/// Asserts an error for `parse_sql_statements`:
/// - "end of statement" for dialects that require semicolon delimiters
/// - "an SQL statement" for dialects that don't require semicolon delimiters.
pub fn assert_err_parse_statements(sql: &str, found: &str) {
assert_eq!(
ParserError::ParserError(format!("Expected: end of statement, found: {}", found)),
all_dialects_requiring_semicolon_statement_delimiter()
.parse_sql_statements(sql)
.unwrap_err()
);

assert_eq!(
ParserError::ParserError(format!("Expected: an SQL statement, found: {}", found)),
all_dialects_not_requiring_semicolon_statement_delimiter()
.parse_sql_statements(sql)
.unwrap_err()
);
}

pub fn assert_eq_vec<T: ToString>(expected: &[&str], actual: &[T]) {
assert_eq!(
expected,
Expand Down
Loading
Loading