@@ -34,6 +34,7 @@ use crate::surface::elaboration::reporting::Message;
34
34
use crate :: surface:: { distillation, pretty, BinOp , FormatField , Item , Module , Pattern , Term } ;
35
35
36
36
mod order;
37
+ mod patterns;
37
38
mod reporting;
38
39
mod unification;
39
40
@@ -463,6 +464,70 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
463
464
( labels. into ( ) , filtered_fields)
464
465
}
465
466
467
+ fn check_tuple_fields < F > (
468
+ & mut self ,
469
+ range : ByteRange ,
470
+ fields : & [ F ] ,
471
+ get_range : fn ( & F ) -> ByteRange ,
472
+ expected_labels : & [ StringId ] ,
473
+ ) -> Result < ( ) , ( ) > {
474
+ if fields. len ( ) == expected_labels. len ( ) {
475
+ return Ok ( ( ) ) ;
476
+ }
477
+
478
+ let mut found_labels = Vec :: with_capacity ( fields. len ( ) ) ;
479
+ let mut fields_iter = fields. iter ( ) . enumerate ( ) . peekable ( ) ;
480
+ let mut expected_labels_iter = expected_labels. iter ( ) ;
481
+
482
+ // use the label names from the expected labels
483
+ while let Some ( ( ( _, field) , label) ) =
484
+ Option :: zip ( fields_iter. peek ( ) , expected_labels_iter. next ( ) )
485
+ {
486
+ found_labels. push ( ( get_range ( field) , * label) ) ;
487
+ fields_iter. next ( ) ;
488
+ }
489
+
490
+ // use numeric labels for excess fields
491
+ for ( index, field) in fields_iter {
492
+ found_labels. push ( (
493
+ get_range ( field) ,
494
+ self . interner . borrow_mut ( ) . get_tuple_label ( index) ,
495
+ ) ) ;
496
+ }
497
+
498
+ self . push_message ( Message :: MismatchedFieldLabels {
499
+ range,
500
+ found_labels,
501
+ expected_labels : expected_labels. to_vec ( ) ,
502
+ } ) ;
503
+ Err ( ( ) )
504
+ }
505
+
506
+ fn check_record_fields < F > (
507
+ & mut self ,
508
+ range : ByteRange ,
509
+ fields : & [ F ] ,
510
+ get_label : impl Fn ( & F ) -> ( ByteRange , StringId ) ,
511
+ labels : & ' arena [ StringId ] ,
512
+ ) -> Result < ( ) , ( ) > {
513
+ if fields. len ( ) == labels. len ( )
514
+ && fields
515
+ . iter ( )
516
+ . zip ( labels. iter ( ) )
517
+ . all ( |( field, type_label) | get_label ( field) . 1 == * type_label)
518
+ {
519
+ return Ok ( ( ) ) ;
520
+ }
521
+
522
+ // TODO: improve handling of duplicate labels
523
+ self . push_message ( Message :: MismatchedFieldLabels {
524
+ range,
525
+ found_labels : fields. iter ( ) . map ( get_label) . collect ( ) ,
526
+ expected_labels : labels. to_vec ( ) ,
527
+ } ) ;
528
+ Err ( ( ) )
529
+ }
530
+
466
531
/// Parse a source string into number, assuming an ASCII encoding.
467
532
fn parse_ascii < T > (
468
533
& mut self ,
@@ -696,177 +761,6 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
696
761
term
697
762
}
698
763
699
- /// Check that a pattern matches an expected type.
700
- fn check_pattern (
701
- & mut self ,
702
- pattern : & Pattern < ByteRange > ,
703
- expected_type : & ArcValue < ' arena > ,
704
- ) -> CheckedPattern {
705
- match pattern {
706
- Pattern :: Name ( range, name) => CheckedPattern :: Binder ( * range, * name) ,
707
- Pattern :: Placeholder ( range) => CheckedPattern :: Placeholder ( * range) ,
708
- Pattern :: StringLiteral ( range, lit) => {
709
- let constant = match expected_type. match_prim_spine ( ) {
710
- Some ( ( Prim :: U8Type , [ ] ) ) => self . parse_ascii ( * range, * lit, Const :: U8 ) ,
711
- Some ( ( Prim :: U16Type , [ ] ) ) => self . parse_ascii ( * range, * lit, Const :: U16 ) ,
712
- Some ( ( Prim :: U32Type , [ ] ) ) => self . parse_ascii ( * range, * lit, Const :: U32 ) ,
713
- Some ( ( Prim :: U64Type , [ ] ) ) => self . parse_ascii ( * range, * lit, Const :: U64 ) ,
714
- // Some((Prim::Array8Type, [len, _])) => todo!(),
715
- // Some((Prim::Array16Type, [len, _])) => todo!(),
716
- // Some((Prim::Array32Type, [len, _])) => todo!(),
717
- // Some((Prim::Array64Type, [len, _])) => todo!(),
718
- Some ( ( Prim :: ReportedError , _) ) => None ,
719
- _ => {
720
- let expected_type = self . pretty_print_value ( expected_type) ;
721
- self . push_message ( Message :: StringLiteralNotSupported {
722
- range : * range,
723
- expected_type,
724
- } ) ;
725
- None
726
- }
727
- } ;
728
-
729
- match constant {
730
- Some ( constant) => CheckedPattern :: ConstLit ( * range, constant) ,
731
- None => CheckedPattern :: ReportedError ( * range) ,
732
- }
733
- }
734
- Pattern :: NumberLiteral ( range, lit) => {
735
- let constant = match expected_type. match_prim_spine ( ) {
736
- Some ( ( Prim :: U8Type , [ ] ) ) => self . parse_number_radix ( * range, * lit, Const :: U8 ) ,
737
- Some ( ( Prim :: U16Type , [ ] ) ) => self . parse_number_radix ( * range, * lit, Const :: U16 ) ,
738
- Some ( ( Prim :: U32Type , [ ] ) ) => self . parse_number_radix ( * range, * lit, Const :: U32 ) ,
739
- Some ( ( Prim :: U64Type , [ ] ) ) => self . parse_number_radix ( * range, * lit, Const :: U64 ) ,
740
- Some ( ( Prim :: S8Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: S8 ) ,
741
- Some ( ( Prim :: S16Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: S16 ) ,
742
- Some ( ( Prim :: S32Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: S32 ) ,
743
- Some ( ( Prim :: S64Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: S64 ) ,
744
- Some ( ( Prim :: F32Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: F32 ) ,
745
- Some ( ( Prim :: F64Type , [ ] ) ) => self . parse_number ( * range, * lit, Const :: F64 ) ,
746
- Some ( ( Prim :: ReportedError , _) ) => None ,
747
- _ => {
748
- let expected_type = self . pretty_print_value ( expected_type) ;
749
- self . push_message ( Message :: NumericLiteralNotSupported {
750
- range : * range,
751
- expected_type,
752
- } ) ;
753
- None
754
- }
755
- } ;
756
-
757
- match constant {
758
- Some ( constant) => CheckedPattern :: ConstLit ( * range, constant) ,
759
- None => CheckedPattern :: ReportedError ( * range) ,
760
- }
761
- }
762
- Pattern :: BooleanLiteral ( range, boolean) => {
763
- let constant = match expected_type. match_prim_spine ( ) {
764
- Some ( ( Prim :: BoolType , [ ] ) ) => match * boolean {
765
- true => Some ( Const :: Bool ( true ) ) ,
766
- false => Some ( Const :: Bool ( false ) ) ,
767
- } ,
768
- _ => {
769
- self . push_message ( Message :: BooleanLiteralNotSupported { range : * range } ) ;
770
- None
771
- }
772
- } ;
773
-
774
- match constant {
775
- Some ( constant) => CheckedPattern :: ConstLit ( * range, constant) ,
776
- None => CheckedPattern :: ReportedError ( * range) ,
777
- }
778
- }
779
- Pattern :: RecordLiteral ( _, _) => todo ! ( ) ,
780
- Pattern :: Tuple ( _, _) => todo ! ( ) ,
781
- }
782
- }
783
-
784
- /// Synthesize the type of a pattern.
785
- fn synth_pattern (
786
- & mut self ,
787
- pattern : & Pattern < ByteRange > ,
788
- ) -> ( CheckedPattern , ArcValue < ' arena > ) {
789
- match pattern {
790
- Pattern :: Name ( range, name) => {
791
- let source = MetaSource :: NamedPatternType ( * range, * name) ;
792
- let r#type = self . push_unsolved_type ( source) ;
793
- ( CheckedPattern :: Binder ( * range, * name) , r#type)
794
- }
795
- Pattern :: Placeholder ( range) => {
796
- let source = MetaSource :: PlaceholderPatternType ( * range) ;
797
- let r#type = self . push_unsolved_type ( source) ;
798
- ( CheckedPattern :: Placeholder ( * range) , r#type)
799
- }
800
- Pattern :: StringLiteral ( range, _) => {
801
- self . push_message ( Message :: AmbiguousStringLiteral { range : * range } ) ;
802
- let source = MetaSource :: ReportedErrorType ( * range) ;
803
- let r#type = self . push_unsolved_type ( source) ;
804
- ( CheckedPattern :: ReportedError ( * range) , r#type)
805
- }
806
- Pattern :: NumberLiteral ( range, _) => {
807
- self . push_message ( Message :: AmbiguousNumericLiteral { range : * range } ) ;
808
- let source = MetaSource :: ReportedErrorType ( * range) ;
809
- let r#type = self . push_unsolved_type ( source) ;
810
- ( CheckedPattern :: ReportedError ( * range) , r#type)
811
- }
812
- Pattern :: BooleanLiteral ( range, val) => {
813
- let r#const = Const :: Bool ( * val) ;
814
- let r#type = self . bool_type . clone ( ) ;
815
- ( CheckedPattern :: ConstLit ( * range, r#const) , r#type)
816
- }
817
- Pattern :: RecordLiteral ( _, _) => todo ! ( ) ,
818
- Pattern :: Tuple ( _, _) => todo ! ( ) ,
819
- }
820
- }
821
-
822
- /// Check that the type of an annotated pattern matches an expected type.
823
- fn check_ann_pattern (
824
- & mut self ,
825
- pattern : & Pattern < ByteRange > ,
826
- r#type : Option < & Term < ' _ , ByteRange > > ,
827
- expected_type : & ArcValue < ' arena > ,
828
- ) -> CheckedPattern {
829
- match r#type {
830
- None => self . check_pattern ( pattern, expected_type) ,
831
- Some ( r#type) => {
832
- let range = r#type. range ( ) ;
833
- let r#type = self . check ( r#type, & self . universe . clone ( ) ) ;
834
- let r#type = self . eval_env ( ) . eval ( & r#type) ;
835
-
836
- match self . unification_context ( ) . unify ( & r#type, expected_type) {
837
- Ok ( ( ) ) => self . check_pattern ( pattern, & r#type) ,
838
- Err ( error) => {
839
- let lhs = self . pretty_print_value ( & r#type) ;
840
- let rhs = self . pretty_print_value ( expected_type) ;
841
- self . push_message ( Message :: FailedToUnify {
842
- range,
843
- lhs,
844
- rhs,
845
- error,
846
- } ) ;
847
- CheckedPattern :: ReportedError ( range)
848
- }
849
- }
850
- }
851
- }
852
- }
853
-
854
- /// Synthesize the type of an annotated pattern.
855
- fn synth_ann_pattern (
856
- & mut self ,
857
- pattern : & Pattern < ByteRange > ,
858
- r#type : Option < & Term < ' _ , ByteRange > > ,
859
- ) -> ( CheckedPattern , ArcValue < ' arena > ) {
860
- match r#type {
861
- None => self . synth_pattern ( pattern) ,
862
- Some ( r#type) => {
863
- let r#type = self . check ( r#type, & self . universe . clone ( ) ) ;
864
- let type_value = self . eval_env ( ) . eval ( & r#type) ;
865
- ( self . check_pattern ( pattern, & type_value) , type_value)
866
- }
867
- }
868
- }
869
-
870
764
/// Push a local definition onto the context.
871
765
/// The supplied `pattern` is expected to be irrefutable.
872
766
fn push_local_def (
@@ -886,6 +780,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
886
780
None
887
781
}
888
782
CheckedPattern :: ReportedError ( _) => None ,
783
+ CheckedPattern :: RecordLit ( _, _, _) => todo ! ( ) ,
889
784
} ;
890
785
891
786
self . local_env . push_def ( name, expr, r#type) ;
@@ -911,6 +806,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
911
806
None
912
807
}
913
808
CheckedPattern :: ReportedError ( _) => None ,
809
+ CheckedPattern :: RecordLit ( _, _, _) => todo ! ( ) ,
914
810
} ;
915
811
916
812
let expr = self . local_env . push_param ( name, r#type) ;
@@ -970,18 +866,10 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
970
866
self . check_fun_lit ( * range, patterns, body_expr, & expected_type)
971
867
}
972
868
( Term :: RecordLiteral ( range, expr_fields) , Value :: RecordType ( labels, types) ) => {
973
- // TODO: improve handling of duplicate labels
974
- if expr_fields. len ( ) != labels. len ( )
975
- || Iterator :: zip ( expr_fields. iter ( ) , labels. iter ( ) )
976
- . any ( |( expr_field, type_label) | expr_field. label . 1 != * type_label)
869
+ if self
870
+ . check_record_fields ( * range, expr_fields, |field| field. label , labels)
871
+ . is_err ( )
977
872
{
978
- self . push_message ( Message :: MismatchedFieldLabels {
979
- range : * range,
980
- expr_labels : ( expr_fields. iter ( ) )
981
- . map ( |expr_field| expr_field. label )
982
- . collect ( ) ,
983
- type_labels : labels. to_vec ( ) ,
984
- } ) ;
985
873
return core:: Term :: Prim ( range. into ( ) , Prim :: ReportedError ) ;
986
874
}
987
875
@@ -1045,33 +933,11 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
1045
933
core:: Term :: FormatRecord ( range. into ( ) , labels, formats)
1046
934
}
1047
935
( Term :: Tuple ( range, elem_exprs) , Value :: RecordType ( labels, types) ) => {
1048
- if elem_exprs. len ( ) != labels. len ( ) {
1049
- let mut expr_labels = Vec :: with_capacity ( elem_exprs. len ( ) ) ;
1050
- let mut elem_exprs = elem_exprs. iter ( ) . enumerate ( ) . peekable ( ) ;
1051
- let mut label_iter = labels. iter ( ) ;
1052
-
1053
- // use the label names from the expected type
1054
- while let Some ( ( ( _, elem_expr) , label) ) =
1055
- Option :: zip ( elem_exprs. peek ( ) , label_iter. next ( ) )
1056
- {
1057
- expr_labels. push ( ( elem_expr. range ( ) , * label) ) ;
1058
- elem_exprs. next ( ) ;
1059
- }
1060
-
1061
- // use numeric labels for excess elems
1062
- for ( index, elem_expr) in elem_exprs {
1063
- expr_labels. push ( (
1064
- elem_expr. range ( ) ,
1065
- self . interner . borrow_mut ( ) . get_tuple_label ( index) ,
1066
- ) ) ;
1067
- }
1068
-
1069
- self . push_message ( Message :: MismatchedFieldLabels {
1070
- range : * range,
1071
- expr_labels,
1072
- type_labels : labels. to_vec ( ) ,
1073
- } ) ;
1074
- return core:: Term :: Prim ( range. into ( ) , Prim :: ReportedError ) ;
936
+ if self
937
+ . check_tuple_fields ( * range, elem_exprs, |expr| expr. range ( ) , labels)
938
+ . is_err ( )
939
+ {
940
+ return core:: Term :: error ( range. into ( ) ) ;
1075
941
}
1076
942
1077
943
let mut types = types. clone ( ) ;
@@ -2027,6 +1893,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
2027
1893
self . elab_match_unreachable ( match_info, equations) ;
2028
1894
core:: Term :: Prim ( range. into ( ) , Prim :: ReportedError )
2029
1895
}
1896
+ CheckedPattern :: RecordLit ( _, _, _) => todo ! ( ) ,
2030
1897
}
2031
1898
}
2032
1899
None => self . elab_match_absurd ( is_reachable, match_info) ,
@@ -2116,6 +1983,7 @@ impl<'interner, 'arena> Context<'interner, 'arena> {
2116
1983
default_branch = ( None , self . scope . to_scope ( default_expr) as & _ ) ;
2117
1984
self . local_env . pop ( ) ;
2118
1985
}
1986
+ CheckedPattern :: RecordLit ( _, _, _) => todo ! ( ) ,
2119
1987
} ;
2120
1988
2121
1989
// A default pattern was found, check any unreachable patterns.
@@ -2196,15 +2064,17 @@ impl_from_str_radix!(u64);
2196
2064
2197
2065
/// Simple patterns that have had some initial elaboration performed on them
2198
2066
#[ derive( Debug ) ]
2199
- enum CheckedPattern {
2200
- /// Pattern that binds local variable
2201
- Binder ( ByteRange , StringId ) ,
2067
+ enum CheckedPattern < ' arena > {
2068
+ /// Error sentinel
2069
+ ReportedError ( ByteRange ) ,
2202
2070
/// Placeholder patterns that match everything
2203
2071
Placeholder ( ByteRange ) ,
2072
+ /// Pattern that binds local variable
2073
+ Binder ( ByteRange , StringId ) ,
2204
2074
/// Constant literals
2205
2075
ConstLit ( ByteRange , Const ) ,
2206
- /// Error sentinel
2207
- ReportedError ( ByteRange ) ,
2076
+ /// Record literals
2077
+ RecordLit ( ByteRange , & ' arena [ StringId ] , & ' arena [ Self ] ) ,
2208
2078
}
2209
2079
2210
2080
/// Scrutinee of a match expression
0 commit comments