Skip to content

Commit 74aa271

Browse files
committed
add support setting session params in query
see TestSplitQueriesWithSession for sytax of setting session params inline
1 parent c30b29c commit 74aa271

File tree

5 files changed

+334
-55
lines changed

5 files changed

+334
-55
lines changed

presto/client.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,12 @@ func (c *Client) UserPassword(user, password string) *Client {
138138
return c
139139
}
140140

141-
func (c *Client) GetSessionParams() string {
142-
return c.getHeader(SessionHeader)
141+
func (c *Client) GetSessionParams() map[string]any {
142+
params := make(map[string]any)
143+
for k, v := range c.sessionParams {
144+
params[k] = v
145+
}
146+
return params
143147
}
144148

145149
func (c *Client) ClearSessionParams() *Client {

presto/query.go

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
package presto
22

33
import (
4+
"bufio"
45
"context"
56
"io"
67
"net/http"
78
"pbench/presto/query_json"
9+
"strconv"
10+
"strings"
811
)
912

13+
// QueryWithSession represents a query and its additional session parameters
14+
type QueryWithSession struct {
15+
Query string
16+
SessionParams map[string]any
17+
}
18+
1019
func (c *Client) requestQueryResults(ctx context.Context, req *http.Request) (*QueryResults, *http.Response, error) {
1120
qr := new(QueryResults)
1221
resp, err := c.Do(ctx, req, qr)
@@ -77,3 +86,179 @@ func (c *Client) GetQueryInfo(ctx context.Context, queryId string, pretty bool,
7786
}
7887
return queryInfo, resp, nil
7988
}
89+
90+
// ParseSessionCommand checks if a query is a session parameter command and returns the parameter and value
91+
// Format: --session parameter_name=parameter_value or --SET SESSION parameter_name=parameter_value
92+
func ParseSessionCommand(query string) (paramName string, paramValue string, isSession bool) {
93+
query = strings.TrimSpace(query)
94+
95+
// Check if query starts with --session or --set session (case insensitive)
96+
queryLower := strings.ToLower(query)
97+
if !strings.HasPrefix(queryLower, "--session") && !strings.HasPrefix(queryLower, "--set session") {
98+
return "", "", false
99+
}
100+
101+
// Remove the prefix
102+
if strings.HasPrefix(queryLower, "--set session") {
103+
query = strings.TrimSpace(query[13:]) // len("--set session") = 13
104+
} else {
105+
query = strings.TrimSpace(query[9:]) // len("--session") = 9
106+
}
107+
108+
// Split on equals sign and handle spaces
109+
parts := strings.SplitN(query, "=", 2)
110+
if len(parts) != 2 {
111+
return "", "", false
112+
}
113+
114+
paramName = strings.ToLower(strings.TrimSpace(parts[0]))
115+
paramValue = strings.TrimSpace(parts[1])
116+
117+
// Remove quotes if present
118+
if strings.HasPrefix(paramValue, "'") && strings.HasSuffix(paramValue, "'") {
119+
paramValue = paramValue[1:len(paramValue)-1]
120+
}
121+
if strings.HasPrefix(paramValue, "\"") && strings.HasSuffix(paramValue, "\"") {
122+
paramValue = paramValue[1:len(paramValue)-1]
123+
}
124+
125+
// Remove trailing semicolon if present
126+
if strings.HasSuffix(paramValue, ";") {
127+
paramValue = strings.TrimSuffix(paramValue, ";")
128+
}
129+
130+
// Convert value to uppercase for enum values
131+
paramValue = strings.ToUpper(paramValue)
132+
133+
return paramName, paramValue, true
134+
}
135+
136+
// cleanQuery removes unnecessary whitespace, newlines, comments and trailing semicolon from a query
137+
func cleanQuery(query string) string {
138+
// Split into lines and handle each line
139+
lines := strings.Split(query, "\n")
140+
cleanLines := make([]string, 0, len(lines))
141+
142+
for _, line := range lines {
143+
// Remove inline comments
144+
if idx := strings.Index(line, "--"); idx >= 0 {
145+
line = line[:idx]
146+
}
147+
148+
trimmed := strings.TrimSpace(line)
149+
if trimmed != "" {
150+
cleanLines = append(cleanLines, trimmed)
151+
}
152+
}
153+
154+
// Join with single spaces
155+
query = strings.Join(cleanLines, " ")
156+
157+
// Remove trailing semicolon
158+
if strings.HasSuffix(query, ";") {
159+
query = strings.TrimSuffix(query, ";")
160+
}
161+
162+
return query
163+
}
164+
165+
// SplitQueriesWithSession splits a SQL file into individual queries and their associated session parameters
166+
func SplitQueriesWithSession(r io.Reader) ([]QueryWithSession, error) {
167+
queries := make([]QueryWithSession, 0)
168+
currentSessionParams := make(map[string]any)
169+
170+
scanner := bufio.NewScanner(r)
171+
var currentQuery strings.Builder
172+
inMultilineComment := false
173+
174+
for scanner.Scan() {
175+
line := scanner.Text()
176+
trimmedLine := strings.TrimSpace(line)
177+
178+
// Skip empty lines
179+
if len(trimmedLine) == 0 {
180+
continue
181+
}
182+
183+
// Handle multiline comments
184+
if strings.HasPrefix(trimmedLine, "/*") {
185+
inMultilineComment = true
186+
}
187+
if inMultilineComment {
188+
if strings.HasSuffix(trimmedLine, "*/") {
189+
inMultilineComment = false
190+
}
191+
continue
192+
}
193+
194+
// Handle single line comments and session parameters
195+
if strings.HasPrefix(trimmedLine, "--") {
196+
paramName, paramValue, isSession := ParseSessionCommand(trimmedLine)
197+
if isSession {
198+
// Try to parse value as number or boolean first
199+
if val, err := strconv.ParseInt(paramValue, 10, 64); err == nil {
200+
currentSessionParams[paramName] = val
201+
} else if val, err := strconv.ParseFloat(paramValue, 64); err == nil {
202+
currentSessionParams[paramName] = val
203+
} else if val, err := strconv.ParseBool(paramValue); err == nil {
204+
currentSessionParams[paramName] = val
205+
} else {
206+
// Remove any remaining quotes from string values
207+
if strings.HasPrefix(paramValue, "'") && strings.HasSuffix(paramValue, "'") {
208+
paramValue = paramValue[1:len(paramValue)-1]
209+
}
210+
if strings.HasPrefix(paramValue, "\"") && strings.HasSuffix(paramValue, "\"") {
211+
paramValue = paramValue[1:len(paramValue)-1]
212+
}
213+
// Treat as string if not a number or boolean
214+
currentSessionParams[paramName] = paramValue
215+
}
216+
}
217+
continue
218+
}
219+
220+
currentQuery.WriteString(line)
221+
currentQuery.WriteString("\n")
222+
223+
// Check if line ends with semicolon
224+
if strings.HasSuffix(trimmedLine, ";") {
225+
query := strings.TrimSpace(currentQuery.String())
226+
if len(query) > 0 {
227+
// Clean up the query formatting
228+
query = cleanQuery(query)
229+
230+
// Create a copy of current session parameters for this query
231+
sessionParams := make(map[string]any, len(currentSessionParams))
232+
for k, v := range currentSessionParams {
233+
sessionParams[k] = v
234+
}
235+
queries = append(queries, QueryWithSession{
236+
Query: query,
237+
SessionParams: sessionParams,
238+
})
239+
// Clear session parameters after query
240+
currentSessionParams = make(map[string]any)
241+
}
242+
currentQuery.Reset()
243+
}
244+
}
245+
246+
// Handle last query if it doesn't end with semicolon
247+
lastQuery := strings.TrimSpace(currentQuery.String())
248+
if len(lastQuery) > 0 {
249+
sessionParams := make(map[string]any, len(currentSessionParams))
250+
for k, v := range currentSessionParams {
251+
sessionParams[k] = v
252+
}
253+
queries = append(queries, QueryWithSession{
254+
Query: lastQuery,
255+
SessionParams: sessionParams,
256+
})
257+
}
258+
259+
if err := scanner.Err(); err != nil {
260+
return nil, err
261+
}
262+
263+
return queries, nil
264+
}

presto/query_splitter_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,34 @@ another query;;missing semicolon, should be discarded
3636
}
3737
}
3838
}
39+
40+
func TestSplitQueriesWithSession(t *testing.T) {
41+
input := `/* header comment */
42+
--SET SESSION join_reordering_strategy = 'NONE';
43+
--session query_max_memory = '1GB'
44+
--session max_splits_per_node = 1234
45+
--session optimize_hash_generation = true
46+
-- normal comment
47+
SELECT
48+
* -- inline comment
49+
FROM
50+
table1
51+
WHERE
52+
id > 0;`
53+
54+
expected := []presto.QueryWithSession{
55+
{
56+
Query: "SELECT * FROM table1 WHERE id > 0",
57+
SessionParams: map[string]any{
58+
"join_reordering_strategy": "NONE",
59+
"query_max_memory": "1GB",
60+
"max_splits_per_node": int64(1234),
61+
"optimize_hash_generation": true,
62+
},
63+
},
64+
}
65+
66+
queries, err := presto.SplitQueriesWithSession(strings.NewReader(input))
67+
assert.NoError(t, err)
68+
assert.Equal(t, expected, queries)
69+
}

0 commit comments

Comments
 (0)