@@ -592,3 +592,184 @@ def method(self):
592592
593593 assert isinstance (inferred [1 ], nodes .Const )
594594 assert inferred [1 ].value == fail_val
595+
596+
597+ @common_params (node = "x" )
598+ def test_if_exp_body (
599+ condition : str , satisfy_val : int | None , fail_val : int | None
600+ ) -> None :
601+ """Test constraint for a variable that is used in an if exp body."""
602+ node1 , node2 = builder .extract_node (
603+ f"""
604+ def f1(x = { fail_val } ):
605+ return (
606+ x if { condition } else None #@
607+ )
608+
609+ def f2(x = { satisfy_val } ):
610+ return (
611+ x if { condition } else None #@
612+ )
613+ """
614+ )
615+
616+ inferred = node1 .body .inferred ()
617+ assert len (inferred ) == 1
618+ assert inferred [0 ] is Uninferable
619+
620+ inferred = node2 .body .inferred ()
621+ assert len (inferred ) == 2
622+ assert isinstance (inferred [0 ], nodes .Const )
623+ assert inferred [0 ].value == satisfy_val
624+ assert inferred [1 ] is Uninferable
625+
626+
627+ @common_params (node = "x" )
628+ def test_if_exp_else (
629+ condition : str , satisfy_val : int | None , fail_val : int | None
630+ ) -> None :
631+ """Test constraint for a variable that is used in an if exp else block."""
632+ node1 , node2 = builder .extract_node (
633+ f"""
634+ def f1(x = { satisfy_val } ):
635+ return (
636+ None if { condition } else x #@
637+ )
638+
639+ def f2(x = { fail_val } ):
640+ return (
641+ None if { condition } else x #@
642+ )
643+ """
644+ )
645+
646+ inferred = node1 .orelse .inferred ()
647+ assert len (inferred ) == 1
648+ assert inferred [0 ] is Uninferable
649+
650+ inferred = node2 .orelse .inferred ()
651+ assert len (inferred ) == 2
652+ assert isinstance (inferred [0 ], nodes .Const )
653+ assert inferred [0 ].value == fail_val
654+ assert inferred [1 ] is Uninferable
655+
656+
657+ @common_params (node = "x" )
658+ def test_outside_if_exp (
659+ condition : str , satisfy_val : int | None , fail_val : int | None
660+ ) -> None :
661+ """Test that constraint in an if exp condition doesn't apply outside of the if exp."""
662+ nodes_ = builder .extract_node (
663+ f"""
664+ def f1(x = { fail_val } ):
665+ x if { condition } else None
666+ return (
667+ x #@
668+ )
669+
670+ def f2(x = { satisfy_val } ):
671+ None if { condition } else x
672+ return (
673+ x #@
674+ )
675+ """
676+ )
677+ for node , val in zip (nodes_ , (fail_val , satisfy_val )):
678+ inferred = node .inferred ()
679+ assert len (inferred ) == 2
680+ assert isinstance (inferred [0 ], nodes .Const )
681+ assert inferred [0 ].value == val
682+ assert inferred [1 ] is Uninferable
683+
684+
685+ @common_params (node = "x" )
686+ def test_nested_if_exp (
687+ condition : str , satisfy_val : int | None , fail_val : int | None
688+ ) -> None :
689+ """Test that constraint in an if exp condition applies within inner if exp."""
690+ node1 , node2 = builder .extract_node (
691+ f"""
692+ def f1(y, x = { fail_val } ):
693+ return (
694+ (x if y else None) if { condition } else None #@
695+ )
696+
697+ def f2(y, x = { satisfy_val } ):
698+ return (
699+ (x if y else None) if { condition } else None #@
700+ )
701+ """
702+ )
703+
704+ inferred = node1 .body .body .inferred ()
705+ assert len (inferred ) == 1
706+ assert inferred [0 ] is Uninferable
707+
708+ inferred = node2 .body .body .inferred ()
709+ assert len (inferred ) == 2
710+ assert isinstance (inferred [0 ], nodes .Const )
711+ assert inferred [0 ].value == satisfy_val
712+ assert inferred [1 ] is Uninferable
713+
714+
715+ @common_params (node = "self.x" )
716+ def test_if_exp_instance_attr (
717+ condition : str , satisfy_val : int | None , fail_val : int | None
718+ ) -> None :
719+ """Test constraint for an instance attribute in an if exp."""
720+ node1 , node2 = builder .extract_node (
721+ f"""
722+ class A1:
723+ def __init__(self, x = { fail_val } ):
724+ self.x = x
725+
726+ def method(self):
727+ return (
728+ self.x if { condition } else None #@
729+ )
730+
731+ class A2:
732+ def __init__(self, x = { satisfy_val } ):
733+ self.x = x
734+
735+ def method(self):
736+ return (
737+ self.x if { condition } else None #@
738+ )
739+ """
740+ )
741+
742+ inferred = node1 .body .inferred ()
743+ assert len (inferred ) == 1
744+ assert inferred [0 ] is Uninferable
745+
746+ inferred = node2 .body .inferred ()
747+ assert len (inferred ) == 2
748+ assert isinstance (inferred [0 ], nodes .Const )
749+ assert inferred [0 ].value == satisfy_val
750+ assert inferred [1 ].value is Uninferable
751+
752+
753+ @common_params (node = "self.x" )
754+ def test_if_exp_instance_attr_varname_collision (
755+ condition : str , satisfy_val : int | None , fail_val : int | None
756+ ) -> None :
757+ """Test that constraint in an if exp condition doesn't apply to a variable with the same name."""
758+ node = builder .extract_node (
759+ f"""
760+ class A:
761+ def __init__(self, x = { fail_val } ):
762+ self.x = x
763+
764+ def method(self, x = { fail_val } ):
765+ return (
766+ x if { condition } else None #@
767+ )
768+ """
769+ )
770+
771+ inferred = node .body .inferred ()
772+ assert len (inferred ) == 2
773+ assert isinstance (inferred [0 ], nodes .Const )
774+ assert inferred [0 ].value == fail_val
775+ assert inferred [1 ].value is Uninferable
0 commit comments