diff --git a/analysis/analysisError.ml b/analysis/analysisError.ml index 9a0238303d7..a668400616a 100644 --- a/analysis/analysisError.ml +++ b/analysis/analysisError.ml @@ -1173,6 +1173,16 @@ let rec messages ~concise ~signature location kind = (Type.Variable expected) name; ] + | InvalidTypeParameters + { name; kind = AttributeResolution.ViolateConstraintsVariadic { expected; actual } } -> + [ + Format.asprintf + "Type parameter list `%a` violates constraints on `%s` in generic type `%s`." + (Type.Record.OrderedTypes.pp_concise ~pp_type) + actual + (Type.Record.Variable.RecordVariadic.RecordList.name expected) + name; + ] | InvalidTypeParameters { name; kind = AttributeResolution.UnexpectedKind { expected; actual } } -> let details = @@ -3453,6 +3463,9 @@ let dequalify actual = dequalify actual; expected = Type.Variable.Unary.dequalify ~dequalify_map expected; } + | AttributeResolution.ViolateConstraintsVariadic { actual; expected } -> + AttributeResolution.ViolateConstraintsVariadic + { actual; expected = Type.Variable.Variadic.List.dequalify ~dequalify_map expected } | AttributeResolution.UnexpectedKind { actual; expected } -> AttributeResolution.UnexpectedKind { actual; expected = Type.Variable.dequalify dequalify_map expected } diff --git a/analysis/attributeResolution.ml b/analysis/attributeResolution.ml index c94b81d8126..f63ceb1bab5 100644 --- a/analysis/attributeResolution.ml +++ b/analysis/attributeResolution.ml @@ -169,6 +169,10 @@ module TypeParameterValidationTypes = struct actual: Type.t; expected: Type.Variable.Unary.t; } + | ViolateConstraintsVariadic of { + actual: Type.OrderedTypes.t; + expected: Type.Variable.Variadic.List.t; + } | UnexpectedKind of { actual: Type.Parameter.t; expected: Type.Variable.t; @@ -821,7 +825,7 @@ class base class_metadata_environment dependency = TypeConstraints.empty ~order ~pair - >>| TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair + >>= TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair |> Option.is_none in if invalid then @@ -848,10 +852,29 @@ class base class_metadata_environment dependency = ( CallableParameters Undefined, Some { name; kind = UnexpectedKind { expected = generic; actual = given } } ) - | ParameterVariadic _, CallableParameters _ - | ListVariadic _, Group _ -> + | ParameterVariadic _, CallableParameters _ -> given, None + | ListVariadic generic, Group given -> (* TODO(T47346673): accept w/ new kind of validation *) - given, None + let invalid = + let order = self#full_order ~assumptions in + let pair = Type.Variable.ListVariadicPair (generic, given) in + TypeOrder.OrderedConstraints.add_lower_bound + TypeConstraints.empty + ~order + ~pair + >>= TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair + |> Option.is_none + in + if invalid then + ( Type.Parameter.Group Any, + Some + { + name; + kind = + ViolateConstraintsVariadic { actual = given; expected = generic }; + } ) + else + Type.Parameter.Group given, None in List.map paired ~f:check_parameter |> List.unzip diff --git a/analysis/attributeResolution.mli b/analysis/attributeResolution.mli index fc39ce2a7ec..d12c10651b0 100644 --- a/analysis/attributeResolution.mli +++ b/analysis/attributeResolution.mli @@ -34,6 +34,10 @@ type generic_type_problems = actual: Type.t; expected: Type.Variable.Unary.t; } + | ViolateConstraintsVariadic of { + actual: Type.OrderedTypes.t; + expected: Type.Variable.Variadic.List.t; + } | UnexpectedKind of { actual: Type.Parameter.t; expected: Type.Variable.t; diff --git a/analysis/test/integration/typeVariableTest.ml b/analysis/test/integration/typeVariableTest.ml index 7a7bb2aeda8..22937676fa3 100644 --- a/analysis/test/integration/typeVariableTest.ml +++ b/analysis/test/integration/typeVariableTest.ml @@ -1834,6 +1834,53 @@ let test_list_variadics context = () +let test_list_variadics_constraints context = + let assert_type_errors = assert_type_errors ~context in + assert_type_errors + {| + from typing import Generic, TypeVar + from typing_extensions import Literal + from pyre_extensions import ListVariadic + from pyre_extensions.type_variable_operators import Concatenate as Cat + + Ts = ListVariadic("Ts", bound=int) + A = TypeVar("A") + class Vec(Generic[Ts]): ... + + def f1(x : Vec[Cat[Ts,A]]) -> None: ... + def f2( *ts: Ts) -> Vec[Cat[Ts,float]]: ... + def f3( *ts: Ts) -> Vec[[int,float]]: ... + |} + [ + "Invalid type parameters [24]: Type parameter list `Concatenate[test.Ts, Variable[test.A]]` \ + violates constraints on `Ts` in generic type `Vec`."; + "Invalid type parameters [24]: Type parameter list `Concatenate[test.Ts, float]` violates \ + constraints on `Ts` in generic type `Vec`."; + "Invalid type parameters [24]: Type parameter list `int, float` violates constraints on `Ts` \ + in generic type `Vec`."; + ]; + assert_type_errors + {| + from typing import Generic, TypeVar + from pyre_extensions import ListVariadic + + Ts1 = ListVariadic("Ts1", bound=int) + Ts2 = ListVariadic("Ts2", bound=float) + + class Vec1(Generic[Ts1]): ... + class Vec2(Generic[Ts2]): ... + + def f1(x: Vec1[Ts1]) -> None: ... + def f2(x: Vec2[Ts2]) -> None: ... + def g1(x: Vec1[Ts2]) -> None: ... + def g2(x: Vec2[Ts1]) -> None: ... + |} + [ + "Invalid type parameters [24]: Type parameter list `test.Ts2` violates constraints on `Ts1` \ + in generic type `Vec1`."; + ] + + let test_map context = let assert_type_errors = assert_type_errors ~context in assert_type_errors @@ -2424,6 +2471,7 @@ let () = "single_explicit_error" >:: test_single_explicit_error; "callable_parameter_variadics" >:: test_callable_parameter_variadics; "list_variadics" >:: test_list_variadics; + "list_variadics_constraints" >:: test_list_variadics_constraints; "map" >:: test_map; "user_defined_variadics" >:: test_user_defined_variadics; "concatenation" >:: test_concatenation_operator; diff --git a/analysis/test/typeTest.ml b/analysis/test/typeTest.ml index cc88bd194ae..e80dc399d6e 100644 --- a/analysis/test/typeTest.ml +++ b/analysis/test/typeTest.ml @@ -2082,7 +2082,16 @@ let test_parse_type_variable_declarations _ = assert_parses_declaration "pyre_extensions.ListVariadic('Ts')" (Type.Variable.ListVariadic (Type.Variable.Variadic.List.create "target")); - assert_declaration_does_not_parse "pyre_extensions.ListVariadic('Ts', int, str)"; + assert_parses_declaration + "pyre_extensions.ListVariadic('Ts', int, str)" + (Type.Variable.ListVariadic + (Type.Variable.Variadic.List.create + "target" + ~constraints:(Explicit [Type.Primitive "int"; Type.Primitive "str"]))); + assert_parses_declaration + "pyre_extensions.ListVariadic('Ts', bound=int)" + (Type.Variable.ListVariadic + (Type.Variable.Variadic.List.create "target" ~constraints:(Bound (Type.Primitive "int")))); () diff --git a/analysis/type.ml b/analysis/type.ml index e6bb39f4dbf..6b33b1fb34d 100644 --- a/analysis/type.ml +++ b/analysis/type.ml @@ -2741,12 +2741,10 @@ let rec create_logic ~aliases ~variable_aliases { Node.value = expression; _ } = in List.find_map ~f:bound arguments in - if not (List.is_empty explicits) then - Record.Variable.Explicit explicits - else if Option.is_some bound then - Bound (Option.value_exn bound) - else - Unconstrained + match explicits, bound with + | [], Some bound -> Record.Variable.Bound bound + | explicits, _ when List.length explicits > 0 -> Explicit explicits + | _ -> Unconstrained in let variance = let variance_definition = function @@ -4101,28 +4099,48 @@ end = struct let parse_declaration value ~target = - match value with - | { - Node.value = - Expression.Call - { - callee = - { - Node.value = - Name - (Name.Attribute - { - base = { Node.value = Name (Name.Identifier "pyre_extensions"); _ }; - attribute = "ListVariadic"; - special = false; - }); - _; - }; - arguments = [{ Call.Argument.value = { Node.value = String _; _ }; _ }]; - }; - _; - } -> + match Node.value value with + | Expression.Call { callee; arguments = [{ value = { Node.value = String _; _ }; _ }] } + when name_is ~name:"pyre_extensions.ListVariadic" callee -> Some (create (Reference.show target)) + | Call + { + callee; + arguments = { Call.Argument.value = { Node.value = String _; _ }; _ } :: arguments; + } + when name_is ~name:"pyre_extensions.ListVariadic" callee -> + let constraints = + let explicits = + let explicit = function + | { + Call.Argument.name = None; + value = { Node.value = Name (Name.Identifier identifier); _ }; + } -> + let identifier = Identifier.sanitized identifier in + Some (Primitive identifier) + | _ -> None + in + List.filter_map ~f:explicit arguments + in + let bound = + let bound = function + | { + Call.Argument.value = { Node.value = Name (Name.Identifier identifier); _ }; + name = Some { Node.value = bound; _ }; + } + when String.equal (Identifier.sanitized bound) "bound" -> + let identifier = Identifier.sanitized identifier in + Some (Primitive identifier) + | _ -> None + in + List.find_map ~f:bound arguments + in + match explicits, bound with + | [], Some bound -> Record.Variable.Bound bound + | explicits, _ when List.length explicits > 0 -> Explicit explicits + | _ -> Unconstrained + in + Some (create (Reference.show target) ~constraints) | _ -> None end end diff --git a/analysis/type.mli b/analysis/type.mli index fae0707f3d0..6514add6938 100644 --- a/analysis/type.mli +++ b/analysis/type.mli @@ -50,7 +50,16 @@ module Record : sig end module RecordList : sig - type 'annotation record [@@deriving compare, eq, sexp, show, hash] + type 'annotation record = { + name: Identifier.t; + constraints: 'annotation constraints; + variance: variance; + state: state; + namespace: RecordNamespace.t; + } + [@@deriving compare, eq, sexp, show, hash] + + val name : 'a record -> string end end @@ -64,10 +73,24 @@ module Record : sig module OrderedTypes : sig module RecordConcatenate : sig module Middle : sig - type 'annotation t [@@deriving compare, eq, sexp, show, hash] + type 'annotation t = { + variable: 'annotation Variable.RecordVariadic.RecordList.record; + mappers: Identifier.t list; + } + [@@deriving compare, eq, sexp, show, hash] end - type ('middle, 'outer) t [@@deriving compare, eq, sexp, show, hash] + type 'annotation wrapping = { + head: 'annotation list; + tail: 'annotation list; + } + [@@deriving compare, eq, sexp, show, hash] + + type ('middle, 'annotation) t = { + middle: 'middle; + wrapping: 'annotation wrapping; + } + [@@deriving compare, eq, sexp, show, hash] end type 'annotation record = diff --git a/analysis/typeConstraints.ml b/analysis/typeConstraints.ml index 84f687f059e..0f344d37870 100644 --- a/analysis/typeConstraints.ml +++ b/analysis/typeConstraints.ml @@ -556,7 +556,68 @@ module OrderedConstraints (Order : OrderType) = struct | Some upper, Some lower -> BothBounds { upper; lower } - let less_or_equal order ~left ~right = + let rec less_or_equal order ~left ~right = + let compare_concretes ~left ~right = + match List.zip left right with + | Ok bounds -> + List.for_all bounds ~f:(fun (left, right) -> + Order.always_less_or_equal order ~left ~right) + | _ -> false + in + let concrete_less_concatenation + ~inverse + ~concretes + ~concatenation: + { + Type.OrderedTypes.RecordConcatenate.middle = + { Type.OrderedTypes.RecordConcatenate.Middle.variable = { constraints; _ }; _ }; + wrapping = { head; tail }; + } + = + let valid = List.length head + List.length tail <= List.length concretes in + let left_concrete, right_concrete = + List.take concretes (List.length head), List.take (List.rev concretes) (List.length tail) + in + let middle_concretes = + List.drop concretes (List.length head) + |> List.rev + |> fun x -> List.drop x (List.length tail) |> List.rev + in + match constraints with + | Bound upper_bound -> + let middle_as_concrete = + Type.OrderedTypes.Concrete (List.map middle_concretes ~f:(fun _ -> upper_bound)) + in + if inverse then + valid + && compare_concretes ~left:head ~right:left_concrete + && compare_concretes ~left:tail ~right:right_concrete + && less_or_equal + order + ~left:middle_as_concrete + ~right:(Type.OrderedTypes.Concrete concretes) + else + valid + && compare_concretes ~right:head ~left:left_concrete + && compare_concretes ~right:tail ~left:right_concrete + && less_or_equal + order + ~left:(Type.OrderedTypes.Concrete concretes) + ~right:middle_as_concrete + | Explicit explicits -> + List.for_all explicits ~f:(fun concrete -> + List.exists concretes ~f:(Type.equal concrete)) + | _ -> true + in + let is_free_concatenation + { + Type.OrderedTypes.RecordConcatenate.middle = + { Type.OrderedTypes.RecordConcatenate.Middle.variable; _ }; + _; + } + = + Type.Variable.Variadic.List.is_free variable + in if Type.OrderedTypes.equal left right then true else @@ -564,25 +625,66 @@ module OrderedConstraints (Order : OrderType) = struct | _, Any | Any, _ -> true - | Concatenation _, _ - | _, Concatenation _ -> - false - | Concrete upper_bounds, Concrete lower_bounds -> ( - match List.zip upper_bounds lower_bounds with - | Ok bounds -> - List.for_all bounds ~f:(fun (left, right) -> - Order.always_less_or_equal order ~left ~right) - | _ -> false ) - - - let narrowest_valid_value interval ~order ~variable:_ = - match interval with - | NoBounds -> None - | OnlyLowerBound bound - | OnlyUpperBound bound -> - Some bound - | BothBounds { upper; lower } -> - Option.some_if (less_or_equal order ~left:lower ~right:upper) lower + | Concatenation concatenation, Concrete concretes + | Concrete concretes, Concatenation concatenation + when is_free_concatenation concatenation -> + concrete_less_concatenation ~concretes ~concatenation ~inverse:false + | Concrete left, Concrete right -> compare_concretes ~left ~right + | ( Concatenation + ( { + Type.OrderedTypes.RecordConcatenate.middle = + { variable = { constraints = constraints_left; _ }; mappers = [] }; + wrapping = { head = head_left; tail = tail_left }; + } as concatenation_left ), + Concatenation + ( { + Type.OrderedTypes.RecordConcatenate.middle = + { variable = { constraints = constraints_right; _ }; mappers = [] }; + wrapping = { head = head_right; tail = tail_right }; + } as concatenation_right ) ) + when is_free_concatenation concatenation_left && is_free_concatenation concatenation_right + -> + (* Example [A1,B1,Ts1,C1] <: [A2,Ts2,C2] [A1] <: [A2] && [C1] <: [C2] && [B1, Ts1] <: + [Ts2] *) + let drop_head = min (List.length head_left) (List.length head_right) in + let drop_tail = min (List.length tail_left) (List.length tail_right) in + let (prefix_head_left, head_left), (prefix_head_right, head_right) = + List.split_n head_left drop_head, List.split_n head_right drop_head + in + let (prefix_tail_left, tail_left), (prefix_tail_right, tail_right) = + List.split_n tail_left drop_tail, List.split_n tail_right drop_tail + in + let valid = + compare_concretes ~left:prefix_head_left ~right:prefix_head_right + && compare_concretes ~left:prefix_tail_left ~right:prefix_tail_right + in + let valid = + valid + && + match constraints_left, constraints_right with + | Bound left, Bound right -> Order.always_less_or_equal order ~left ~right + | Explicit lefts, Bound bound -> + List.for_all lefts ~f:(fun left -> + Order.always_less_or_equal order ~left ~right:bound) + | Bound _, Explicit _ -> false + | Explicit lefts, Explicit rights -> + List.for_all lefts ~f:(List.mem rights ~equal:Type.equal) + | _, Unconstrained -> true + | Unconstrained, _ -> false + | _, _ -> false + in + let valid = + valid + && + match head_left @ tail_left, head_right @ tail_right with + (* [A,B,Ts1, C] <: [Ts2] *) + | head_tail_left, [] -> less_or_equal order ~left:(Concrete head_tail_left) ~right + | [], head_tail_right -> less_or_equal order ~left ~right:(Concrete head_tail_right) + | _, _ -> false + in + valid + | Concatenation _, _ -> false + | _, Concatenation _ -> false let intersection left right ~order = @@ -643,6 +745,30 @@ module OrderedConstraints (Order : OrderType) = struct | _ -> None ) + let narrowest_valid_value interval ~order ~(variable : Type.Variable.list_variadic_t) = + let { Type.Variable.RecordVariadic.RecordList.constraints = exogenous_constraint; _ } = + variable + in + let prune_interval interval = + match interval with + | NoBounds -> None + | OnlyLowerBound bound + | OnlyUpperBound bound -> + Some bound + | BothBounds { upper; lower } -> + Option.some_if (less_or_equal order ~left:lower ~right:upper) lower + in + let variable = ListVariadic.self_reference variable in + let intersected_interval = + match exogenous_constraint with + | Bound _ -> intersection interval (create ~upper_bound:variable ()) ~order + | Explicit _ -> + intersection interval (create ~upper_bound:variable ~lower_bound:variable ()) ~order + | _ -> Some interval + in + intersected_interval >>= prune_interval + + let merge_solution_in interval ~variable ~solution = let upper_bound, lower_bound = match interval with diff --git a/pyre_extensions/__init__.py b/pyre_extensions/__init__.py index 42f3a3dc887..d8c6939007a 100644 --- a/pyre_extensions/__init__.py +++ b/pyre_extensions/__init__.py @@ -92,7 +92,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: pass -def ListVariadic(name) -> object: +def ListVariadic(name, bound=None) -> object: return Any