From b412a486f99a6d72c2225ad7cbba589252823105 Mon Sep 17 00:00:00 2001 From: Martin Rubey Date: Thu, 30 Oct 2025 18:12:53 +0100 Subject: [PATCH] slightly restructure the logic of first_terms, to avoid copying when calling a statistic --- src/sage/databases/findstat.py | 85 +++++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 22 deletions(-) diff --git a/src/sage/databases/findstat.py b/src/sage/databases/findstat.py index 83b2a21c921..05260b86598 100644 --- a/src/sage/databases/findstat.py +++ b/src/sage/databases/findstat.py @@ -198,6 +198,7 @@ def mapping(sigma): # https://www.gnu.org/licenses/ # **************************************************************************** from sage.misc.lazy_list import lazy_list +from sage.misc.lazy_attribute import lazy_attribute from sage.misc.inherit_comparison import InheritComparisonClasscallMetaclass from sage.structure.element import Element from sage.structure.parent import Parent @@ -631,7 +632,7 @@ def _data_from_iterable(iterable, mapping=False, domain=None, pre_data = [(elts, vals)] # pre_data is a list of all elements of the iterator accessed so - # far, for each of its elements and also the remainder ot the + # far, for each of its elements and also the remainder of the # iterator, each element is either a pair ``(object, value)`` or # a pair ``(objects, values)`` elts, vals = pre_data[0] @@ -730,16 +731,15 @@ def _data_from_data(data, max_values): [0, 0, 1, 1, 2, 2, 1, 0, 0, 0, 1, 1, 1, 2, 3])] """ query = [] - total = min(max_values, FINDSTAT_MAX_VALUES) iterator = iter(data) - while total > 0: + while max_values > 0: try: elts, vals = next(iterator) except StopIteration: break - if total >= len(elts): + if max_values >= len(elts): query.append((elts, vals)) - total -= len(elts) + max_values -= len(elts) else: break # assuming that the next pair is even larger @@ -1017,12 +1017,12 @@ def findstat(query=None, values=None, distribution=None, domain=None, sage: findstat("Permutations", lambda x: 1, depth='x') # optional -- internet Traceback (most recent call last): ... - ValueError: E021: Depth should be a nonnegative integer at most 9, but is x. + ValueError: E021: Depth should be a non-negative integer at most 9, but is x. sage: findstat("Permutations", lambda x: 1, depth=100) # optional -- internet Traceback (most recent call last): ... - ValueError: E021: Depth should be a nonnegative integer at most 9, but is 100. + ValueError: E021: Depth should be a non-negative integer at most 9, but is 100. sage: S = Permutation sage: findstat([(S([1,2]), 1), ([S([1,3,2]), S([1,2])], [2,3])]) # optional -- internet @@ -1786,10 +1786,10 @@ def set_sage_code(self, value): EXAMPLES:: sage: q = findstat([(d, randint(1,1000)) for d in DyckWords(4)]) # optional -- internet - sage: q.set_sage_code("def statistic(x):\n return randint(1, 1000)") # optional -- internet + sage: q.set_sage_code("def statistic(x):\n return randint(1, 1000)") # optional -- internet sage: print(q.sage_code()) # optional -- internet def statistic(x): - return randint(1,1000) + return randint(1, 1000) """ if value != self.sage_code(): self._modified = True @@ -1818,9 +1818,22 @@ def __init__(self): sage: FindStatCombinatorialStatistic() """ - self._first_terms_cache = None self._first_terms_raw_cache = None + @lazy_attribute + def _first_terms_cache(self): + """ + Return the first terms of the (compound) statistic as a + dictionary. + + EXAMPLES:: + + sage: findstat(41)._first_terms_cache[PerfectMatching([(1,6),(2,5),(3,4)])] # optional -- internet + 3 + """ + # this indirectly initializes self._first_terms_raw_cache + return dict(self._fetch_first_terms()) + def first_terms(self): r""" Return the first terms of the (compound) statistic as a @@ -1838,10 +1851,6 @@ def first_terms(self): sage: findstat(41).first_terms()[PerfectMatching([(1,6),(2,5),(3,4)])] # optional -- internet 3 """ - # initialize self._first_terms_cache and - # self._first_terms_raw_cache on first call - if self._first_terms_cache is None: - self._first_terms_cache = self._fetch_first_terms() # a shallow copy suffices - tuples are immutable return dict(self._first_terms_cache) @@ -1944,7 +1953,7 @@ def _generating_functions_dict(self, domain = self.domain() levels_with_sizes = domain.levels_with_sizes() total = 0 - for elt, val in self.first_terms().items(): + for elt, val in self._first_terms_cache.items(): if total == max_values: break lvl = domain.element_level(elt) @@ -2153,7 +2162,7 @@ def __call__(self, elt): sage: q(graphs.PetersenGraph().copy(immutable=True)) # optional -- internet 2 """ - val = self.first_terms().get(elt, None) + val = self._first_terms_cache.get(elt, None) if val is None: return FindStatFunction.__call__(self, elt) return val @@ -2267,12 +2276,22 @@ def set_first_terms(self, values): [(1, 4), (2, 3)] => 3 sage: s.reset() # optional -- internet """ - to_str = self.domain().to_string() + domain = self.domain() + from_str = domain.from_string() + to_str = domain.to_string() + + def to_domain(elt): + if domain.is_element(elt): + return elt + if not isinstance(elt, str): + elt = str(elt) + return from_str(elt) + new = [(to_str(obj), value) for obj, value in values] if sorted(new) != sorted(self.first_terms_str()): self._modified = True self._first_terms_raw_cache = new - self._first_terms_cache = values + self._first_terms_cache = {to_domain(elt): v for elt, v in values} def code(self): r""" @@ -2584,6 +2603,7 @@ def __init__(self, data=None, values_of=None, distribution_of=None, self._known_terms = data else: self._known_terms = known_terms + self._known_terms_number = 0 self._values_of = None self._distribution_of = None self._depth = depth @@ -2647,9 +2667,26 @@ def __init__(self, data=None, values_of=None, distribution_of=None, function=function) Element.__init__(self, FindStatStatistics()) # this is not completely correct, but it works + @lazy_attribute + def _first_terms_cache(self): + """ + Return the pairs of the known terms which contain + singletons, as a dictionary. + + EXAMPLES:: + + sage: PM = PerfectMatchings + sage: l = [(PM(2*n), [m.number_of_nestings() for m in PM(2*n)]) for n in range(5)] + sage: r = findstat(l, depth=0) # optional -- internet + sage: r._first_terms_cache # optional -- internet + {} + """ + return dict() + def first_terms(self, max_values=FINDSTAT_MAX_SUBMISSION_VALUES): """ - Return the pairs of the known terms which contain singletons as a dictionary. + Return the pairs of the known terms which contain + singletons, as a dictionary. EXAMPLES:: @@ -2660,10 +2697,14 @@ def first_terms(self, max_values=FINDSTAT_MAX_SUBMISSION_VALUES): 1: St000042 (quality [99, 100]) sage: r.first_terms() # optional -- internet {[]: 0, [(1, 2)]: 0} + """ - return dict(itertools.islice(((objs[0], vals[0]) - for objs, vals in self._known_terms - if len(vals) == 1), max_values)) + new_terms = self._known_terms[self._known_terms_number:max_values] + self._first_terms_cache.update((objs[0], vals[0]) + for objs, vals in new_terms + if len(vals) == 1) + self._known_terms_number = max(max_values, self._known_terms_number) + return dict(self._first_terms_cache) def _first_terms_raw(self, max_values): """