@@ -111,6 +111,10 @@ pub(super) enum ParsedStatement {
111111 variable : StmtParam ,
112112 value : StmtWithParams ,
113113 } ,
114+ StaticSimpleSet {
115+ variable : StmtParam ,
116+ value : SimpleSelectValue ,
117+ } ,
114118 CsvImport ( CsvImport ) ,
115119 Error ( anyhow:: Error ) ,
116120}
@@ -142,6 +146,7 @@ fn parse_sql<'a>(
142146 return None ;
143147 }
144148 let statement = parse_single_statement ( & mut parser, db_info, sql) ;
149+ log:: debug!( "Parsed statement: {statement:?}" ) ;
145150 if let Some ( ParsedStatement :: Error ( _) ) = & statement {
146151 has_error = true ;
147152 }
@@ -182,7 +187,6 @@ fn parse_single_statement(
182187 Ok ( stmt) => stmt,
183188 Err ( err) => return Some ( syntax_error ( err, parser, source_sql) ) ,
184189 } ;
185- log:: debug!( "Parsed statement: {stmt}" ) ;
186190 let mut semicolon = false ;
187191 while parser. consume_token ( & SemiColon ) {
188192 semicolon = true ;
@@ -404,40 +408,10 @@ fn extract_static_simple_select(
404408 let mut items = Vec :: with_capacity ( select_items. len ( ) ) ;
405409 let mut params_iter = params. iter ( ) . cloned ( ) ;
406410 for select_item in select_items {
407- use serde_json:: Value :: { Bool , Null , Number , String } ;
408- use SimpleSelectValue :: { Dynamic , Static } ;
409411 let sqlparser:: ast:: SelectItem :: ExprWithAlias { expr, alias } = select_item else {
410412 return None ;
411413 } ;
412- let value = match expr {
413- Expr :: Value ( ValueWithSpan {
414- value : Value :: Boolean ( b) ,
415- ..
416- } ) => Static ( Bool ( * b) ) ,
417- Expr :: Value ( ValueWithSpan {
418- value : Value :: Number ( n, _) ,
419- ..
420- } ) => Static ( Number ( n. parse ( ) . ok ( ) ?) ) ,
421- Expr :: Value ( ValueWithSpan {
422- value : Value :: SingleQuotedString ( s) ,
423- ..
424- } ) => Static ( String ( s. clone ( ) ) ) ,
425- Expr :: Value ( ValueWithSpan {
426- value : Value :: Null , ..
427- } ) => Static ( Null ) ,
428- e if is_simple_select_placeholder ( e) => {
429- if let Some ( p) = params_iter. next ( ) {
430- Dynamic ( p)
431- } else {
432- log:: error!( "Parameter not extracted for placehorder: {expr:?}" ) ;
433- return None ;
434- }
435- }
436- other => {
437- log:: trace!( "Cancelling simple select optimization because of expr: {other:?}" ) ;
438- return None ;
439- }
440- } ;
414+ let value = expr_to_simple_select_val ( & mut params_iter, expr) ?;
441415 let key = alias. value . clone ( ) ;
442416 items. push ( ( key, value) ) ;
443417 }
@@ -448,6 +422,43 @@ fn extract_static_simple_select(
448422 Some ( items)
449423}
450424
425+ fn expr_to_simple_select_val (
426+ params_iter : & mut impl Iterator < Item = StmtParam > ,
427+ expr : & Expr ,
428+ ) -> Option < SimpleSelectValue > {
429+ use serde_json:: Value :: { Bool , Null , Number , String } ;
430+ use SimpleSelectValue :: { Dynamic , Static } ;
431+ Some ( match expr {
432+ Expr :: Value ( ValueWithSpan {
433+ value : Value :: Boolean ( b) ,
434+ ..
435+ } ) => Static ( Bool ( * b) ) ,
436+ Expr :: Value ( ValueWithSpan {
437+ value : Value :: Number ( n, _) ,
438+ ..
439+ } ) => Static ( Number ( n. parse ( ) . ok ( ) ?) ) ,
440+ Expr :: Value ( ValueWithSpan {
441+ value : Value :: SingleQuotedString ( s) ,
442+ ..
443+ } ) => Static ( String ( s. clone ( ) ) ) ,
444+ Expr :: Value ( ValueWithSpan {
445+ value : Value :: Null , ..
446+ } ) => Static ( Null ) ,
447+ e if is_simple_select_placeholder ( e) => {
448+ if let Some ( p) = params_iter. next ( ) {
449+ Dynamic ( p)
450+ } else {
451+ log:: error!( "Parameter not extracted for placehorder: {expr:?}" ) ;
452+ return None ;
453+ }
454+ }
455+ other => {
456+ log:: trace!( "Cancelling simple select optimization because of expr: {other:?}" ) ;
457+ return None ;
458+ }
459+ } )
460+ }
461+
451462fn is_simple_select_placeholder ( e : & Expr ) -> bool {
452463 match e {
453464 Expr :: Value ( ValueWithSpan {
@@ -485,6 +496,11 @@ fn extract_set_variable(
485496 StmtParam :: PostOrGet ( std:: mem:: take ( & mut ident. value ) )
486497 } ;
487498 let owned_expr = std:: mem:: replace ( value, Expr :: value ( Value :: Null ) ) ;
499+ let mut params_iter = params. iter ( ) . cloned ( ) ;
500+ if let Some ( value) = expr_to_simple_select_val ( & mut params_iter, & owned_expr) {
501+ return Some ( ParsedStatement :: StaticSimpleSet { variable, value } ) ;
502+ }
503+
488504 let mut select_stmt: Statement = expr_to_statement ( owned_expr) ;
489505 let delayed_functions = extract_toplevel_functions ( & mut select_stmt) ;
490506 if let Err ( err) = validate_function_calls ( & select_stmt) {
@@ -1248,26 +1264,24 @@ mod test {
12481264 }
12491265
12501266 #[ test]
1251- fn test_set_variable ( ) {
1267+ fn test_set_variable_to_other_variable ( ) {
12521268 let sql = "set x = $y" ;
12531269 for & ( dialect, dbms) in ALL_DIALECTS {
12541270 let mut parser = Parser :: new ( dialect) . try_with_sql ( sql) . unwrap ( ) ;
12551271 let db_info = create_test_db_info ( dbms) ;
1256- let stmt = parse_single_statement ( & mut parser, & db_info, sql) ;
1257- if let Some ( ParsedStatement :: SetVariable {
1258- variable,
1259- value : StmtWithParams { query, params, .. } ,
1260- } ) = stmt
1261- {
1262- assert_eq ! (
1263- variable,
1264- StmtParam :: PostOrGet ( "x" . to_string( ) ) ,
1265- "{dialect:?}"
1266- ) ;
1267- assert ! ( query. starts_with( "SELECT " ) ) ;
1268- assert_eq ! ( params, [ StmtParam :: PostOrGet ( "y" . to_string( ) ) ] ) ;
1269- } else {
1270- panic ! ( "Failed for dialect {dialect:?}: {stmt:#?}" , ) ;
1272+ match parse_single_statement ( & mut parser, & db_info, sql) {
1273+ Some ( ParsedStatement :: StaticSimpleSet { variable, value } ) => {
1274+ assert_eq ! (
1275+ variable,
1276+ StmtParam :: PostOrGet ( "x" . to_string( ) ) ,
1277+ "{dialect:?}"
1278+ ) ;
1279+ assert_eq ! (
1280+ value,
1281+ SimpleSelectValue :: Dynamic ( StmtParam :: PostOrGet ( "y" . to_string( ) ) )
1282+ ) ;
1283+ }
1284+ other => panic ! ( "Failed for dialect {dialect:?}: {other:#?}" ) ,
12711285 }
12721286 }
12731287 }
@@ -1398,7 +1412,7 @@ mod test {
13981412
13991413 #[ test]
14001414 fn test_extract_set_variable ( ) {
1401- let sql = "set x = 42 " ;
1415+ let sql = "set x = CURRENT_TIMESTAMP " ;
14021416 for & ( dialect, dbms) in ALL_DIALECTS {
14031417 let mut parser = Parser :: new ( dialect) . try_with_sql ( sql) . unwrap ( ) ;
14041418 let db_info = create_test_db_info ( dbms) ;
@@ -1413,14 +1427,33 @@ mod test {
14131427 StmtParam :: PostOrGet ( "x" . to_string( ) ) ,
14141428 "{dialect:?}"
14151429 ) ;
1416- assert_eq ! ( query, "SELECT 42 AS sqlpage_set_expr" ) ;
1430+ assert_eq ! ( query, "SELECT CURRENT_TIMESTAMP AS sqlpage_set_expr" ) ;
14171431 assert ! ( params. is_empty( ) ) ;
14181432 } else {
14191433 panic ! ( "Failed for dialect {dialect:?}: {stmt:#?}" , ) ;
14201434 }
14211435 }
14221436 }
14231437
1438+ #[ test]
1439+ fn test_extract_set_variable_static ( ) {
1440+ let sql = "set x = 'hello'" ;
1441+ for & ( dialect, dbms) in ALL_DIALECTS {
1442+ let mut parser = Parser :: new ( dialect) . try_with_sql ( sql) . unwrap ( ) ;
1443+ let db_info = create_test_db_info ( dbms) ;
1444+ match parse_single_statement ( & mut parser, & db_info, sql) {
1445+ Some ( ParsedStatement :: StaticSimpleSet {
1446+ variable : StmtParam :: PostOrGet ( var_name) ,
1447+ value : SimpleSelectValue :: Static ( value) ,
1448+ } ) => {
1449+ assert_eq ! ( var_name, "x" ) ;
1450+ assert_eq ! ( value, "hello" ) ;
1451+ }
1452+ other => panic ! ( "Failed for dialect {dialect:?}: {other:#?}" ) ,
1453+ }
1454+ }
1455+ }
1456+
14241457 #[ test]
14251458 fn test_static_extract_doesnt_match ( ) {
14261459 assert_eq ! (
0 commit comments