From a5f084308d39d5a8f55a824dd2a39b8717f54825 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alfonso=20Casta=C3=B1o?= Date: Thu, 24 Sep 2020 02:37:08 -0700 Subject: [PATCH] Constraints on ListVariadics Summary: This diff brings support for constraints on ListVariadics. The constraints that can be specified are the same as for TypeVar, that is, Bounds and Explicits. Although ListVariadics currently could infer the bounds based on other values of the concatenation, this approach is too limited. While Pyre had already internal support to express that a ListVariadic has constraints, all the logic related with it was empty, since it was not allow to directly specify the constraint of a ListVariadic. For that purpose, this diff defines the required logic, mostly the propagation of constraints to ensure that violation of constraints are identified, as well as subtyping between the various combinations of variadics. Since subtyping of variables can be ambiguous, we support every case except when given something reduced to `[...,A,Ts1] <: [Ts2,B,...]`. Finally, since parse_declaration does not have access to create_logic, currently only primitive types can be used as constraints, for that purpose a follow up diff is proposed. See comment: https://www.internalfb.com/intern/diff/D23342739/?dest_fbid=606703123545772&transaction_id=343730890366364 Differential Revision: D23342739 fbshipit-source-id: 8b6848fdfcfc6d29be5be0970d1b2a58ad24b13e --- analysis/analysisError.ml | 13 ++ analysis/attributeResolution.ml | 31 +++- analysis/attributeResolution.mli | 4 + analysis/test/integration/typeVariableTest.ml | 48 +++++ analysis/test/typeTest.ml | 11 +- analysis/type.ml | 72 +++++--- analysis/type.mli | 29 ++- analysis/typeConstraints.ml | 166 +++++++++++++++--- pyre_extensions/__init__.py | 2 +- 9 files changed, 320 insertions(+), 56 deletions(-) diff --git a/analysis/analysisError.ml b/analysis/analysisError.ml index 9a0238303d7..a668400616a 100644 --- a/analysis/analysisError.ml +++ b/analysis/analysisError.ml @@ -1173,6 +1173,16 @@ let rec messages ~concise ~signature location kind = (Type.Variable expected) name; ] + | InvalidTypeParameters + { name; kind = AttributeResolution.ViolateConstraintsVariadic { expected; actual } } -> + [ + Format.asprintf + "Type parameter list `%a` violates constraints on `%s` in generic type `%s`." + (Type.Record.OrderedTypes.pp_concise ~pp_type) + actual + (Type.Record.Variable.RecordVariadic.RecordList.name expected) + name; + ] | InvalidTypeParameters { name; kind = AttributeResolution.UnexpectedKind { expected; actual } } -> let details = @@ -3453,6 +3463,9 @@ let dequalify actual = dequalify actual; expected = Type.Variable.Unary.dequalify ~dequalify_map expected; } + | AttributeResolution.ViolateConstraintsVariadic { actual; expected } -> + AttributeResolution.ViolateConstraintsVariadic + { actual; expected = Type.Variable.Variadic.List.dequalify ~dequalify_map expected } | AttributeResolution.UnexpectedKind { actual; expected } -> AttributeResolution.UnexpectedKind { actual; expected = Type.Variable.dequalify dequalify_map expected } diff --git a/analysis/attributeResolution.ml b/analysis/attributeResolution.ml index c94b81d8126..f63ceb1bab5 100644 --- a/analysis/attributeResolution.ml +++ b/analysis/attributeResolution.ml @@ -169,6 +169,10 @@ module TypeParameterValidationTypes = struct actual: Type.t; expected: Type.Variable.Unary.t; } + | ViolateConstraintsVariadic of { + actual: Type.OrderedTypes.t; + expected: Type.Variable.Variadic.List.t; + } | UnexpectedKind of { actual: Type.Parameter.t; expected: Type.Variable.t; @@ -821,7 +825,7 @@ class base class_metadata_environment dependency = TypeConstraints.empty ~order ~pair - >>| TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair + >>= TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair |> Option.is_none in if invalid then @@ -848,10 +852,29 @@ class base class_metadata_environment dependency = ( CallableParameters Undefined, Some { name; kind = UnexpectedKind { expected = generic; actual = given } } ) - | ParameterVariadic _, CallableParameters _ - | ListVariadic _, Group _ -> + | ParameterVariadic _, CallableParameters _ -> given, None + | ListVariadic generic, Group given -> (* TODO(T47346673): accept w/ new kind of validation *) - given, None + let invalid = + let order = self#full_order ~assumptions in + let pair = Type.Variable.ListVariadicPair (generic, given) in + TypeOrder.OrderedConstraints.add_lower_bound + TypeConstraints.empty + ~order + ~pair + >>= TypeOrder.OrderedConstraints.add_upper_bound ~order ~pair + |> Option.is_none + in + if invalid then + ( Type.Parameter.Group Any, + Some + { + name; + kind = + ViolateConstraintsVariadic { actual = given; expected = generic }; + } ) + else + Type.Parameter.Group given, None in List.map paired ~f:check_parameter |> List.unzip diff --git a/analysis/attributeResolution.mli b/analysis/attributeResolution.mli index fc39ce2a7ec..d12c10651b0 100644 --- a/analysis/attributeResolution.mli +++ b/analysis/attributeResolution.mli @@ -34,6 +34,10 @@ type generic_type_problems = actual: Type.t; expected: Type.Variable.Unary.t; } + | ViolateConstraintsVariadic of { + actual: Type.OrderedTypes.t; + expected: Type.Variable.Variadic.List.t; + } | UnexpectedKind of { actual: Type.Parameter.t; expected: Type.Variable.t; diff --git a/analysis/test/integration/typeVariableTest.ml b/analysis/test/integration/typeVariableTest.ml index 7a7bb2aeda8..22937676fa3 100644 --- a/analysis/test/integration/typeVariableTest.ml +++ b/analysis/test/integration/typeVariableTest.ml @@ -1834,6 +1834,53 @@ let test_list_variadics context = () +let test_list_variadics_constraints context = + let assert_type_errors = assert_type_errors ~context in + assert_type_errors + {| + from typing import Generic, TypeVar + from typing_extensions import Literal + from pyre_extensions import ListVariadic + from pyre_extensions.type_variable_operators import Concatenate as Cat + + Ts = ListVariadic("Ts", bound=int) + A = TypeVar("A") + class Vec(Generic[Ts]): ... + + def f1(x : Vec[Cat[Ts,A]]) -> None: ... + def f2( *ts: Ts) -> Vec[Cat[Ts,float]]: ... + def f3( *ts: Ts) -> Vec[[int,float]]: ... + |} + [ + "Invalid type parameters [24]: Type parameter list `Concatenate[test.Ts, Variable[test.A]]` \ + violates constraints on `Ts` in generic type `Vec`."; + "Invalid type parameters [24]: Type parameter list `Concatenate[test.Ts, float]` violates \ + constraints on `Ts` in generic type `Vec`."; + "Invalid type parameters [24]: Type parameter list `int, float` violates constraints on `Ts` \ + in generic type `Vec`."; + ]; + assert_type_errors + {| + from typing import Generic, TypeVar + from pyre_extensions import ListVariadic + + Ts1 = ListVariadic("Ts1", bound=int) + Ts2 = ListVariadic("Ts2", bound=float) + + class Vec1(Generic[Ts1]): ... + class Vec2(Generic[Ts2]): ... + + def f1(x: Vec1[Ts1]) -> None: ... + def f2(x: Vec2[Ts2]) -> None: ... + def g1(x: Vec1[Ts2]) -> None: ... + def g2(x: Vec2[Ts1]) -> None: ... + |} + [ + "Invalid type parameters [24]: Type parameter list `test.Ts2` violates constraints on `Ts1` \ + in generic type `Vec1`."; + ] + + let test_map context = let assert_type_errors = assert_type_errors ~context in assert_type_errors @@ -2424,6 +2471,7 @@ let () = "single_explicit_error" >:: test_single_explicit_error; "callable_parameter_variadics" >:: test_callable_parameter_variadics; "list_variadics" >:: test_list_variadics; + "list_variadics_constraints" >:: test_list_variadics_constraints; "map" >:: test_map; "user_defined_variadics" >:: test_user_defined_variadics; "concatenation" >:: test_concatenation_operator; diff --git a/analysis/test/typeTest.ml b/analysis/test/typeTest.ml index cc88bd194ae..e80dc399d6e 100644 --- a/analysis/test/typeTest.ml +++ b/analysis/test/typeTest.ml @@ -2082,7 +2082,16 @@ let test_parse_type_variable_declarations _ = assert_parses_declaration "pyre_extensions.ListVariadic('Ts')" (Type.Variable.ListVariadic (Type.Variable.Variadic.List.create "target")); - assert_declaration_does_not_parse "pyre_extensions.ListVariadic('Ts', int, str)"; + assert_parses_declaration + "pyre_extensions.ListVariadic('Ts', int, str)" + (Type.Variable.ListVariadic + (Type.Variable.Variadic.List.create + "target" + ~constraints:(Explicit [Type.Primitive "int"; Type.Primitive "str"]))); + assert_parses_declaration + "pyre_extensions.ListVariadic('Ts', bound=int)" + (Type.Variable.ListVariadic + (Type.Variable.Variadic.List.create "target" ~constraints:(Bound (Type.Primitive "int")))); () diff --git a/analysis/type.ml b/analysis/type.ml index e6bb39f4dbf..6b33b1fb34d 100644 --- a/analysis/type.ml +++ b/analysis/type.ml @@ -2741,12 +2741,10 @@ let rec create_logic ~aliases ~variable_aliases { Node.value = expression; _ } = in List.find_map ~f:bound arguments in - if not (List.is_empty explicits) then - Record.Variable.Explicit explicits - else if Option.is_some bound then - Bound (Option.value_exn bound) - else - Unconstrained + match explicits, bound with + | [], Some bound -> Record.Variable.Bound bound + | explicits, _ when List.length explicits > 0 -> Explicit explicits + | _ -> Unconstrained in let variance = let variance_definition = function @@ -4101,28 +4099,48 @@ end = struct let parse_declaration value ~target = - match value with - | { - Node.value = - Expression.Call - { - callee = - { - Node.value = - Name - (Name.Attribute - { - base = { Node.value = Name (Name.Identifier "pyre_extensions"); _ }; - attribute = "ListVariadic"; - special = false; - }); - _; - }; - arguments = [{ Call.Argument.value = { Node.value = String _; _ }; _ }]; - }; - _; - } -> + match Node.value value with + | Expression.Call { callee; arguments = [{ value = { Node.value = String _; _ }; _ }] } + when name_is ~name:"pyre_extensions.ListVariadic" callee -> Some (create (Reference.show target)) + | Call + { + callee; + arguments = { Call.Argument.value = { Node.value = String _; _ }; _ } :: arguments; + } + when name_is ~name:"pyre_extensions.ListVariadic" callee -> + let constraints = + let explicits = + let explicit = function + | { + Call.Argument.name = None; + value = { Node.value = Name (Name.Identifier identifier); _ }; + } -> + let identifier = Identifier.sanitized identifier in + Some (Primitive identifier) + | _ -> None + in + List.filter_map ~f:explicit arguments + in + let bound = + let bound = function + | { + Call.Argument.value = { Node.value = Name (Name.Identifier identifier); _ }; + name = Some { Node.value = bound; _ }; + } + when String.equal (Identifier.sanitized bound) "bound" -> + let identifier = Identifier.sanitized identifier in + Some (Primitive identifier) + | _ -> None + in + List.find_map ~f:bound arguments + in + match explicits, bound with + | [], Some bound -> Record.Variable.Bound bound + | explicits, _ when List.length explicits > 0 -> Explicit explicits + | _ -> Unconstrained + in + Some (create (Reference.show target) ~constraints) | _ -> None end end diff --git a/analysis/type.mli b/analysis/type.mli index fae0707f3d0..6514add6938 100644 --- a/analysis/type.mli +++ b/analysis/type.mli @@ -50,7 +50,16 @@ module Record : sig end module RecordList : sig - type 'annotation record [@@deriving compare, eq, sexp, show, hash] + type 'annotation record = { + name: Identifier.t; + constraints: 'annotation constraints; + variance: variance; + state: state; + namespace: RecordNamespace.t; + } + [@@deriving compare, eq, sexp, show, hash] + + val name : 'a record -> string end end @@ -64,10 +73,24 @@ module Record : sig module OrderedTypes : sig module RecordConcatenate : sig module Middle : sig - type 'annotation t [@@deriving compare, eq, sexp, show, hash] + type 'annotation t = { + variable: 'annotation Variable.RecordVariadic.RecordList.record; + mappers: Identifier.t list; + } + [@@deriving compare, eq, sexp, show, hash] end - type ('middle, 'outer) t [@@deriving compare, eq, sexp, show, hash] + type 'annotation wrapping = { + head: 'annotation list; + tail: 'annotation list; + } + [@@deriving compare, eq, sexp, show, hash] + + type ('middle, 'annotation) t = { + middle: 'middle; + wrapping: 'annotation wrapping; + } + [@@deriving compare, eq, sexp, show, hash] end type 'annotation record = diff --git a/analysis/typeConstraints.ml b/analysis/typeConstraints.ml index 84f687f059e..0f344d37870 100644 --- a/analysis/typeConstraints.ml +++ b/analysis/typeConstraints.ml @@ -556,7 +556,68 @@ module OrderedConstraints (Order : OrderType) = struct | Some upper, Some lower -> BothBounds { upper; lower } - let less_or_equal order ~left ~right = + let rec less_or_equal order ~left ~right = + let compare_concretes ~left ~right = + match List.zip left right with + | Ok bounds -> + List.for_all bounds ~f:(fun (left, right) -> + Order.always_less_or_equal order ~left ~right) + | _ -> false + in + let concrete_less_concatenation + ~inverse + ~concretes + ~concatenation: + { + Type.OrderedTypes.RecordConcatenate.middle = + { Type.OrderedTypes.RecordConcatenate.Middle.variable = { constraints; _ }; _ }; + wrapping = { head; tail }; + } + = + let valid = List.length head + List.length tail <= List.length concretes in + let left_concrete, right_concrete = + List.take concretes (List.length head), List.take (List.rev concretes) (List.length tail) + in + let middle_concretes = + List.drop concretes (List.length head) + |> List.rev + |> fun x -> List.drop x (List.length tail) |> List.rev + in + match constraints with + | Bound upper_bound -> + let middle_as_concrete = + Type.OrderedTypes.Concrete (List.map middle_concretes ~f:(fun _ -> upper_bound)) + in + if inverse then + valid + && compare_concretes ~left:head ~right:left_concrete + && compare_concretes ~left:tail ~right:right_concrete + && less_or_equal + order + ~left:middle_as_concrete + ~right:(Type.OrderedTypes.Concrete concretes) + else + valid + && compare_concretes ~right:head ~left:left_concrete + && compare_concretes ~right:tail ~left:right_concrete + && less_or_equal + order + ~left:(Type.OrderedTypes.Concrete concretes) + ~right:middle_as_concrete + | Explicit explicits -> + List.for_all explicits ~f:(fun concrete -> + List.exists concretes ~f:(Type.equal concrete)) + | _ -> true + in + let is_free_concatenation + { + Type.OrderedTypes.RecordConcatenate.middle = + { Type.OrderedTypes.RecordConcatenate.Middle.variable; _ }; + _; + } + = + Type.Variable.Variadic.List.is_free variable + in if Type.OrderedTypes.equal left right then true else @@ -564,25 +625,66 @@ module OrderedConstraints (Order : OrderType) = struct | _, Any | Any, _ -> true - | Concatenation _, _ - | _, Concatenation _ -> - false - | Concrete upper_bounds, Concrete lower_bounds -> ( - match List.zip upper_bounds lower_bounds with - | Ok bounds -> - List.for_all bounds ~f:(fun (left, right) -> - Order.always_less_or_equal order ~left ~right) - | _ -> false ) - - - let narrowest_valid_value interval ~order ~variable:_ = - match interval with - | NoBounds -> None - | OnlyLowerBound bound - | OnlyUpperBound bound -> - Some bound - | BothBounds { upper; lower } -> - Option.some_if (less_or_equal order ~left:lower ~right:upper) lower + | Concatenation concatenation, Concrete concretes + | Concrete concretes, Concatenation concatenation + when is_free_concatenation concatenation -> + concrete_less_concatenation ~concretes ~concatenation ~inverse:false + | Concrete left, Concrete right -> compare_concretes ~left ~right + | ( Concatenation + ( { + Type.OrderedTypes.RecordConcatenate.middle = + { variable = { constraints = constraints_left; _ }; mappers = [] }; + wrapping = { head = head_left; tail = tail_left }; + } as concatenation_left ), + Concatenation + ( { + Type.OrderedTypes.RecordConcatenate.middle = + { variable = { constraints = constraints_right; _ }; mappers = [] }; + wrapping = { head = head_right; tail = tail_right }; + } as concatenation_right ) ) + when is_free_concatenation concatenation_left && is_free_concatenation concatenation_right + -> + (* Example [A1,B1,Ts1,C1] <: [A2,Ts2,C2] [A1] <: [A2] && [C1] <: [C2] && [B1, Ts1] <: + [Ts2] *) + let drop_head = min (List.length head_left) (List.length head_right) in + let drop_tail = min (List.length tail_left) (List.length tail_right) in + let (prefix_head_left, head_left), (prefix_head_right, head_right) = + List.split_n head_left drop_head, List.split_n head_right drop_head + in + let (prefix_tail_left, tail_left), (prefix_tail_right, tail_right) = + List.split_n tail_left drop_tail, List.split_n tail_right drop_tail + in + let valid = + compare_concretes ~left:prefix_head_left ~right:prefix_head_right + && compare_concretes ~left:prefix_tail_left ~right:prefix_tail_right + in + let valid = + valid + && + match constraints_left, constraints_right with + | Bound left, Bound right -> Order.always_less_or_equal order ~left ~right + | Explicit lefts, Bound bound -> + List.for_all lefts ~f:(fun left -> + Order.always_less_or_equal order ~left ~right:bound) + | Bound _, Explicit _ -> false + | Explicit lefts, Explicit rights -> + List.for_all lefts ~f:(List.mem rights ~equal:Type.equal) + | _, Unconstrained -> true + | Unconstrained, _ -> false + | _, _ -> false + in + let valid = + valid + && + match head_left @ tail_left, head_right @ tail_right with + (* [A,B,Ts1, C] <: [Ts2] *) + | head_tail_left, [] -> less_or_equal order ~left:(Concrete head_tail_left) ~right + | [], head_tail_right -> less_or_equal order ~left ~right:(Concrete head_tail_right) + | _, _ -> false + in + valid + | Concatenation _, _ -> false + | _, Concatenation _ -> false let intersection left right ~order = @@ -643,6 +745,30 @@ module OrderedConstraints (Order : OrderType) = struct | _ -> None ) + let narrowest_valid_value interval ~order ~(variable : Type.Variable.list_variadic_t) = + let { Type.Variable.RecordVariadic.RecordList.constraints = exogenous_constraint; _ } = + variable + in + let prune_interval interval = + match interval with + | NoBounds -> None + | OnlyLowerBound bound + | OnlyUpperBound bound -> + Some bound + | BothBounds { upper; lower } -> + Option.some_if (less_or_equal order ~left:lower ~right:upper) lower + in + let variable = ListVariadic.self_reference variable in + let intersected_interval = + match exogenous_constraint with + | Bound _ -> intersection interval (create ~upper_bound:variable ()) ~order + | Explicit _ -> + intersection interval (create ~upper_bound:variable ~lower_bound:variable ()) ~order + | _ -> Some interval + in + intersected_interval >>= prune_interval + + let merge_solution_in interval ~variable ~solution = let upper_bound, lower_bound = match interval with diff --git a/pyre_extensions/__init__.py b/pyre_extensions/__init__.py index 42f3a3dc887..d8c6939007a 100644 --- a/pyre_extensions/__init__.py +++ b/pyre_extensions/__init__.py @@ -92,7 +92,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: pass -def ListVariadic(name) -> object: +def ListVariadic(name, bound=None) -> object: return Any