Skip to content

Commit 4d96c47

Browse files
committed
Remove strict=False in hot loops
This is actually slower than just not specifying it
1 parent e98cbbc commit 4d96c47

File tree

15 files changed

+44
-50
lines changed

15 files changed

+44
-50
lines changed

pyproject.toml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,8 @@ exclude = ["doc/", "pytensor/_version.py"]
130130
docstring-code-format = true
131131

132132
[tool.ruff.lint]
133-
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
133+
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
134134
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
135-
unfixable = [
136-
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
137-
"B905",
138-
]
139135

140136

141137
[tool.ruff.lint.isort]

pytensor/compile/builders.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,6 @@ def clone(self):
873873

874874
def perform(self, node, inputs, outputs):
875875
variables = self.fn(*inputs)
876-
assert len(variables) == len(outputs)
877-
# strict=False because asserted above
878-
for output, variable in zip(outputs, variables, strict=False):
876+
# strict=None because we are in a hot loop
877+
for output, variable in zip(outputs, variables):
879878
output[0] = variable

pytensor/link/basic.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,12 @@ def make_all(
373373

374374
# The function that actually runs your program is one of the f's in streamline.
375375
f = streamline(
376-
fgraph, thunks, order, post_thunk_old_storage, no_recycling=no_recycling
376+
fgraph,
377+
thunks,
378+
order,
379+
post_thunk_old_storage=post_thunk_old_storage,
380+
no_recycling=no_recycling,
381+
output_storage=output_storage,
377382
)
378383

379384
f.allow_gc = (
@@ -539,14 +544,14 @@ def make_thunk(self, **kwargs):
539544

540545
def f():
541546
for inputs in input_lists[1:]:
542-
# strict=False because we are in a hot loop
543-
for input1, input2 in zip(inputs0, inputs, strict=False):
547+
# strict=None because we are in a hot loop
548+
for input1, input2 in zip(inputs0, inputs):
544549
input2.storage[0] = copy(input1.storage[0])
545550
for x in to_reset:
546551
x[0] = None
547552
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
548-
# strict=False because we are in a hot loop
549-
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
553+
# strict=None because we are in a hot loop
554+
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
550555
try:
551556
wrapper(self.fgraph, i, node, *thunks)
552557
except Exception:
@@ -668,8 +673,8 @@ def thunk(
668673
# since the error may come from any of them?
669674
raise_with_op(self.fgraph, output_nodes[0], thunk)
670675

671-
# strict=False because we are in a hot loop
672-
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
676+
# strict=None because we are in a hot loop
677+
for o_storage, o_val in zip(thunk_outputs, outputs):
673678
o_storage[0] = o_val
674679

675680
thunk.inputs = thunk_inputs

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,10 @@ def py_perform_return(inputs):
312312
else:
313313

314314
def py_perform_return(inputs):
315-
# strict=False because we are in a hot loop
315+
# strict=None because we are in a hot loop
316316
return tuple(
317317
out_type.filter(out[0])
318-
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
318+
for out_type, out in zip(output_types, py_perform(inputs))
319319
)
320320

321321
@numba_njit

pytensor/link/numba/dispatch/cython_support.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@ def __wrapper_address__(self):
166166
def __call__(self, *args, **kwargs):
167167
# no strict argument because of the JIT
168168
# TODO: check
169-
args = [
170-
dtype(arg)
171-
for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905
172-
]
169+
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
173170
if self.has_pyx_skip_dispatch():
174171
output = self._pyfunc(*args[:-1], **kwargs)
175172
else:

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def ravelmultiindex(*inp):
186186
new_arr = arr.T.astype(np.float64).copy()
187187
for i, b in enumerate(new_arr):
188188
# no strict argument to this zip because numba doesn't support it
189-
for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905
189+
for j, (d, v) in enumerate(zip(shape, b)):
190190
if v < 0 or v >= d:
191191
mode_fn(new_arr, i, j, v, d)
192192

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def block_diag(*arrs):
183183

184184
r, c = 0, 0
185185
# no strict argument because it is incompatible with numba
186-
for arr, shape in zip(arrs, shapes): # noqa: B905
186+
for arr, shape in zip(arrs, shapes):
187187
rr, cc = shape
188188
out[r : r + rr, c : c + cc] = arr
189189
r += rr

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def advanced_subtensor_multiple_vector(x, *idxs):
219219
shape_aft = x_shape[after_last_axis:]
220220
out_shape = (*shape_bef, *idx_shape, *shape_aft)
221221
out_buffer = np.empty(out_shape, dtype=x.dtype)
222-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
222+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
223223
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
224224
return out_buffer
225225

@@ -253,7 +253,7 @@ def advanced_set_subtensor_multiple_vector(x, y, *idxs):
253253
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
254254

255255
for outer in np.ndindex(x_shape[:first_axis]):
256-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
256+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
257257
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
258258
return out
259259

@@ -275,7 +275,7 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
275275
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
276276

277277
for outer in np.ndindex(x_shape[:first_axis]):
278-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
278+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
279279
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
280280
return out
281281

@@ -314,7 +314,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
314314
if not len(idxs) == len(vals):
315315
raise ValueError("The number of indices and values must match.")
316316
# no strict argument because incompatible with numba
317-
for idx, val in zip(idxs, vals): # noqa: B905
317+
for idx, val in zip(idxs, vals):
318318
x[idx] = val
319319
return x
320320
else:
@@ -342,7 +342,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
342342
raise ValueError("The number of indices and values must match.")
343343
# no strict argument because unsupported by numba
344344
# TODO: this doesn't come up in tests
345-
for idx, val in zip(idxs, vals): # noqa: B905
345+
for idx, val in zip(idxs, vals):
346346
x[idx] += val
347347
return x
348348

pytensor/link/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ def streamline_default_f():
190190
for x in no_recycling:
191191
x[0] = None
192192
try:
193-
# strict=False because we are in a hot loop
193+
# strict=None because we are in a hot loop
194194
for thunk, node, old_storage in zip(
195-
thunks, order, post_thunk_old_storage, strict=False
195+
thunks, order, post_thunk_old_storage
196196
):
197197
thunk()
198198
for old_s in old_storage:
@@ -207,8 +207,8 @@ def streamline_nice_errors_f():
207207
for x in no_recycling:
208208
x[0] = None
209209
try:
210-
# strict=False because we are in a hot loop
211-
for thunk, node in zip(thunks, order, strict=False):
210+
# strict=None because we are in a hot loop
211+
for thunk, node in zip(thunks, order):
212212
thunk()
213213
except Exception:
214214
raise_with_op(fgraph, node, thunk)

pytensor/scalar/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4416,8 +4416,8 @@ def make_node(self, *inputs):
44164416

44174417
def perform(self, node, inputs, output_storage):
44184418
outputs = self.py_perform_fn(*inputs)
4419-
# strict=False because we are in a hot loop
4420-
for storage, out_val in zip(output_storage, outputs, strict=False):
4419+
# strict=None because we are in a hot loop
4420+
for storage, out_val in zip(output_storage, outputs):
44214421
storage[0] = out_val
44224422

44234423
def grad(self, inputs, output_grads):

0 commit comments

Comments
 (0)