Skip to content

Commit 2f7e6a8

Browse files
committed
Check/synth record and tuple patterns
1 parent 1d36cb4 commit 2f7e6a8

File tree

4 files changed

+386
-225
lines changed

4 files changed

+386
-225
lines changed

fathom/src/surface/elaboration.rs

Lines changed: 84 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use crate::surface::elaboration::reporting::Message;
3434
use crate::surface::{distillation, pretty, BinOp, FormatField, Item, Module, Pattern, Term};
3535

3636
mod order;
37+
mod patterns;
3738
mod reporting;
3839
mod unification;
3940

@@ -463,6 +464,70 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
463464
(labels.into(), filtered_fields)
464465
}
465466

467+
fn check_tuple_fields<F>(
468+
&mut self,
469+
range: ByteRange,
470+
fields: &[F],
471+
get_range: fn(&F) -> ByteRange,
472+
expected_labels: &[StringId],
473+
) -> Result<(), ()> {
474+
if fields.len() == expected_labels.len() {
475+
return Ok(());
476+
}
477+
478+
let mut found_labels = Vec::with_capacity(fields.len());
479+
let mut fields_iter = fields.iter().enumerate().peekable();
480+
let mut expected_labels_iter = expected_labels.iter();
481+
482+
// use the label names from the expected labels
483+
while let Some(((_, field), label)) =
484+
Option::zip(fields_iter.peek(), expected_labels_iter.next())
485+
{
486+
found_labels.push((get_range(field), *label));
487+
fields_iter.next();
488+
}
489+
490+
// use numeric labels for excess fields
491+
for (index, field) in fields_iter {
492+
found_labels.push((
493+
get_range(field),
494+
self.interner.borrow_mut().get_tuple_label(index),
495+
));
496+
}
497+
498+
self.push_message(Message::MismatchedFieldLabels {
499+
range,
500+
found_labels,
501+
expected_labels: expected_labels.to_vec(),
502+
});
503+
Err(())
504+
}
505+
506+
fn check_record_fields<F>(
507+
&mut self,
508+
range: ByteRange,
509+
fields: &[F],
510+
get_label: impl Fn(&F) -> (ByteRange, StringId),
511+
labels: &'arena [StringId],
512+
) -> Result<(), ()> {
513+
if fields.len() == labels.len()
514+
&& fields
515+
.iter()
516+
.zip(labels.iter())
517+
.all(|(field, type_label)| get_label(field).1 == *type_label)
518+
{
519+
return Ok(());
520+
}
521+
522+
// TODO: improve handling of duplicate labels
523+
self.push_message(Message::MismatchedFieldLabels {
524+
range,
525+
found_labels: fields.iter().map(get_label).collect(),
526+
expected_labels: labels.to_vec(),
527+
});
528+
Err(())
529+
}
530+
466531
/// Parse a source string into number, assuming an ASCII encoding.
467532
fn parse_ascii<T>(
468533
&mut self,
@@ -696,177 +761,6 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
696761
term
697762
}
698763

