Skip to content
302 changes: 184 additions & 118 deletions engine/baml-compiler/src/codegen.rs

Large diffs are not rendered by default.

57 changes: 41 additions & 16 deletions engine/baml-compiler/src/hir/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
//!
//! This files contains the convertions between Baml AST nodes to HIR nodes.

use std::collections::HashSet;

use baml_types::{
ir_type::TypeGeneric,
type_meta::{self, base::StreamingBehavior},
Constraint, ConstraintLevel, TypeIR, TypeValue,
};
Expand Down Expand Up @@ -48,6 +51,38 @@ impl Hir {
}
}

let enums = HashSet::<&str>::from_iter(hir.enums.iter().map(|e| e.name.as_str()));

let param_type: fn(&mut Parameter) -> &mut TypeIR = |p| &mut p.r#type;

// Patch return types because only here in the code we have the full
// context for enums.
hir.expr_functions
.iter_mut()
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type))
.chain(
hir.llm_functions
.iter_mut()
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type)),
)
.chain(hir.classes.iter_mut().flat_map(|c| {
c.methods
.iter_mut()
.map(|f| (f.parameters.iter_mut().map(param_type), &mut f.return_type))
}))
.flat_map(|(parameters, return_type)| parameters.chain(std::iter::once(return_type)))
.for_each(|ty| match ty {
TypeIR::Class { name, meta, .. } if enums.contains(name.as_str()) => {
*ty = TypeIR::Enum {
name: name.to_owned(),
dynamic: false, // TODO: How to know if it's dynamic.
meta: meta.clone(),
}
}

_ => {}
});

hir
}
}
Expand Down Expand Up @@ -125,22 +160,12 @@ pub fn type_ir_from_ast(type_: &ast::FieldType) -> TypeIR {
};

