Skip to content

Commit 0fcc31b

Browse files
committed
feat(parser): add support for parsing and replacing custom functions
1 parent 150e1ea commit 0fcc31b

File tree

7 files changed

+403
-11
lines changed

7 files changed

+403
-11
lines changed

src/markProcessor.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export type MarkName =
2020
| 'float'
2121
| 'func_args_end'
2222
| 'func_call'
23+
| 'func_decl'
2324
| 'ident'
2425
| 'inc_range'
2526
| 'integer'

src/nodeTypes.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import type {GroqFunction, GroqPipeFunction} from './evaluator/functions'
22

33
/** Any sort of node which appears as syntax */
4-
export type SyntaxNode = ExprNode | ArrayElementNode | ObjectAttributeNode | SelectAlternativeNode
4+
export type SyntaxNode =
5+
| ExprNode
6+
| ArrayElementNode
7+
| ObjectAttributeNode
8+
| SelectAlternativeNode
9+
| FunctionDeclarationNode
510

611
export type ObjectAttributeNode =
712
| ObjectAttributeValueNode
@@ -280,6 +285,14 @@ export interface ValueNode<P = any> {
280285
value: P
281286
}
282287

288+
export interface FunctionDeclarationNode extends BaseNode {
289+
type: 'FuncDeclaration'
290+
namespace: string
291+
name: string
292+
params: ParameterNode[]
293+
body: ExprNode
294+
}
295+
283296
export interface FlatMapNode extends BaseNode {
284297
type: 'FlatMap'
285298
base: ExprNode

src/parser.ts

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,21 @@ import {tryConstantEvaluate} from './evaluator'
33
import {type GroqFunctionArity, namespaces, pipeFunctions} from './evaluator/functions'
44
import {MarkProcessor, type MarkVisitor} from './markProcessor'
55
import {
6+
type ArrayCoerceNode,
67
type ArrayElementNode,
8+
type DerefNode,
79
type ExprNode,
810
type FuncCallNode,
11+
type FunctionDeclarationNode,
12+
type InlineFuncCallNode,
913
isSelectorNested,
14+
type MapNode,
1015
type ObjectAttributeNode,
1116
type ObjectSplatNode,
1217
type OpCall,
18+
type ParameterNode,
1319
type ParentNode,
20+
type ProjectionNode,
1421
type SelectNode,
1522
type SelectorNode,
1623
walk,
@@ -50,9 +57,14 @@ class GroqQueryError extends Error {
5057
public override name = 'GroqQueryError'
5158
}
5259

60+
type FunctionId = `${string}::${string}`
61+
type CustomFunctions = Record<FunctionId, FunctionDeclarationNode>
62+
5363
function createExpressionBuilder(): {
5464
exprBuilder: MarkVisitor<ExprNode>
65+
customFunctions: CustomFunctions
5566
} {
67+
const customFunctions: CustomFunctions = {}
5668
const exprBuilder: MarkVisitor<ExprNode> = {
5769
group(p) {
5870
const inner = p.process(exprBuilder)
@@ -522,6 +534,37 @@ function createExpressionBuilder(): {
522534
name,
523535
}
524536
},
537+
538+
func_decl(p) {
539+
const namespace = p.processString()
540+
const name = p.processString()
541+
const params: ParameterNode[] = []
542+
while (p.getMark().name !== 'func_params_end') {
543+
const param = p.process(exprBuilder)
544+
if (param.type !== 'Parameter') throw new Error('expected parameter')
545+
params.push(param)
546+
}
547+
548+
if (params.length !== 1) {
549+
throw new GroqQueryError('Custom functions can only have one parameter')
550+
}
551+
552+
p.shift() // func_params_end
553+
554+
const body = p.process(exprBuilder)
555+
556+
const decl = {
557+
type: 'FuncDeclaration',
558+
namespace,
559+
name,
560+
params,
561+
body,
562+
} satisfies FunctionDeclarationNode
563+
564+
customFunctions[`${namespace}::${name}`] = decl
565+
566+
return p.process(exprBuilder)
567+
},
525568
}
526569

527570
const OBJECT_BUILDER: MarkVisitor<ObjectAttributeNode> = {
@@ -860,7 +903,7 @@ function createExpressionBuilder(): {
860903
},
861904
}
862905

863-
return {exprBuilder}
906+
return {exprBuilder, customFunctions}
864907
}
865908

866909
function extractPropertyKey(node: ExprNode): string {
@@ -899,6 +942,71 @@ function validateArity(name: string, arity: GroqFunctionArity, count: number) {
899942
}
900943
}
901944

945+
/**
946+
* The function body is one of the forms:
947+
* - $param{…}
948+
* - $param->{…}
949+
* - $param[]{…}
950+
* - $param[]->{…}
951+
*
952+
* https://github.com/sanity-io/go-groq/blob/b7fb57f5aefe080becff9e3522c0b7b52a79ffd0/parser/internal/parserv2/parser.go#L975-L981
953+
*/
954+
function resolveFunctionParameter(
955+
parameter: ParameterNode,
956+
funcDeclaration: FunctionDeclarationNode,
957+
funcCall: InlineFuncCallNode,
958+
) {
959+
const index = funcDeclaration.params.findIndex((p) => p.name === parameter.name)
960+
if (index === -1) {
961+
throw new GroqQueryError(`Missing argument for parameter ${parameter.name} in function call`)
962+
}
963+
return funcCall.args[index]
964+
}
965+
function replaceCustomFunctionBody(
966+
funcDeclaration: FunctionDeclarationNode,
967+
funcCall: InlineFuncCallNode,
968+
): ExprNode {
969+
const {body} = funcDeclaration
970+
971+
if (body.type === 'Projection') {
972+
if (body.base.type === 'Parameter') {
973+
return {
974+
type: 'Projection',
975+
base: resolveFunctionParameter(body.base, funcDeclaration, funcCall),
976+
expr: body.expr,
977+
}
978+
}
979+
980+
if (body.base.type === 'Deref') {
981+
if (body.base.base.type === 'Parameter') {
982+
return {
983+
type: 'Projection',
984+
base: {
985+
type: 'Deref',
986+
base: resolveFunctionParameter(body.base.base, funcDeclaration, funcCall),
987+
},
988+
expr: body.expr,
989+
}
990+
}
991+
}
992+
}
993+
994+
if (body.type === 'Map' && body.base.type === 'ArrayCoerce') {
995+
if (body.base.base.type === 'Parameter') {
996+
return {
997+
type: 'Map',
998+
base: {
999+
type: 'ArrayCoerce',
1000+
base: resolveFunctionParameter(body.base.base, funcDeclaration, funcCall),
1001+
},
1002+
expr: body.expr,
1003+
}
1004+
}
1005+
}
1006+
1007+
throw new GroqQueryError(`Unexpected function body, must be a projection. Got "${body.type}"`)
1008+
}
1009+
9021010
function argumentShouldBeSelector(namespace: string, functionName: string, argCount: number) {
9031011
const functionsRequiringSelectors = ['changedAny', 'changedOnly']
9041012

@@ -924,20 +1032,41 @@ export function parse(input: string, options: ParseOptions = {}): ExprNode {
9241032
throw new GroqSyntaxError(result.position, result.message)
9251033
}
9261034
const processor = new MarkProcessor(input, result.marks, options)
927-
const {exprBuilder} = createExpressionBuilder()
1035+
const {exprBuilder, customFunctions} = createExpressionBuilder()
9281036
const procssed = processor.process(exprBuilder)
929-
const replaceInlineFuncCalls = createReplaceInlineFuncCalls(options)
930-
return walk(procssed, replaceInlineFuncCalls)
1037+
const replaceInlineFuncCalls = createReplaceInlineFuncCalls(options, customFunctions)
1038+
return walk(procssed, (node) => replaceInlineFuncCalls(node))
9311039
}
9321040

933-
function createReplaceInlineFuncCalls(options: ParseOptions) {
934-
const replacer = (node: ExprNode): ExprNode => {
1041+
function createReplaceInlineFuncCalls(
1042+
options: ParseOptions,
1043+
customFunctions: Record<string, FunctionDeclarationNode>,
1044+
) {
1045+
const replacer = (
1046+
node: ExprNode,
1047+
recurssion: Set<FunctionId> = new Set<FunctionId>(),
1048+
): ExprNode => {
9351049
if (node.type !== 'InlineFuncCall') {
9361050
return node
9371051
}
9381052

9391053
const {namespace, name} = node
9401054

1055+
const functionId: FunctionId = `${namespace}::${name}`
1056+
if (recurssion.has(functionId)) {
1057+
throw new GroqQueryError(`Recursion detected in function ${name}`)
1058+
}
1059+
1060+
// Check for custom function first
1061+
const customFunction = customFunctions[functionId]
1062+
if (customFunction) {
1063+
validateArity(name, customFunction.params.length, node.args.length)
1064+
1065+
return walk(replaceCustomFunctionBody(customFunction, node), (node) =>
1066+
replacer(node, new Set([...recurssion, functionId])),
1067+
)
1068+
}
1069+
9411070
const funcs = namespaces[namespace]
9421071
if (!funcs) {
9431072
throw new GroqQueryError(`Undefined namespace: ${namespace}`)

src/rawParser.js

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,22 @@ const PREC_NEG = 8
2525
function parse(str) {
2626
let pos = 0
2727
pos = skipWS(str, pos)
28+
29+
let fnMarks = []
30+
31+
// Parse function declarations first
32+
while (pos < str.length && str.substring(pos, pos + 2) === 'fn') {
33+
let funcResult = parseFunctionDeclaration(str, pos)
34+
if (funcResult.type === 'error') return funcResult
35+
fnMarks = fnMarks.concat(funcResult.marks)
36+
pos = skipWS(str, funcResult.position)
37+
}
38+
39+
// Parse the main query expression
2840
let result = parseExpr(str, pos, 0)
2941
if (result.type === 'error') return result
3042
pos = skipWS(str, result.position)
43+
3144
if (pos !== str.length) {
3245
if (result.failPosition) {
3346
pos = result.failPosition - 1
@@ -36,6 +49,7 @@ function parse(str) {
3649
}
3750
delete result.position
3851
delete result.failPosition
52+
result.marks = fnMarks.concat(result.marks)
3953
return result
4054
}
4155

@@ -831,4 +845,109 @@ function parseRegexStr(str, pos, re) {
831845
return m ? m[0] : null
832846
}
833847

848+
/**
849+
* Parses a function declaration: fn namespace::name(params) = body;
850+
*/
851+
function parseFunctionDeclaration(str, startPos) {
852+
let pos = startPos
853+
let marks = []
854+
855+
// Parse "fn"
856+
if (str.substring(pos, pos + 2) !== 'fn') {
857+
return {
858+
type: 'success',
859+
position: pos,
860+
marks: marks,
861+
}
862+
}
863+
pos = skipWS(str, pos + 2)
864+
865+
marks.push({name: 'func_decl', position: startPos})
866+
867+
let identStart = pos
868+
let identResult = parseRegexStr(str, pos, IDENT)
869+
if (!identResult) {
870+
return {type: 'error', message: 'Expected function name', position: pos}
871+
}
872+
pos += identResult.length
873+
pos = skipWS(str, pos)
874+
875+
marks.push({name: 'ident', position: identStart}, {name: 'ident_end', position: pos})
876+
877+
// Check for "::"
878+
if (str.substring(pos, pos + 2) !== '::') {
879+
return {type: 'error', message: 'Expected "::" after namespace', position: pos}
880+
}
881+
882+
pos = skipWS(str, pos + 2)
883+
let nameLen = parseRegex(str, pos, IDENT)
884+
if (!nameLen) return {type: 'error', message: 'Expected function name', position: pos}
885+
marks.push({name: 'ident', position: pos}, {name: 'ident_end', position: pos + nameLen})
886+
pos = skipWS(str, pos + nameLen)
887+
888+
if (str[pos] !== '(') {
889+
return {type: 'error', message: 'Expected "("', position: pos}
890+
}
891+
pos = skipWS(str, pos + 1)
892+
893+
// Parse parameters
894+
while (pos < str.length && str[pos] !== ')') {
895+
// Parse parameter (should start with $)
896+
if (str[pos] !== '$') {
897+
return {type: 'error', message: 'Parameter should start with "$"', position: pos}
898+
}
899+
const startPos = pos
900+
pos++
901+
902+
const paramName = parseRegexStr(str, pos, IDENT)
903+
if (!paramName) {
904+
return {type: 'error', message: 'Expected function name', position: pos}
905+
}
906+
pos += paramName.length
907+
marks.push(
908+
{name: 'param', position: startPos},
909+
{name: 'ident', position: startPos + 1},
910+
{name: 'ident_end', position: pos},
911+
)
912+
pos = skipWS(str, pos)
913+
914+
// Check for comma
915+
if (str[pos] === ',') {
916+
pos = skipWS(str, pos + 1)
917+
} else if (str[pos] !== ')') {
918+
return {type: 'error', message: 'Expected "," or ")"', position: pos}
919+
}
920+
}
921+
922+
if (str[pos] !== ')') {
923+
return {type: 'error', message: 'Expected ")"', position: pos}
924+
}
925+
marks.push({name: 'func_params_end', position: pos})
926+
pos = skipWS(str, pos + 1)
927+
928+
if (str[pos] !== '=') {
929+
return {type: 'error', message: 'Expected "="', position: pos}
930+
}
931+
pos = skipWS(str, pos + 1)
932+
933+
// Parse function body (expression)
934+
// marks.push({name: 'func_body', position: pos})
935+
let bodyResult = parseExpr(str, pos, 0)
936+
if (bodyResult.type === 'error') return bodyResult
937+
marks = marks.concat(bodyResult.marks)
938+
pos = skipWS(str, bodyResult.position)
939+
940+
// Parse ";"
941+
if (str[pos] !== ';') {
942+
return {type: 'error', message: 'Expected ";" after function declaration', position: pos}
943+
}
944+
pos++
945+
946+
return {
947+
type: 'success',
948+
position: pos,
949+
marks: marks,
950+
}
951+
}
952+
834953
export {parse}

0 commit comments

Comments
 (0)