diff --git a/controller/ndslice.py b/controller/ndslice.py index 7dca3d5..26e66c1 100644 --- a/controller/ndslice.py +++ b/controller/ndslice.py @@ -33,12 +33,32 @@ def __iter__(self): yield self.offset + sum(i*s for i, s in zip(loc, self.strides)) def union(self, other: 'NDSlice') -> List['NDSlice']: - raise NotImplementedError() + # return a merged slice for pairs that meet regularity conditions + # and can be detected cheaply, otherwise just return the pair. + def sort(s): + strides, sizes = zip(*sorted(zip(s.strides, s.sizes), reverse=True)) + return NDSlice(s.offset, list(sizes), list(strides)) + s1, s2 = sort(self), sort(other) + s1, s2 = sorted([s1, s2], key=lambda s: s.offset) + extent = s1.sizes[0] * s1.strides[0] + if s1.offset + extent == s2.offset: + if s1.strides == s2.strides and s1.sizes[1:] == s2.sizes[1:]: + # congruent inner shape and strides, concatenated on long axis + sizes = [s1.sizes[0] + s2.sizes[0]] + s1.sizes[1:] + return [NDSlice(s1.offset, sizes, s1.strides)] + # fallback + return [self, other] def contains_any(self, start: int, end: int) -> bool: # does this slice contain any of the elements in [start, end) # will be used to figure out who to broadcast to. - raise NotImplementedError() + base = self.offset + for stride, size in sorted(zip(self.strides, self.sizes), reverse=True): + if end <= base or start >= base + size * stride: + return False + shift = max(0, min(size - 1, (end - base - 1) // stride)) + base += shift * stride + return base >= start and base < end def index(self, value: int) -> int: # return index where self[index] == value, diff --git a/tests/controller/test_controller.py b/tests/controller/test_controller.py index 21267a8..0c21028 100644 --- a/tests/controller/test_controller.py +++ b/tests/controller/test_controller.py @@ -333,6 +333,33 @@ def check(s): check(NDSlice(24, [4, 3, 8], [48*2, 16, 2])) check(NDSlice(37, [], [])) + def check_contains_any(s, gen): + elems = set(s) + max_elem = max(elems) + span = max_elem - s.offset + 1 + for _ in range(len(elems) * 32): + start = gen.randrange(s.offset - 10, max_elem + 10) + end = start + gen.choices(range(span), range(span, 0, -1))[0] + contains = s.contains_any(start, end) + ref = any(x in elems for x in range(start, end)) + if contains != ref: + raise RuntimeError( + f"{s}.contains_any({start}, {end}) broken, false {'positive' if contains else 'negative'}" + ) + + def check_union(s1, gen): + for _ in range(10): + s2 = fuzz(gen) + ref_elems = set(s1) | set(s2) + u = s1.union(s2) + elems = set() + for s in u: + elems |= set(s) + if elems != ref_elems: + raise RuntimeError( + f"{s1}.union({s2}) broken, {'extra' if len(elems) > len(ref_elems) else 'missing'} elements" + ) + def fuzz(gen): ndim = gen.choices([1, 2, 3, 4], [1, 2, 3, 4])[0] sizes = [gen.randrange(1, 10) for _ in range(ndim)] @@ -354,6 +381,8 @@ def fuzz(gen): for _ in range(1000): s = fuzz(gen) check(s) + check_contains_any(s, gen) + check_union(s, gen)