Skip to content

Commit 50b870f

Browse files
authored
Yuyang/switch (#21) (#182)
This pull request modifies the compiler support for switch statements. Modifications - Introduces complete translation logic for switch statements, preserving fallthrough, break, and default behaviors. - Initializes control variables: switch_discriminant_, switch_fallthrough_, and switch_break_*. - Defines helpers for splitting at break, translating case bodies, and building conditional chains. - Ensures each case conditionally executes based on the discriminant, fallthrough, and break flags, with deferred evaluation of break to maintain semantic correctness.
1 parent c952c0a commit 50b870f

File tree

6 files changed

+152
-68
lines changed

6 files changed

+152
-68
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ vcs/*.smt2
77
Strata.code-workspace
88

99
conformance_testing/__pycache__
10-
conformance_testing/failures
10+
conformance_testing/failures
11+
12+
test_single_file.sh

Strata/Languages/TypeScript/TS_to_Strata.lean

Lines changed: 92 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -890,50 +890,6 @@ partial def translate_statement_core
890890
-- output: init break flag, init statements, then a loop statement
891891
(ctx1, [initBreakFlag] ++ initStmts ++ [ .loop combinedCondition none none loopBody])
892892

893-
| .TS_SwitchStatement switchStmt =>
894-
-- Handle switch statement: switch discriminant { cases }
895-
896-
-- Process all cases in their original order, separating regular from default
897-
let allCases := switchStmt.cases.toList
898-
let (regularCaseStmts, defaultStmts) := allCases.foldl (fun (regCases, defStmts) case =>
899-
match case.test with
900-
| some expr =>
901-
-- Regular case
902-
let discrimExpr := translate_expr switchStmt.discriminant
903-
let caseValue := translate_expr expr
904-
let testExpr := Heap.HExpr.app (Heap.HExpr.app (Heap.HExpr.deferredOp "Int.Eq" none) discrimExpr) caseValue
905-
let (caseCtx, stmts) := case.consequent.foldl (fun (accCtx, accStmts) stmt =>
906-
let (newCtx, newStmts) := translate_statement_core stmt accCtx
907-
(newCtx, accStmts ++ newStmts)) (ctx, [])
908-
(regCases ++ [(testExpr, stmts)], defStmts)
909-
| none =>
910-
-- Default case
911-
let (defaultCtx, stmts) := case.consequent.foldl (fun (accCtx, accStmts) stmt =>
912-
let (newCtx, newStmts) := translate_statement_core stmt accCtx
913-
(newCtx, accStmts ++ newStmts)) (ctx, [])
914-
(regCases, stmts)
915-
) ([], [])
916-
917-
-- Build nested if-then-else structure for regular cases
918-
let rec build_cases (cases: List (Heap.HExpr × List TSStrataStatement)) (defaultStmts: List TSStrataStatement) : TSStrataStatement :=
919-
match cases with
920-
| [] =>
921-
-- No regular cases, just execute default if it exists
922-
let defaultBlock : Imperative.Block TSStrataExpression TSStrataCommand := { ss := defaultStmts }
923-
.block "default" defaultBlock
924-
| [(test, stmts)] =>
925-
let thenBlock : Imperative.Block TSStrataExpression TSStrataCommand := { ss := stmts }
926-
let elseBlock : Imperative.Block TSStrataExpression TSStrataCommand := { ss := defaultStmts }
927-
.ite test thenBlock elseBlock
928-
| (test, stmts) :: rest =>
929-
let thenBlock : Imperative.Block TSStrataExpression TSStrataCommand := { ss := stmts }
930-
let elseBlock := build_cases rest defaultStmts
931-
let elseBlockWrapped : Imperative.Block TSStrataExpression TSStrataCommand := { ss := [elseBlock] }
932-
.ite test thenBlock elseBlockWrapped
933-
934-
let switchStructure := build_cases regularCaseStmts defaultStmts
935-
(ctx, [switchStructure])
936-
937893
| .TS_ContinueStatement cont =>
938894
let tgt :=
939895
match ct.continueLabel? with
@@ -958,7 +914,98 @@ partial def translate_statement_core
958914
| none => "__unbound_break"
959915
(ctx, [ .goto tgt ])
960916

961-
| _ => panic! s!"Unimplemented statement: {repr s}"
917+
| .TS_SwitchStatement switchStmt =>
918+
-- Handle switch statement with fallthrough and break semantics
919+
dbg_trace s!"[DEBUG] Translating switch statement at loc {switchStmt.start_loc}-{switchStmt.end_loc}"
920+
921+
-- Variables for storing control variables
922+
let loc := switchStmt.start_loc
923+
let discriminantVar := s!"switch_discriminant_{loc}" -- Stores the switch expression value
924+
let fallthroughVar := s!"switch_fallthrough_{loc}" -- Stores fallthrough state
925+
let breakFlagVar := s!"switch_break_{loc}" -- Stores break state
926+
927+
-- Initialize control variables
928+
let initDiscriminant : TSStrataStatement := .cmd (.init discriminantVar (infer_type_from_expr switchStmt.discriminant) (translate_expr switchStmt.discriminant))
929+
let initFallthrough : TSStrataStatement := .cmd (.init fallthroughVar Heap.HMonoTy.bool Heap.HExpr.false)
930+
let initBreakFlag : TSStrataStatement := .cmd (.init breakFlagVar Heap.HMonoTy.bool Heap.HExpr.false)
931+
932+
-- Helper: split statements at break
933+
let splitAtBreak (stmts : List TS_Statement) : List TS_Statement × Bool :=
934+
let rec loop acc rest :=
935+
match rest with
936+
| [] => (acc.reverse, false)
937+
| .TS_BreakStatement _ :: _ => (acc.reverse, true)
938+
| s :: tail => loop (s :: acc) tail
939+
loop [] stmts
940+
941+
-- Helper: translate case body
942+
let translateCaseBody (stmts : List TS_Statement) (caseCtx : TranslationContext) : TranslationContext × List TSStrataStatement :=
943+
stmts.foldl (fun (c, acc) stmt =>
944+
let (c2, ss) := translate_statement_core stmt c ct
945+
(c2, acc ++ ss)) (caseCtx, [])
946+
947+
-- Helper: build case statements with optional break and fallthrough
948+
let buildCaseStmts (caseStmts : List TSStrataStatement) (hasBreak : Bool) (isDefault : Bool) : List TSStrataStatement :=
949+
let setFallthrough := .cmd (.set fallthroughVar Heap.HExpr.true)
950+
let setBreak := .cmd (.set breakFlagVar Heap.HExpr.true)
951+
let stmts := if isDefault then caseStmts else setFallthrough :: caseStmts
952+
if hasBreak then stmts ++ [setBreak] else stmts
953+
954+
-- Flag references
955+
let breakFlagRef := Heap.HExpr.lambda (.fvar breakFlagVar none)
956+
let discriminantRef := Heap.HExpr.lambda (.fvar discriminantVar none)
957+
let fallthroughRef := Heap.HExpr.lambda (.fvar fallthroughVar none)
958+
959+
-- Helper: create condition (if break then false else baseCondition)
960+
let mkCondition (baseCondition : Heap.HExpr) : Heap.HExpr :=
961+
Heap.HExpr.deferredIte breakFlagRef Heap.HExpr.false baseCondition
962+
963+
-- Helper: build case condition for regular case
964+
let mkCaseCondition (testExpr : TS_Expression) : Heap.HExpr :=
965+
let testVal := translate_expr testExpr
966+
let matchCond := Heap.HExpr.app (Heap.HExpr.app (Heap.HExpr.deferredOp "Int.Eq" none) discriminantRef) testVal
967+
let matchOrFallthrough := Heap.HExpr.app (Heap.HExpr.app (Heap.HExpr.deferredOp "Bool.Or" none) fallthroughRef) matchCond
968+
mkCondition matchOrFallthrough
969+
970+
-- Recursive case builder
971+
let rec buildCases (remainingCases : List TS_SwitchCase) (accCtx : TranslationContext) : TranslationContext × TSStrataStatement :=
972+
let emptyBlock : Imperative.Block TSStrataExpression TSStrataCommand := { ss := [] }
973+
974+
match remainingCases with
975+
| [] => (accCtx, .ite Heap.HExpr.false emptyBlock emptyBlock)
976+
977+
| [singleCase] =>
978+
-- Last case: no rest to chain
979+
let (stmtsBeforeBreak, hasBreak) := splitAtBreak singleCase.consequent.toList
980+
let (caseCtx, caseStmts) := translateCaseBody stmtsBeforeBreak accCtx
981+
let isDefault := singleCase.test.isNone
982+
let finalStmts := buildCaseStmts caseStmts hasBreak isDefault
983+
984+
let condition := match singleCase.test with
985+
| none => mkCondition Heap.HExpr.true -- Default: if !break then true
986+
| some testExpr => mkCaseCondition testExpr
987+
988+
(caseCtx, .ite condition { ss := finalStmts } emptyBlock)
989+
990+
| currentCase :: restCases =>
991+
-- Non-last case: chain with rest
992+
let (stmtsBeforeBreak, hasBreak) := splitAtBreak currentCase.consequent.toList
993+
let (caseCtx, caseStmts) := translateCaseBody stmtsBeforeBreak accCtx
994+
let isDefault := currentCase.test.isNone
995+
let finalStmts := buildCaseStmts caseStmts hasBreak isDefault
996+
let (restCtx, restStmt) := buildCases restCases caseCtx
997+
998+
let condition := match currentCase.test with
999+
| none => mkCondition Heap.HExpr.true -- Default: if !break then true
1000+
| some testExpr => mkCaseCondition testExpr
1001+
1002+
(restCtx, .ite condition { ss := finalStmts ++ [restStmt] } { ss := [restStmt] })
1003+
1004+
let (finalCtx, switchBody) := buildCases switchStmt.cases.toList ctx
1005+
dbg_trace s!"[DEBUG] Switch statement translated with {switchStmt.cases.size} cases (with break support)"
1006+
(finalCtx, [initDiscriminant, initFallthrough, initBreakFlag, switchBody])
1007+
1008+
| _ => panic! s!"Unimplemented statement: {repr s}"
9621009

9631010
-- Translate TypeScript statements to TypeScript-Strata statements
9641011
partial def translate_statement (s: TS_Statement) (ctx : TranslationContext) : TranslationContext × List TSStrataStatement :=
Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
function getWeekendName(date: number): string | undefined {
2-
switch (date) {
3-
case 0:
4-
return "Sunday";
5-
case 6:
6-
return "Saturday";
7-
}
8-
}
1+
// Test switch statement with fallthrough
2+
let x: number = 2;
3+
let result: number = 0;
94

10-
let day: number = 6;
11-
let weekendName: string | undefined = getWeekendName(day);
5+
switch (x) {
6+
case 1:
7+
result = 10;
8+
case 2:
9+
result = 20;
10+
case 3:
11+
result = 30;
12+
default:
13+
result = 40;
14+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Simple switch with break
2+
let x: number = 1;
3+
let r: number = 0;
4+
5+
switch (x) {
6+
case 1:
7+
r = 10;
8+
break;
9+
case 2:
10+
r = 20;
11+
}
12+
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Test switch with mixed break and fallthrough
2+
let x: number = 2;
3+
let result: number = 0;
4+
5+
switch (x) {
6+
case 1:
7+
result = 10;
8+
case 2:
9+
result = 20;
10+
case 3:
11+
result = 30;
12+
break;
13+
default:
14+
result = 40;
15+
}
16+
17+
result;
18+
19+

conformance_testing/babel_to_lean.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -307,18 +307,7 @@ def parse_continue_statement(j):
307307
}
308308
add_missing_node_info(j, target_j)
309309
return target_j
310-
311-
def parse_for_statement(j):
312-
target_body = parse_statement(j['body'])
313-
target_j = {
314-
"init": parse_variable_declaration(j['init']),
315-
"test": parse_expression(j['test']),
316-
"update": parse_assignment_expression(j['update']),
317-
"body": target_body
318-
}
319-
add_missing_node_info(j, target_j)
320-
return target_j
321-
310+
322311
def parse_switch_statement(j):
323312
target_j = {
324313
"discriminant": parse_expression(j['discriminant']),
@@ -340,12 +329,24 @@ def parse_switch_case(j):
340329
add_missing_node_info(j, target_j)
341330
return target_j
342331

332+
343333
def parse_break_statement(j):
344334
target_j = {
345335
"label": parse_identifier(j['label']) if j.get('label') else None
346336
}
347337
add_missing_node_info(j, target_j)
348338
return target_j
339+
340+
def parse_for_statement(j):
341+
target_body = parse_statement(j['body'])
342+
target_j = {
343+
"init": parse_variable_declaration(j['init']),
344+
"test": parse_expression(j['test']),
345+
"update": parse_assignment_expression(j['update']),
346+
"body": target_body
347+
}
348+
add_missing_node_info(j, target_j)
349+
return target_j
349350

350351
def parse_statement(j):
351352
match j['type']:

0 commit comments

Comments
 (0)