Skip to content

Commit 51fc365

Browse files
VM Errors & Type Convertions & Missing Types (#2403)
<!-- ELLIPSIS_HIDDEN --> > [!IMPORTANT] > Adds support for enums in BAML compiler and VM, including bytecode generation, type checking, and runtime handling, with corresponding tests and client updates. > > - **Behavior**: > - Adds support for enums in BAML compiler and VM, including bytecode generation and runtime handling. > - Updates `compile_thir_to_bytecode()` in `codegen.rs` to handle enums and their variants. > - Adds `Enum` and `Variant` handling in `vm.rs` and `bytecode.rs`. > - Updates `typecheck_expression()` in `typecheck.rs` to handle enum types. > - **Tests**: > - Adds tests for enum functionality in `tests/vm.rs` and `test_vm_async_runtime.py`. > - Updates existing tests to reflect changes in global index handling. > - **Misc**: > - Updates Python and TypeScript clients to handle new enum-related functions. > - Fixes and refactors various parts of the codebase to support enums. > > <sup>This description was created by </sup>[<img alt="Ellipsis" src="https://img.shields.io/badge/Ellipsis-blue?color=175173">](https://www.ellipsis.dev?ref=BoundaryML%2Fbaml&utm_source=github&utm_medium=referral)<sup> for 30c77e0. You can [customize](https://app.ellipsis.dev/BoundaryML/settings/summaries) this summary. It will automatically update as commits are pushed.</sup> <!-- ELLIPSIS_HIDDEN -->
1 parent 8100bc2 commit 51fc365

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1442
-258
lines changed

engine/baml-compiler/src/codegen.rs

Lines changed: 184 additions & 118 deletions
Large diffs are not rendered by default.

engine/baml-compiler/src/hir/lowering.rs

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
//!
33
//! This files contains the convertions between Baml AST nodes to HIR nodes.
44
5+
use std::collections::HashSet;
6+
57
use baml_types::{
8+
ir_type::TypeGeneric,
69
type_meta::{self, base::StreamingBehavior},
710
Constraint, ConstraintLevel, TypeIR, TypeValue,
811
};
@@ -48,6 +51,38 @@ impl Hir {
4851
}
4952
}
5053

54+
let enums = HashSet::<&str>::from_iter(hir.enums.iter().map(|e| e.name.as_str()));
55+
56+
let param_type: fn(&mut Parameter) -> &mut TypeIR = |p| &mut p.r#type;
57+
58+
// Patch return types because only here in the code we have the full
59+
// context for enums.
60+
hir.expr_functions
61+
.iter_mut()
62+
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type))
63+
.chain(
64+
hir.llm_functions
65+
.iter_mut()
66+
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type)),
67+
)
68+
.chain(hir.classes.iter_mut().flat_map(|c| {
69+
c.methods
70+
.iter_mut()
71+
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type))
72+
}))
73+
.flat_map(|(parameters, return_type)| parameters.chain(std::iter::once(return_type)))
74+
.for_each(|ty| match ty {
75+
TypeIR::Class { name, meta, .. } if enums.contains(name.as_str()) => {
76+
*ty = TypeIR::Enum {
77+
name: name.to_owned(),
78+
dynamic: false, // TODO: How to know if it's dynamic.
79+
meta: meta.clone(),
80+
}
81+
}
82+
83+
_ => {}
84+
});
85+
5186
hir
5287
}
5388
}
@@ -125,22 +160,12 @@ pub fn type_ir_from_ast(type_: &ast::FieldType) -> TypeIR {
125160
};
126161