match type_ {
ast::FieldType::Symbol(_, name, _) => {
if name.name().starts_with("Enum") {
TypeIR::Enum {
name: name.name().to_string(),
dynamic: false,
meta,
}
} else {
TypeIR::Class {
name: name.name().to_string(),
mode: baml_types::ir_type::StreamingMode::NonStreaming,
dynamic: false,
meta,
}
}
}
ast::FieldType::Symbol(_, name, _) => TypeIR::Class {
name: name.name().to_string(),
mode: baml_types::ir_type::StreamingMode::NonStreaming,
dynamic: false,
meta,
},
ast::FieldType::Primitive(_, prim, _, _) => TypeIR::Primitive(*prim, meta),
ast::FieldType::List(_, inner, dims, _, _) => {
// Respect multi-dimensional arrays (e.g., int[][] has dims=2)
Expand Down
12 changes: 11 additions & 1 deletion engine/baml-compiler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@ pub mod test {
use internal_baml_diagnostics::Diagnostics;
use internal_baml_parser_database::{parse_and_diagnostics, ParserDatabase};

use crate::{hir, thir};

/// Shim helper function for testing.
pub fn ast(source: &'static str) -> anyhow::Result<ParserDatabase> {
let (parser_db, diagnostics) = parse_and_diagnostics(source)?;
let (parser_db, mut diagnostics) = parse_and_diagnostics(source)?;

if diagnostics.has_errors() {
let errors = diagnostics.to_pretty_string();
anyhow::bail!("{errors}");
}

// Here because of cycle dependencies between crates and shit.
// TODO: We're building this like 3 different times, needs refactoring.
thir::typecheck::typecheck(&hir::Hir::from_ast(&parser_db.ast), &mut diagnostics);
if diagnostics.has_errors() {
let errors = diagnostics.to_pretty_string();
anyhow::bail!("{errors}");
Expand Down
10 changes: 9 additions & 1 deletion engine/baml-compiler/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
///
use baml_types::ir_type::TypeIR;

use crate::hir::{self, AssignOp, BinaryOperator, Enum, LlmFunction, UnaryOperator};
use crate::hir::{self, AssignOp, BinaryOperator, LlmFunction, UnaryOperator};

pub mod interpret;
pub mod typecheck;
Expand Down Expand Up @@ -54,6 +54,14 @@ pub struct Class<T> {
pub span: Span,
}

#[derive(Clone, Debug)]
pub struct Enum {
pub name: String,
pub variants: Vec<hir::EnumVariant>,
pub span: Span,
pub ty: TypeIR, // TODO: Used for type checking, but do we need this?
}

/// A BAML expression term.
/// T is the type of the metadata.
#[derive(Debug, Clone)]
Expand Down
62 changes: 59 additions & 3 deletions engine/baml-compiler/src/thir/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub fn typecheck_returning_context<'a>(
.map(|c| (c.name.clone(), c))
.collect();

let enums = hir
let enums: BamlMap<String, hir::Enum> = hir
.enums
.clone()
.into_iter()
Expand All @@ -54,6 +54,7 @@ pub fn typecheck_returning_context<'a>(
// Create typing context with all functions
let mut typing_context = TypeContext::new();
typing_context.classes.extend(classes.clone());
typing_context.enums.extend(enums.clone());

// Add expr functions to typing context
for func in &hir.expr_functions {
Expand Down Expand Up @@ -271,11 +272,32 @@ pub fn typecheck_returning_context<'a>(
})
.collect();

// TODO: Those are HIR enums, figure out if there's something different we
// would need in a THIR enum? Does it need a "type"?.
let thir_enums = enums
.iter()
.map(|(name, enum_def)| {
(
name.clone(),
thir::Enum {
name: enum_def.name.clone(),
variants: enum_def.variants.clone(),
span: enum_def.span.clone(),
ty: TypeIR::Enum {
name: enum_def.name.clone(),
dynamic: false,
meta: Default::default(),
},
},
)
})
.collect();

(
THir {
llm_functions: hir.llm_functions.clone(),
classes: thir_classes,
enums,
enums: thir_enums,
expr_functions,
global_assignments: BamlMap::new(),
},
Expand All @@ -302,6 +324,7 @@ pub struct TypeContext<'func> {
// Variables in scope with mutability info
pub vars: BamlMap<String, VarInfo>,
pub classes: BamlMap<String, hir::Class>,
pub enums: BamlMap<String, hir::Enum>,
// Used for knowing whether `break` and `continue` are inside a loop or not.
pub is_inside_loop: bool,

Expand Down Expand Up @@ -336,6 +359,7 @@ impl TypeContext<'_> {
symbols: BamlMap::new(),
vars,
classes: BamlMap::new(),
enums: BamlMap::new(),
is_inside_loop: false,
function_return_type: None,
}
Expand Down Expand Up @@ -428,7 +452,9 @@ impl TypeContext<'_> {
}
hir::Expression::Identifier(name, _) => {
// Look up type in context
self.get_type(name).cloned()
self.get_type(name)
.cloned()
.or_else(|| self.enums.get(name).map(|e| TypeIR::r#enum(&e.name)))
}
hir::Expression::Array(items, _) => {
// Infer array type from first item
Expand Down Expand Up @@ -575,6 +601,14 @@ impl TypeContext<'_> {
None
}
}
TypeIR::Enum {
name: enum_name, ..
} => {
// Look up field in enum definition
self.enums
.get(&enum_name)
.map(|enum_def| TypeIR::r#enum(&enum_def.name))
}
_ => None, // Not a class
}
} else {
Expand Down Expand Up @@ -1162,6 +1196,14 @@ pub fn typecheck_expression(
BamlValueWithMeta::String(value.clone(), (span.clone(), Some(TypeIR::string()))),
),
hir::Expression::Identifier(name, span) => {
// Enum access: let x = Shape.Rectangle
if let Some(enum_def) = context.enums.get(name) {
return thir::Expr::Var(
name.clone(),
(span.clone(), Some(TypeIR::r#enum(&enum_def.name))),
);
}

// Look up type in context
let var_type = context.get_type(name).cloned();
if var_type.is_none() {
Expand Down Expand Up @@ -1740,6 +1782,20 @@ pub fn typecheck_expression(
None
}
}
Some(TypeIR::Enum {
name: enum_name, ..
}) => {
// Look up field in enum definition
if let Some(enum_def) = context.enums.get(enum_name) {
Some(TypeIR::r#enum(&enum_def.name))
} else {
diagnostics.push_error(DatamodelError::new_validation_error(
&format!("Enum {enum_name} not found"),
span.clone(),
));
None
}
}
_ => {
diagnostics.push_error(DatamodelError::new_validation_error(
"Can only access fields on class instances",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ function ArrayAccessWithVariable(arr: float[], idx: int) -> float {
// 4 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
1 change: 1 addition & 0 deletions engine/baml-lib/baml/tests/bytecode_files/assert.baml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn assertNotOk() -> int {
// 5 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ function Nested(x: int) -> int {
// 10 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function ReturnArray() -> int[] {
// 6 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
2 changes: 2 additions & 0 deletions engine/baml-lib/baml/tests/bytecode_files/llm_functions.baml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ function AnalyzeSentiment(text: string) -> Sentiment {
// Function: AnalyzeSentiment
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Enum Sentiment
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
1 change: 1 addition & 0 deletions engine/baml-lib/baml/tests/bytecode_files/loops/break.baml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ fn Nested() -> int {
// 22 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
1 change: 1 addition & 0 deletions engine/baml-lib/baml/tests/bytecode_files/loops/c_for.baml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ fn Nothing() -> int {
// 4 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ fn ContinueNested() -> int {
// 18 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
11 changes: 6 additions & 5 deletions engine/baml-lib/baml/tests/bytecode_files/loops/for.baml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn NestedFor(as: int[], bs: int[]) -> int {
// 1 0 LOAD_CONST 0 (0)
//
// 3 1 LOAD_VAR 1 (xs)
// 2 LOAD_GLOBAL 5 (<fn std.Array.len>)
// 2 LOAD_GLOBAL 6 (<fn std.Array.len>)
// 3 LOAD_VAR 3 (__baml for loop iterated array 0)
// 4 CALL 1
// 5 LOAD_CONST 0 (0)
Expand Down Expand Up @@ -84,7 +84,7 @@ fn NestedFor(as: int[], bs: int[]) -> int {
// 11 0 LOAD_CONST 0 (0)
//
// 13 1 LOAD_VAR 1 (xs)
// 2 LOAD_GLOBAL 5 (<fn std.Array.len>)
// 2 LOAD_GLOBAL 6 (<fn std.Array.len>)
// 3 LOAD_VAR 3 (__baml for loop iterated array 1)
// 4 CALL 1
// 5 LOAD_CONST 0 (0)
Expand Down Expand Up @@ -127,7 +127,7 @@ fn NestedFor(as: int[], bs: int[]) -> int {
// 24 0 LOAD_CONST 0 (0)
//
// 26 1 LOAD_VAR 1 (xs)
// 2 LOAD_GLOBAL 5 (<fn std.Array.len>)
// 2 LOAD_GLOBAL 6 (<fn std.Array.len>)
// 3 LOAD_VAR 3 (__baml for loop iterated array 2)
// 4 CALL 1
// 5 LOAD_CONST 0 (0)
Expand Down Expand Up @@ -170,7 +170,7 @@ fn NestedFor(as: int[], bs: int[]) -> int {
// 38 0 LOAD_CONST 0 (0)
//
// 40 1 LOAD_VAR 1 (as)
// 2 LOAD_GLOBAL 5 (<fn std.Array.len>)
// 2 LOAD_GLOBAL 6 (<fn std.Array.len>)
// 3 LOAD_VAR 4 (__baml for loop iterated array 3)
// 4 CALL 1
// 5 LOAD_CONST 0 (0)
Expand All @@ -188,7 +188,7 @@ fn NestedFor(as: int[], bs: int[]) -> int {
// 17 STORE_VAR 6 (__baml for loop index 3)
//
// 41 18 LOAD_VAR 2 (bs)
// 19 LOAD_GLOBAL 5 (<fn std.Array.len>)
// 19 LOAD_GLOBAL 6 (<fn std.Array.len>)
// 20 LOAD_VAR 8 (__baml for loop iterated array 4)
// 21 CALL 1
// 22 LOAD_CONST 0 (0)
Expand Down Expand Up @@ -224,6 +224,7 @@ fn NestedFor(as: int[], bs: int[]) -> int {
// 50 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ fn WhileWithScopes() -> int {
// 33 RETURN
//
// Class: std::Request with 3 fields
// Enum std::HttpMethod
// Function: std.Array.len
//
// Function: std.Map.len
Expand Down
Loading
Loading