Skip to content

Commit 3db2bd9

Browse files
Add brain module for statistics inference (#2832)
1 parent df07b50 commit 3db2bd9

File tree

4 files changed

+158
-0
lines changed

4 files changed

+158
-0
lines changed

astroid/brain/brain_statistics.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2+
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4+
5+
"""Astroid hooks for understanding statistics library module.
6+
7+
Provides inference improvements for statistics module functions that have
8+
complex runtime behavior difficult to analyze statically.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
from collections.abc import Iterator
14+
from typing import TYPE_CHECKING
15+
16+
from astroid.context import InferenceContext
17+
from astroid.inference_tip import inference_tip
18+
from astroid.manager import AstroidManager
19+
from astroid.nodes.node_classes import Attribute, Call, ImportFrom, Name
20+
from astroid.util import Uninferable
21+
22+
if TYPE_CHECKING:
23+
from astroid.typing import InferenceResult
24+
25+
26+
def _looks_like_statistics_quantiles(node: Call) -> bool:
27+
"""Check if this is a call to statistics.quantiles."""
28+
# Case 1: statistics.quantiles(...)
29+
if isinstance(node.func, Attribute):
30+
if node.func.attrname != "quantiles":
31+
return False
32+
if isinstance(node.func.expr, Name):
33+
if node.func.expr.name == "statistics":
34+
return True
35+
36+
# Case 2: from statistics import quantiles; quantiles(...)
37+
if isinstance(node.func, Name) and node.func.name == "quantiles":
38+
# Check if quantiles was imported from statistics
39+
try:
40+
frame = node.frame()
41+
if "quantiles" in frame.locals:
42+
# Look for import from statistics
43+
for stmt in frame.body:
44+
if (
45+
isinstance(stmt, ImportFrom)
46+
and stmt.modname == "statistics"
47+
and any(name[0] == "quantiles" for name in stmt.names or [])
48+
):
49+
return True
50+
except (AttributeError, TypeError):
51+
# If we can't determine the import context, be conservative
52+
pass
53+
54+
return False
55+
56+
57+
def infer_statistics_quantiles(
58+
node: Call, context: InferenceContext | None = None
59+
) -> Iterator[InferenceResult]:
60+
"""Infer the result of statistics.quantiles() calls.
61+
62+
Returns Uninferable because quantiles() has complex runtime behavior
63+
that cannot be statically analyzed, preventing false positives in
64+
pylint's unbalanced-tuple-unpacking checker.
65+
66+
statistics.quantiles() returns a list with (n-1) elements, but static
67+
analysis sees only the empty list initializations in the function body.
68+
"""
69+
yield Uninferable
70+
71+
72+
def register(manager: AstroidManager) -> None:
73+
"""Register statistics-specific inference improvements."""
74+
manager.register_transform(
75+
Call,
76+
inference_tip(infer_statistics_quantiles),
77+
_looks_like_statistics_quantiles,
78+
)

astroid/brain/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def register_all_brains(manager: AstroidManager) -> None:
7575
brain_six,
7676
brain_sqlalchemy,
7777
brain_ssl,
78+
brain_statistics,
7879
brain_subprocess,
7980
brain_threading,
8081
brain_type,
@@ -126,6 +127,7 @@ def register_all_brains(manager: AstroidManager) -> None:
126127
brain_six.register(manager)
127128
brain_sqlalchemy.register(manager)
128129
brain_ssl.register(manager)
130+
brain_statistics.register(manager)
129131
brain_subprocess.register(manager)
130132
brain_threading.register(manager)
131133
brain_type.register(manager)

tests/brain/test_brain.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,16 @@ class Derived(collections.abc.Hashable, collections.abc.Iterator[int]):
321321
],
322322
)
323323

324+
def test_statistics_quantiles_from_import(self):
325+
node = builder.extract_node(
326+
"""
327+
from statistics import quantiles
328+
quantiles([1, 2, 3, 4, 5, 6, 7, 8, 9], n=4)
329+
"""
330+
)
331+
inferred = next(node.infer())
332+
self.assertIs(inferred, util.Uninferable)
333+
324334

325335
class TypingBrain(unittest.TestCase):
326336
def test_namedtuple_base(self) -> None:

tests/brain/test_statistics.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
2+
# For details: https://github.com/pylint-dev/astroid/blob/main/LICENSE
3+
# Copyright (c) https://github.com/pylint-dev/astroid/blob/main/CONTRIBUTORS.txt
4+
5+
"""Tests for brain statistics module."""
6+
7+
from __future__ import annotations
8+
9+
import unittest
10+
11+
from astroid import extract_node
12+
from astroid.util import Uninferable
13+
14+
15+
class StatisticsBrainTest(unittest.TestCase):
16+
"""Test the brain statistics module functionality."""
17+
18+
def test_statistics_quantiles_inference(self) -> None:
19+
"""Test that statistics.quantiles() returns Uninferable instead of empty list."""
20+
node = extract_node(
21+
"""
22+
import statistics
23+
statistics.quantiles(list(range(100)), n=4) #@
24+
"""
25+
)
26+
inferred = list(node.infer())
27+
self.assertEqual(len(inferred), 1)
28+
self.assertIs(inferred[0], Uninferable)
29+
30+
def test_statistics_quantiles_different_args(self) -> None:
31+
"""Test statistics.quantiles with different arguments."""
32+
node = extract_node(
33+
"""
34+
import statistics
35+
statistics.quantiles([1, 2, 3, 4, 5], n=10, method='inclusive') #@
36+
"""
37+
)
38+
inferred = list(node.infer())
39+
self.assertEqual(len(inferred), 1)
40+
self.assertIs(inferred[0], Uninferable)
41+
42+
def test_statistics_quantiles_assignment_unpacking(self) -> None:
43+
"""Test the specific case that was causing false positives."""
44+
node = extract_node(
45+
"""
46+
import statistics
47+
q1, q2, q3 = statistics.quantiles(list(range(100)), n=4) #@
48+
"""
49+
)
50+
call_node = node.value
51+
inferred = list(call_node.infer())
52+
self.assertEqual(len(inferred), 1)
53+
self.assertIs(inferred[0], Uninferable)
54+
55+
def test_other_statistics_functions_not_affected(self) -> None:
56+
"""Test that other statistics functions are not affected by our brain module."""
57+
node = extract_node(
58+
"""
59+
import statistics
60+
statistics.mean([1, 2, 3, 4, 5]) #@
61+
"""
62+
)
63+
inferred = list(node.infer())
64+
self.assertNotEqual(len(inferred), 0)
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

0 commit comments

Comments
 (0)