127162
match type_ {
128-
ast::FieldType::Symbol(_, name, _) => {
129-
if name.name().starts_with("Enum") {
130-
TypeIR::Enum {
131-
name: name.name().to_string(),
132-
dynamic: false,
133-
meta,
134-
}
135-
} else {
136-
TypeIR::Class {
137-
name: name.name().to_string(),
138-
mode: baml_types::ir_type::StreamingMode::NonStreaming,
139-
dynamic: false,
140-
meta,
141-
}
142-
}
143-
}
163+
ast::FieldType::Symbol(_, name, _) => TypeIR::Class {
164+
name: name.name().to_string(),
165+
mode: baml_types::ir_type::StreamingMode::NonStreaming,
166+
dynamic: false,
167+
meta,
168+
},
144169
ast::FieldType::Primitive(_, prim, _, _) => TypeIR::Primitive(*prim, meta),
145170
ast::FieldType::List(_, inner, dims, _, _) => {
146171
// Respect multi-dimensional arrays (e.g., int[][] has dims=2)

engine/baml-compiler/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,20 @@ pub mod test {
99
use internal_baml_diagnostics::Diagnostics;
1010
use internal_baml_parser_database::{parse_and_diagnostics, ParserDatabase};
1111

12+
use crate::{hir, thir};
13+
1214
/// Shim helper function for testing.
1315
pub fn ast(source: &'static str) -> anyhow::Result<ParserDatabase> {
14-
let (parser_db, diagnostics) = parse_and_diagnostics(source)?;
16+
let (parser_db, mut diagnostics) = parse_and_diagnostics(source)?;
17+
18+
if diagnostics.has_errors() {
19+
let errors = diagnostics.to_pretty_string();
20+
anyhow::bail!("{errors}");
21+
}
1522

23+
// Here because of cycle dependencies between crates and shit.
24+
// TODO: We're building this like 3 different times, needs refactoring.
25+
thir::typecheck::typecheck(&hir::Hir::from_ast(&parser_db.ast), &mut diagnostics);
1626
if diagnostics.has_errors() {
1727
let errors = diagnostics.to_pretty_string();
1828
anyhow::bail!("{errors}");

engine/baml-compiler/src/thir.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
///
33
use baml_types::ir_type::TypeIR;
44

5-
use crate::hir::{self, AssignOp, BinaryOperator, Enum, LlmFunction, UnaryOperator};
5+
use crate::hir::{self, AssignOp, BinaryOperator, LlmFunction, UnaryOperator};
66

77
pub mod interpret;
88
pub mod typecheck;
@@ -54,6 +54,14 @@ pub struct Class<T> {
5454
pub span: Span,
5555
}
5656

57+
#[derive(Clone, Debug)]
58+
pub struct Enum {
59+
pub name: String,
60+
pub variants: Vec<hir::EnumVariant>,
61+
pub span: Span,
62+
pub ty: TypeIR, // TODO: Used for type checking, but do we need this?
63+
}
64+
5765
/// A BAML expression term.
5866
/// T is the type of the metadata.
5967
#[derive(Debug, Clone)]

engine/baml-compiler/src/thir/typecheck.rs

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ pub fn typecheck_returning_context<'a>(
4444
.map(|c| (c.name.clone(), c))
4545
.collect();
4646

47-
let enums = hir
47+
let enums: BamlMap<String, hir::Enum> = hir
4848
.enums
4949
.clone()
5050
.into_iter()
@@ -54,6 +54,7 @@ pub fn typecheck_returning_context<'a>(
5454
// Create typing context with all functions
5555
let mut typing_context = TypeContext::new();
5656
typing_context.classes.extend(classes.clone());
57+
typing_context.enums.extend(enums.clone());
5758

5859
// Add expr functions to typing context
5960
for func in &hir.expr_functions {
@@ -271,11 +272,32 @@ pub fn typecheck_returning_context<'a>(
271272
})
272273
.collect();
273274

275+
// TODO: Those are HIR enums, figure out if there's something different we
276+
// would need in a THIR enum? Does it need a "type"?.
277+
let thir_enums = enums
278+
.iter()
279+
.map(|(name, enum_def)| {
280+
(
281+
name.clone(),
282+
thir::Enum {
283+
name: enum_def.name.clone(),
284+
variants: enum_def.variants.clone(),
285+
span: enum_def.span.clone(),
286+
ty: TypeIR::Enum {
287+
name: enum_def.name.clone(),
288+
dynamic: false,
289+
meta: Default::default(),
290+
},
291+
},
292+
)
293+
})
294+
.collect();
295+
274296
(
275297
THir {
276298
llm_functions: hir.llm_functions.clone(),
277299
classes: thir_classes,
278-
enums,
300+
enums: thir_enums,
279301
expr_functions,
280302
global_assignments: BamlMap::new(),
281303
},
@@ -302,6 +324,7 @@ pub struct TypeContext<'func> {
302324
// Variables in scope with mutability info
303325
pub vars: BamlMap<String, VarInfo>,
304326
pub classes: BamlMap<String, hir::Class>,
327+
pub enums: BamlMap<String, hir::Enum>,
305328
// Used for knowing whether `break` and `continue` are inside a loop or not.
306329
pub is_inside_loop: bool,
307330

@@ -336,6 +359,7 @@ impl TypeContext<'_> {
336359
symbols: BamlMap::new(),
337360
vars,
338361
classes: BamlMap::new(),
362+
enums: BamlMap::new(),
339363
is_inside_loop: false,
340364
function_return_type: None,
341365
}
@@ -428,7 +452,9 @@ impl TypeContext<'_> {
428452
}
429453
hir::Expression::Identifier(name, _) => {
430454
// Look up type in context
431-
self.get_type(name).cloned()
455+
self.get_type(name)
456+
.cloned()
457+
.or_else(|| self.enums.get(name).map(|e| TypeIR::r#enum(&e.name)))
432458
}
433459
hir::Expression::Array(items, _) => {
434460
// Infer array type from first item
@@ -575,6 +601,14 @@ impl TypeContext<'_> {
575601
None
576602
}
577603
}
604+
TypeIR::Enum {
605+
name: enum_name, ..
606+
} => {
607+
// Look up field in enum definition
608+
self.enums
609+
.get(&enum_name)
610+
.map(|enum_def| TypeIR::r#enum(&enum_def.name))
611+
}
578612
_ => None, // Not a class
579613
}
580614
} else {
@@ -1162,6 +1196,14 @@ pub fn typecheck_expression(
11621196
BamlValueWithMeta::String(value.clone(), (span.clone(), Some(TypeIR::string()))),
11631197
),
11641198
hir::Expression::Identifier(name, span) => {
1199+
// Enum access: let x = Shape.Rectangle
1200+
if let Some(enum_def) = context.enums.get(name) {
1201+
return thir::Expr::Var(
1202+
name.clone(),
1203+
(span.clone(), Some(TypeIR::r#enum(&enum_def.name))),
1204+
);
1205+
}
1206+
11651207
// Look up type in context
11661208
let var_type = context.get_type(name).cloned();
11671209
if var_type.is_none() {
@@ -1740,6 +1782,20 @@ pub fn typecheck_expression(
17401782
None
17411783
}
17421784
}
1785+
Some(TypeIR::Enum {
1786+
name: enum_name, ..
1787+
}) => {
1788+
// Look up field in enum definition
1789+
if let Some(enum_def) = context.enums.get(enum_name) {
1790+
Some(TypeIR::r#enum(&enum_def.name))
1791+
} else {
1792+
diagnostics.push_error(DatamodelError::new_validation_error(
1793+
&format!("Enum {enum_name} not found"),
1794+
span.clone(),
1795+
));
1796+
None
1797+
}
1798+
}
17431799
_ => {
17441800
diagnostics.push_error(DatamodelError::new_validation_error(
17451801
"Can only access fields on class instances",

engine/baml-lib/baml/tests/bytecode_files/array_access.baml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ function ArrayAccessWithVariable(arr: float[], idx: int) -> float {
4545
// 4 RETURN
4646
//
4747
// Class: std::Request with 3 fields
48+
// Enum std::HttpMethod
4849
// Function: std.Array.len
4950
//
5051
// Function: std.Map.len

engine/baml-lib/baml/tests/bytecode_files/assert.baml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ fn assertNotOk() -> int {
3232
// 5 RETURN
3333
//
3434
// Class: std::Request with 3 fields
35+
// Enum std::HttpMethod
3536
// Function: std.Array.len
3637
//
3738
// Function: std.Map.len

engine/baml-lib/baml/tests/bytecode_files/function_calls.baml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ function Nested(x: int) -> int {
5353
// 10 RETURN
5454
//
5555
// Class: std::Request with 3 fields
56+
// Enum std::HttpMethod
5657
// Function: std.Array.len
5758
//
5859
// Function: std.Map.len

engine/baml-lib/baml/tests/bytecode_files/literal_values.baml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ function ReturnArray() -> int[] {
4646
// 6 RETURN
4747
//
4848
// Class: std::Request with 3 fields
49+
// Enum std::HttpMethod
4950
// Function: std.Array.len
5051
//
5152
// Function: std.Map.len

engine/baml-lib/baml/tests/bytecode_files/llm_functions.baml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ function AnalyzeSentiment(text: string) -> Sentiment {
3636
// Function: AnalyzeSentiment
3737
//
3838
// Class: std::Request with 3 fields
39+
// Enum std::HttpMethod
40+
// Enum Sentiment
3941
// Function: std.Array.len
4042
//
4143
// Function: std.Map.len

0 commit comments

Comments
 (0)