From 1d64d827377d692857915693362ffb9a408687ff Mon Sep 17 00:00:00 2001 From: Hillel Sreter Date: Fri, 28 Feb 2025 12:17:55 -0800 Subject: [PATCH] add support setting session params in query see TestSplitQueriesWithSession for sytax of setting session params inline --- presto/client.go | 8 +- presto/query.go | 185 ++++++++++++++++++++++++++++++++++ presto/query_splitter_test.go | 31 ++++++ stage/stage.go | 147 ++++++++++++++++++++++----- stage/stage_utils.go | 2 +- 5 files changed, 344 insertions(+), 29 deletions(-) diff --git a/presto/client.go b/presto/client.go index c6904e8b..016d0dda 100644 --- a/presto/client.go +++ b/presto/client.go @@ -147,8 +147,12 @@ func (c *Client) UserPassword(user, password string) *Client { return c } -func (c *Client) GetSessionParams() string { - return c.getHeader(SessionHeader) +func (c *Client) GetSessionParams() map[string]any { + params := make(map[string]any) + for k, v := range c.sessionParams { + params[k] = v + } + return params } func (c *Client) ClearSessionParams() *Client { diff --git a/presto/query.go b/presto/query.go index 55ee6439..1e04209e 100644 --- a/presto/query.go +++ b/presto/query.go @@ -1,12 +1,21 @@ package presto import ( + "bufio" "context" "io" "net/http" "pbench/presto/query_json" + "strconv" + "strings" ) +// QueryWithSession represents a query and its additional session parameters +type QueryWithSession struct { + Query string + SessionParams map[string]any +} + func (c *Client) requestQueryResults(ctx context.Context, req *http.Request) (*QueryResults, *http.Response, error) { qr := new(QueryResults) resp, err := c.Do(ctx, req, qr) @@ -77,3 +86,179 @@ func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool, } return queryInfo, resp, nil } + +// ParseSessionCommand checks if a query is a session parameter command and returns the parameter and value +// Format: --session parameter_name=parameter_value or --SET SESSION parameter_name=parameter_value +func ParseSessionCommand(query string) (paramName string, paramValue string, isSession bool) { + query = strings.TrimSpace(query) + + // Check if query starts with --session or --set session (case insensitive) + queryLower := strings.ToLower(query) + if !strings.HasPrefix(queryLower, "--session") && !strings.HasPrefix(queryLower, "--set session") { + return "", "", false + } + + // Remove the prefix + if strings.HasPrefix(queryLower, "--set session") { + query = strings.TrimSpace(query[13:]) // len("--set session") = 13 + } else { + query = strings.TrimSpace(query[9:]) // len("--session") = 9 + } + + // Split on equals sign and handle spaces + parts := strings.SplitN(query, "=", 2) + if len(parts) != 2 { + return "", "", false + } + + paramName = strings.ToLower(strings.TrimSpace(parts[0])) + paramValue = strings.TrimSpace(parts[1]) + + // Remove quotes if present + if strings.HasPrefix(paramValue, "'") && strings.HasSuffix(paramValue, "'") { + paramValue = paramValue[1:len(paramValue)-1] + } + if strings.HasPrefix(paramValue, "\"") && strings.HasSuffix(paramValue, "\"") { + paramValue = paramValue[1:len(paramValue)-1] + } + + // Remove trailing semicolon if present + if strings.HasSuffix(paramValue, ";") { + paramValue = strings.TrimSuffix(paramValue, ";") + } + + // Convert value to uppercase for enum values + paramValue = strings.ToUpper(paramValue) + + return paramName, paramValue, true +} + +// cleanQuery removes unnecessary whitespace, newlines, comments and trailing semicolon from a query +func cleanQuery(query string) string { + // Split into lines and handle each line + lines := strings.Split(query, "\n") + cleanLines := make([]string, 0, len(lines)) + + for _, line := range lines { + // Remove inline comments + if idx := strings.Index(line, "--"); idx >= 0 { + line = line[:idx] + } + + trimmed := strings.TrimSpace(line) + if trimmed != "" { + cleanLines = append(cleanLines, trimmed) + } + } + + // Join with single spaces + query = strings.Join(cleanLines, " ") + + // Remove trailing semicolon + if strings.HasSuffix(query, ";") { + query = strings.TrimSuffix(query, ";") + } + + return query +} + +// SplitQueriesWithSession splits a SQL file into individual queries and their associated session parameters +func SplitQueriesWithSession(r io.Reader) ([]QueryWithSession, error) { + queries := make([]QueryWithSession, 0) + currentSessionParams := make(map[string]any) + + scanner := bufio.NewScanner(r) + var currentQuery strings.Builder + inMultilineComment := false + + for scanner.Scan() { + line := scanner.Text() + trimmedLine := strings.TrimSpace(line) + + // Skip empty lines + if len(trimmedLine) == 0 { + continue + } + + // Handle multiline comments + if strings.HasPrefix(trimmedLine, "/*") { + inMultilineComment = true + } + if inMultilineComment { + if strings.HasSuffix(trimmedLine, "*/") { + inMultilineComment = false + } + continue + } + + // Handle single line comments and session parameters + if strings.HasPrefix(trimmedLine, "--") { + paramName, paramValue, isSession := ParseSessionCommand(trimmedLine) + if isSession { + // Try to parse value as number or boolean first + if val, err := strconv.ParseInt(paramValue, 10, 64); err == nil { + currentSessionParams[paramName] = val + } else if val, err := strconv.ParseFloat(paramValue, 64); err == nil { + currentSessionParams[paramName] = val + } else if val, err := strconv.ParseBool(paramValue); err == nil { + currentSessionParams[paramName] = val + } else { + // Remove any remaining quotes from string values + if strings.HasPrefix(paramValue, "'") && strings.HasSuffix(paramValue, "'") { + paramValue = paramValue[1:len(paramValue)-1] + } + if strings.HasPrefix(paramValue, "\"") && strings.HasSuffix(paramValue, "\"") { + paramValue = paramValue[1:len(paramValue)-1] + } + // Treat as string if not a number or boolean + currentSessionParams[paramName] = paramValue + } + } + continue + } + + currentQuery.WriteString(line) + currentQuery.WriteString("\n") + + // Check if line ends with semicolon + if strings.HasSuffix(trimmedLine, ";") { + query := strings.TrimSpace(currentQuery.String()) + if len(query) > 0 { + // Clean up the query formatting + query = cleanQuery(query) + + // Create a copy of current session parameters for this query + sessionParams := make(map[string]any, len(currentSessionParams)) + for k, v := range currentSessionParams { + sessionParams[k] = v + } + queries = append(queries, QueryWithSession{ + Query: query, + SessionParams: sessionParams, + }) + // Clear session parameters after query + currentSessionParams = make(map[string]any) + } + currentQuery.Reset() + } + } + + // Handle last query if it doesn't end with semicolon + lastQuery := strings.TrimSpace(currentQuery.String()) + if len(lastQuery) > 0 { + sessionParams := make(map[string]any, len(currentSessionParams)) + for k, v := range currentSessionParams { + sessionParams[k] = v + } + queries = append(queries, QueryWithSession{ + Query: lastQuery, + SessionParams: sessionParams, + }) + } + + if err := scanner.Err(); err != nil { + return nil, err + } + + return queries, nil +} diff --git a/presto/query_splitter_test.go b/presto/query_splitter_test.go index 8ad328b6..1beeecf4 100644 --- a/presto/query_splitter_test.go +++ b/presto/query_splitter_test.go @@ -36,3 +36,34 @@ another query;;missing semicolon, should be discarded } } } + +func TestSplitQueriesWithSession(t *testing.T) { + input := `/* header comment */ +--SET SESSION join_reordering_strategy = 'NONE'; +--session query_max_memory = '1GB' +--session max_splits_per_node = 1234 +--session optimize_hash_generation = true +-- normal comment +SELECT + * -- inline comment +FROM + table1 +WHERE + id > 0;` + + expected := []presto.QueryWithSession{ + { + Query: "SELECT * FROM table1 WHERE id > 0", + SessionParams: map[string]any{ + "join_reordering_strategy": "NONE", + "query_max_memory": "1GB", + "max_splits_per_node": int64(1234), + "optimize_hash_generation": true, + }, + }, + } + + queries, err := presto.SplitQueriesWithSession(strings.NewReader(input)) + assert.NoError(t, err) + assert.Equal(t, expected, queries) +} diff --git a/stage/stage.go b/stage/stage.go index f0ec5192..3b482843 100644 --- a/stage/stage.go +++ b/stage/stage.go @@ -116,6 +116,13 @@ type Stage struct { started atomic.Bool } +func stringValue(s *string) string { + if s == nil { + return "" + } + return *s +} + // Run this stage and trigger its downstream stages. func (s *Stage) Run(ctx context.Context) int { if s.States == nil { @@ -296,10 +303,21 @@ func (s *Stage) runQueryFile(ctx context.Context, queryFile string, expectedRowC } file, err := os.Open(queryFile) - var queries []string - if err == nil { - queries, err = presto.SplitQueries(file) + if err != nil { + if !*s.AbortOnError { + log.Error().Err(err).Str("file_path", queryFile).Msg("failed to read queries from file") + err = nil + // If we run into errors reading the query file, then the offset in the expected row count array will be messed up. + // Reset it to nil to stop showing expected row counts and to avoid confusions. + s.expectedRowCountInCurrentSchema = nil + } else { + s.States.exitCode.CompareAndSwap(0, 1) + } + return err } + defer file.Close() + + queriesWithSession, err := presto.SplitQueriesWithSession(file) if err != nil { if !*s.AbortOnError { log.Error().Err(err).Str("file_path", queryFile).Msg("failed to read queries from file") @@ -314,14 +332,72 @@ func (s *Stage) runQueryFile(ctx context.Context, queryFile string, expectedRowC } if expectedRowCountStartIndex != nil { - err = s.runQueries(ctx, queries, fileAlias, *expectedRowCountStartIndex) - *expectedRowCountStartIndex += len(queries) + err = s.runQueriesInternal(ctx, queriesWithSession, fileAlias, *expectedRowCountStartIndex) + *expectedRowCountStartIndex += len(queriesWithSession) } else { - err = s.runQueries(ctx, queries, fileAlias, 0) + err = s.runQueriesInternal(ctx, queriesWithSession, fileAlias, 0) } return err } +func (s *Stage) runQueries(ctx context.Context, queries []string, queryFile *string, expectedRowCountStartIndex int) (retErr error) { + batchSize := len(queries) + for i, queryText := range queries { + // run pre query cycle shell scripts + preQueryCycleErr := s.runShellScripts(ctx, s.PreQueryCycleShellScripts) + if preQueryCycleErr != nil { + return fmt.Errorf("pre-query script execution failed: %w", preQueryCycleErr) + } + + // Dereference the pointers for arithmetic + totalRuns := *s.ColdRuns + *s.WarmRuns + for j := 0; j < totalRuns; j++ { + query := &Query{ + Text: queryText, + File: queryFile, + Index: i, + BatchSize: batchSize, + ColdRun: j < *s.ColdRuns, + SequenceNo: j, + ExpectedRowCount: -1, // -1 means unspecified. + } + if len(s.expectedRowCountInCurrentSchema) > expectedRowCountStartIndex+i { + query.ExpectedRowCount = s.expectedRowCountInCurrentSchema[expectedRowCountStartIndex+i] + } + + result, err := s.runQuery(ctx, query) + // err is already attached to the result, if not nil. + if s.States.OnQueryCompletion != nil { + s.States.OnQueryCompletion(result) + } + // Flags and options are checked within. + s.saveQueryJsonFile(result) + // Each query should have a query result sent to the channel, no matter + // its execution succeeded or not. + s.States.resultChan <- result + if err != nil { + if *s.AbortOnError || ctx.Err() != nil { + // If AbortOnError is set, we skip the rest queries in the same batch. + // Logging etc. will be handled in the parent stack. + // If the context is cancelled or timed out, we cannot continue whatsoever and must return. + s.States.exitCode.CompareAndSwap(0, 1) + return result + } + // Log the error information and continue running + s.logErr(ctx, result) + continue + } + log.Info().EmbedObject(result).Msgf("query finished") + } + // run post query cycle shell scripts + postQueryCycleErr := s.runShellScripts(ctx, s.PostQueryCycleShellScripts) + if postQueryCycleErr != nil { + return fmt.Errorf("post-query script execution failed: %w", postQueryCycleErr) + } + } + return nil +} + func (s *Stage) runRandomly(ctx context.Context) error { var continueExecution func(queryCount int) bool if dur, parseErr := time.ParseDuration(*s.RandomlyExecuteUntil); parseErr == nil { @@ -358,7 +434,11 @@ func (s *Stage) runRandomly(ctx context.Context) error { if idx < len(s.Queries) { // Run query embedded in the json file. pseudoFileName := fmt.Sprintf("rand_%d", i) - if err := s.runQueries(ctx, s.Queries[idx:idx+1], &pseudoFileName, 0); err != nil { + queryWithSession := presto.QueryWithSession{ + Query: s.Queries[idx], + SessionParams: make(map[string]any), + } + if err := s.runQueriesInternal(ctx, []presto.QueryWithSession{queryWithSession}, &pseudoFileName, 0); err != nil { return err } } else { @@ -407,17 +487,34 @@ func (s *Stage) runShellScripts(ctx context.Context, shellScripts []string) erro return nil } -func (s *Stage) runQueries(ctx context.Context, queries []string, queryFile *string, expectedRowCountStartIndex int) (retErr error) { +func (s *Stage) runQueriesInternal(ctx context.Context, queries []presto.QueryWithSession, queryFile *string, expectedRowCountStartIndex int) error { + // Log the queries we're about to run + for i, q := range queries { + log.Info(). + Str("stage_id", s.Id). + Int("query_index", i). + Str("query", q.Query). + Interface("session_params", q.SessionParams). + Str("file", stringValue(queryFile)). + Msg("preparing to execute query") + } + batchSize := len(queries) - for i, queryText := range queries { - // run pre query cycle shell scripts - preQueryCycleErr := s.runShellScripts(ctx, s.PreQueryCycleShellScripts) - if preQueryCycleErr != nil { - return fmt.Errorf("pre-query script execution failed: %w", preQueryCycleErr) + for i, queryWithSession := range queries { + // Store current session params to restore later + oldSessionParams := make(map[string]any) + for k, v := range s.Client.GetSessionParams() { + oldSessionParams[k] = v } + + // Apply query-specific session parameters + for name, value := range queryWithSession.SessionParams { + s.Client.SessionParam(name, value) + } + for j := 0; j < *s.ColdRuns+*s.WarmRuns; j++ { query := &Query{ - Text: queryText, + Text: queryWithSession.Query, File: queryFile, Index: i, BatchSize: batchSize, @@ -430,33 +527,31 @@ func (s *Stage) runQueries(ctx context.Context, queries []string, queryFile *str } result, err := s.runQuery(ctx, query) - // err is already attached to the result, if not nil. if s.States.OnQueryCompletion != nil { s.States.OnQueryCompletion(result) } - // Flags and options are checked within. s.saveQueryJsonFile(result) - // Each query should have a query result sent to the channel, no matter - // its execution succeeded or not. s.States.resultChan <- result if err != nil { + // Restore original session parameters before returning + s.Client.ClearSessionParams() + for k, v := range oldSessionParams { + s.Client.SessionParam(k, v) + } if *s.AbortOnError || ctx.Err() != nil { - // If AbortOnError is set, we skip the rest queries in the same batch. - // Logging etc. will be handled in the parent stack. - // If the context is cancelled or timed out, we cannot continue whatsoever and must return. s.States.exitCode.CompareAndSwap(0, 1) return result } - // Log the error information and continue running s.logErr(ctx, result) continue } log.Info().EmbedObject(result).Msgf("query finished") } - // run post query cycle shell scripts - postQueryCycleErr := s.runShellScripts(ctx, s.PostQueryCycleShellScripts) - if postQueryCycleErr != nil { - return fmt.Errorf("post-query script execution failed: %w", postQueryCycleErr) + + // Restore original session parameters after query completes + s.Client.ClearSessionParams() + for k, v := range oldSessionParams { + s.Client.SessionParam(k, v) } } return nil diff --git a/stage/stage_utils.go b/stage/stage_utils.go index de4782bb..d87489e1 100644 --- a/stage/stage_utils.go +++ b/stage/stage_utils.go @@ -176,7 +176,7 @@ func (s *Stage) prepareClient() { } if len(s.SessionParams) > 0 { log.Info().EmbedObject(s). - Str("values", s.Client.GetSessionParams()). + Interface("session_params", s.Client.GetSessionParams()). Msg("set session params") } if s.TimeZone != nil {