699-
/// Check that a pattern matches an expected type.
700-
fn check_pattern(
701-
&mut self,
702-
pattern: &Pattern<ByteRange>,
703-
expected_type: &ArcValue<'arena>,
704-
) -> CheckedPattern {
705-
match pattern {
706-
Pattern::Name(range, name) => CheckedPattern::Binder(*range, *name),
707-
Pattern::Placeholder(range) => CheckedPattern::Placeholder(*range),
708-
Pattern::StringLiteral(range, lit) => {
709-
let constant = match expected_type.match_prim_spine() {
710-
Some((Prim::U8Type, [])) => self.parse_ascii(*range, *lit, Const::U8),
711-
Some((Prim::U16Type, [])) => self.parse_ascii(*range, *lit, Const::U16),
712-
Some((Prim::U32Type, [])) => self.parse_ascii(*range, *lit, Const::U32),
713-
Some((Prim::U64Type, [])) => self.parse_ascii(*range, *lit, Const::U64),
714-
// Some((Prim::Array8Type, [len, _])) => todo!(),
715-
// Some((Prim::Array16Type, [len, _])) => todo!(),
716-
// Some((Prim::Array32Type, [len, _])) => todo!(),
717-
// Some((Prim::Array64Type, [len, _])) => todo!(),
718-
Some((Prim::ReportedError, _)) => None,
719-
_ => {
720-
let expected_type = self.pretty_print_value(expected_type);
721-
self.push_message(Message::StringLiteralNotSupported {
722-
range: *range,
723-
expected_type,
724-
});
725-
None
726-
}
727-
};
728-
729-
match constant {
730-
Some(constant) => CheckedPattern::ConstLit(*range, constant),
731-
None => CheckedPattern::ReportedError(*range),
732-
}
733-
}
734-
Pattern::NumberLiteral(range, lit) => {
735-
let constant = match expected_type.match_prim_spine() {
736-
Some((Prim::U8Type, [])) => self.parse_number_radix(*range, *lit, Const::U8),
737-
Some((Prim::U16Type, [])) => self.parse_number_radix(*range, *lit, Const::U16),
738-
Some((Prim::U32Type, [])) => self.parse_number_radix(*range, *lit, Const::U32),
739-
Some((Prim::U64Type, [])) => self.parse_number_radix(*range, *lit, Const::U64),
740-
Some((Prim::S8Type, [])) => self.parse_number(*range, *lit, Const::S8),
741-
Some((Prim::S16Type, [])) => self.parse_number(*range, *lit, Const::S16),
742-
Some((Prim::S32Type, [])) => self.parse_number(*range, *lit, Const::S32),
743-
Some((Prim::S64Type, [])) => self.parse_number(*range, *lit, Const::S64),
744-
Some((Prim::F32Type, [])) => self.parse_number(*range, *lit, Const::F32),
745-
Some((Prim::F64Type, [])) => self.parse_number(*range, *lit, Const::F64),
746-
Some((Prim::ReportedError, _)) => None,
747-
_ => {
748-
let expected_type = self.pretty_print_value(expected_type);
749-
self.push_message(Message::NumericLiteralNotSupported {
750-
range: *range,
751-
expected_type,
752-
});
753-
None
754-
}
755-
};
756-
757-
match constant {
758-
Some(constant) => CheckedPattern::ConstLit(*range, constant),
759-
None => CheckedPattern::ReportedError(*range),
760-
}
761-
}
762-
Pattern::BooleanLiteral(range, boolean) => {
763-
let constant = match expected_type.match_prim_spine() {
764-
Some((Prim::BoolType, [])) => match *boolean {
765-
true => Some(Const::Bool(true)),
766-
false => Some(Const::Bool(false)),
767-
},
768-
_ => {
769-
self.push_message(Message::BooleanLiteralNotSupported { range: *range });
770-
None
771-
}
772-
};
773-
774-
match constant {
775-
Some(constant) => CheckedPattern::ConstLit(*range, constant),
776-
None => CheckedPattern::ReportedError(*range),
777-
}
778-
}
779-
Pattern::RecordLiteral(_, _) => todo!(),
780-
Pattern::Tuple(_, _) => todo!(),
781-
}
782-
}
783-
784-
/// Synthesize the type of a pattern.
785-
fn synth_pattern(
786-
&mut self,
787-
pattern: &Pattern<ByteRange>,
788-
) -> (CheckedPattern, ArcValue<'arena>) {
789-
match pattern {
790-
Pattern::Name(range, name) => {
791-
let source = MetaSource::NamedPatternType(*range, *name);
792-
let r#type = self.push_unsolved_type(source);
793-
(CheckedPattern::Binder(*range, *name), r#type)
794-
}
795-
Pattern::Placeholder(range) => {
796-
let source = MetaSource::PlaceholderPatternType(*range);
797-
let r#type = self.push_unsolved_type(source);
798-
(CheckedPattern::Placeholder(*range), r#type)
799-
}
800-
Pattern::StringLiteral(range, _) => {
801-
self.push_message(Message::AmbiguousStringLiteral { range: *range });
802-
let source = MetaSource::ReportedErrorType(*range);
803-
let r#type = self.push_unsolved_type(source);
804-
(CheckedPattern::ReportedError(*range), r#type)
805-
}
806-
Pattern::NumberLiteral(range, _) => {
807-
self.push_message(Message::AmbiguousNumericLiteral { range: *range });
808-
let source = MetaSource::ReportedErrorType(*range);
809-
let r#type = self.push_unsolved_type(source);
810-
(CheckedPattern::ReportedError(*range), r#type)
811-
}
812-
Pattern::BooleanLiteral(range, val) => {
813-
let r#const = Const::Bool(*val);
814-
let r#type = self.bool_type.clone();
815-
(CheckedPattern::ConstLit(*range, r#const), r#type)
816-
}
817-
Pattern::RecordLiteral(_, _) => todo!(),
818-
Pattern::Tuple(_, _) => todo!(),
819-
}
820-
}
821-
822-
/// Check that the type of an annotated pattern matches an expected type.
823-
fn check_ann_pattern(
824-
&mut self,
825-
pattern: &Pattern<ByteRange>,
826-
r#type: Option<&Term<'_, ByteRange>>,
827-
expected_type: &ArcValue<'arena>,
828-
) -> CheckedPattern {
829-
match r#type {
830-
None => self.check_pattern(pattern, expected_type),
831-
Some(r#type) => {
832-
let range = r#type.range();
833-
let r#type = self.check(r#type, &self.universe.clone());
834-
let r#type = self.eval_env().eval(&r#type);
835-
836-
match self.unification_context().unify(&r#type, expected_type) {
837-
Ok(()) => self.check_pattern(pattern, &r#type),
838-
Err(error) => {
839-
let lhs = self.pretty_print_value(&r#type);
840-
let rhs = self.pretty_print_value(expected_type);
841-
self.push_message(Message::FailedToUnify {
842-
range,
843-
lhs,
844-
rhs,
845-
error,
846-
});
847-
CheckedPattern::ReportedError(range)
848-
}
849-
}
850-
}
851-
}
852-
}
853-
854-
/// Synthesize the type of an annotated pattern.
855-
fn synth_ann_pattern(
856-
&mut self,
857-
pattern: &Pattern<ByteRange>,
858-
r#type: Option<&Term<'_, ByteRange>>,
859-
) -> (CheckedPattern, ArcValue<'arena>) {
860-
match r#type {
861-
None => self.synth_pattern(pattern),
862-
Some(r#type) => {
863-
let r#type = self.check(r#type, &self.universe.clone());
864-
let type_value = self.eval_env().eval(&r#type);
865-
(self.check_pattern(pattern, &type_value), type_value)
866-
}
867-
}
868-
}
869-
870764
/// Push a local definition onto the context.
871765
/// The supplied `pattern` is expected to be irrefutable.
872766
fn push_local_def(
@@ -886,6 +780,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
886780
None
887781
}
888782
CheckedPattern::ReportedError(_) => None,
783+
CheckedPattern::RecordLit(_, _, _) => todo!(),
889784
};
890785

891786
self.local_env.push_def(name, expr, r#type);
@@ -911,6 +806,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
911806
None
912807
}
913808
CheckedPattern::ReportedError(_) => None,
809+
CheckedPattern::RecordLit(_, _, _) => todo!(),
914810
};
915811

916812
let expr = self.local_env.push_param(name, r#type);
@@ -970,18 +866,10 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
970866
self.check_fun_lit(*range, patterns, body_expr, &expected_type)
971867
}
972868
(Term::RecordLiteral(range, expr_fields), Value::RecordType(labels, types)) => {
973-
// TODO: improve handling of duplicate labels
974-
if expr_fields.len() != labels.len()
975-
|| Iterator::zip(expr_fields.iter(), labels.iter())
976-
.any(|(expr_field, type_label)| expr_field.label.1 != *type_label)
869+
if self
870+
.check_record_fields(*range, expr_fields, |field| field.label, labels)
871+
.is_err()
977872
{
978-
self.push_message(Message::MismatchedFieldLabels {
979-
range: *range,
980-
expr_labels: (expr_fields.iter())
981-
.map(|expr_field| expr_field.label)
982-
.collect(),
983-
type_labels: labels.to_vec(),
984-
});
985873
return core::Term::Prim(range.into(), Prim::ReportedError);
986874
}
987875

@@ -1045,33 +933,11 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
1045933
core::Term::FormatRecord(range.into(), labels, formats)
1046934
}
1047935
(Term::Tuple(range, elem_exprs), Value::RecordType(labels, types)) => {
1048-
if elem_exprs.len() != labels.len() {
1049-
let mut expr_labels = Vec::with_capacity(elem_exprs.len());
1050-
let mut elem_exprs = elem_exprs.iter().enumerate().peekable();
1051-
let mut label_iter = labels.iter();
1052-
1053-
// use the label names from the expected type
1054-
while let Some(((_, elem_expr), label)) =
1055-
Option::zip(elem_exprs.peek(), label_iter.next())
1056-
{
1057-
expr_labels.push((elem_expr.range(), *label));
1058-
elem_exprs.next();
1059-
}
1060-
1061-
// use numeric labels for excess elems
1062-
for (index, elem_expr) in elem_exprs {
1063-
expr_labels.push((
1064-
elem_expr.range(),
1065-
self.interner.borrow_mut().get_tuple_label(index),
1066-
));
1067-
}
1068-
1069-
self.push_message(Message::MismatchedFieldLabels {
1070-
range: *range,
1071-
expr_labels,
1072-
type_labels: labels.to_vec(),
1073-
});
1074-
return core::Term::Prim(range.into(), Prim::ReportedError);
936+
if self
937+
.check_tuple_fields(*range, elem_exprs, |expr| expr.range(), labels)
938+
.is_err()
939+
{
940+
return core::Term::error(range.into());
1075941
}
1076942

1077943
let mut types = types.clone();
@@ -2027,6 +1893,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
20271893
self.elab_match_unreachable(match_info, equations);
20281894
core::Term::Prim(range.into(), Prim::ReportedError)
20291895
}
1896+
CheckedPattern::RecordLit(_, _, _) => todo!(),
20301897
}
20311898
}
20321899
None => self.elab_match_absurd(is_reachable, match_info),
@@ -2116,6 +1983,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
21161983
default_branch = (None, self.scope.to_scope(default_expr) as &_);
21171984
self.local_env.pop();
21181985
}
1986+
CheckedPattern::RecordLit(_, _, _) => todo!(),
21191987
};
21201988

21211989
// A default pattern was found, check any unreachable patterns.
@@ -2196,15 +2064,17 @@ impl_from_str_radix!(u64);
21962064

21972065
/// Simple patterns that have had some initial elaboration performed on them
21982066
#[derive(Debug)]
2199-
enum CheckedPattern {
2200-
/// Pattern that binds local variable
2201-
Binder(ByteRange, StringId),
2067+
enum CheckedPattern<'arena> {
2068+
/// Error sentinel
2069+
ReportedError(ByteRange),
22022070
/// Placeholder patterns that match everything
22032071
Placeholder(ByteRange),
2072+
/// Pattern that binds local variable
2073+
Binder(ByteRange, StringId),
22042074
/// Constant literals
22052075
ConstLit(ByteRange, Const),
2206-
/// Error sentinel
2207-
ReportedError(ByteRange),
2076+
/// Record literals
2077+
RecordLit(ByteRange, &'arena [StringId], &'arena [Self]),
22082078
}
22092079

22102080
/// Scrutinee of a match expression

0 commit comments

Comments
 (0)