Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 10, 2025

📄 29% (0.29x) speedup for XLabsIPAdapterExtension._get_weight in invokeai/backend/flux/extensions/xlabs_ip_adapter_extension.py

⏱️ Runtime : 394 microseconds 305 microseconds (best of 250 runs)

📝 Explanation and details

The optimized code achieves a 29% speedup through a key caching optimization that eliminates redundant mathematical computations in repeated calls.

Primary Optimization - Step Calculation Caching:
The original code recalculates first_step and last_step on every call using math.floor() and math.ceil(). The optimized version caches these values along with total_num_timesteps as instance variables, only recalculating when the step count changes. This is highly effective because in typical AI inference workflows, the same total_num_timesteps is used across multiple timestep calls within a single generation.

Performance Impact by Test Pattern:

  • Best case (48.5% faster): Large-scale tests with repeated calls using the same total_num_timesteps benefit most from caching
  • Moderate gains (25-30% faster): Tests with consistent step counts show solid improvements
  • Slight regression (up to 35% slower): First calls with new total_num_timesteps pay a small cache setup cost, but this is amortized across subsequent calls

Secondary Optimizations:

  • Direct int casting: Replaces math.floor() with int() for non-negative multiplication results, avoiding function call overhead
  • Local variable assignment: Reduces attribute access by storing self._weight in a local variable

The line profiler shows the caching check (self._last_total_num_timesteps != total_num_timesteps) is hit 1,188 times but the expensive calculations only execute 36 times, demonstrating the effectiveness of avoiding redundant computations. This optimization is particularly valuable in AI model inference where the same step configuration is used repeatedly within each generation cycle.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1224 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime

import math
from typing import List, Union

imports

import pytest # used for our unit tests
import torch
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import
XLabsIPAdapterExtension

class XlabsIpAdapterFlux:
pass # Dummy class for initialization
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import
XLabsIPAdapterExtension

unit tests

-------- BASIC TEST CASES --------

def test_constant_weight_within_range():
"""Test constant weight is returned within begin/end step percent."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=0.7,
begin_step_percent=0.2,
end_step_percent=0.8
)
total_steps = 10
# Steps 2 (floor(0.210)) to 8 (ceil(0.810)) inclusive should return weight
for idx in range(2, 9):
codeflash_output = ext._get_weight(idx, total_steps) # 4.22μs -> 4.13μs (2.03% faster)

def test_constant_weight_outside_range():
"""Test constant weight returns 0.0 outside begin/end step percent."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=0.5,
begin_step_percent=0.3,
end_step_percent=0.7
)
total_steps = 10
# Steps 0-2 and 8-9 should return 0.0
for idx in range(0, 3):
codeflash_output = ext._get_weight(idx, total_steps) # 1.69μs -> 2.39μs (29.3% slower)
for idx in range(8, 10):
codeflash_output = ext._get_weight(idx, total_steps) # 714ns -> 553ns (29.1% faster)

def test_list_weight_within_range():
"""Test list weight returns correct value within range."""
weights = [0.1 * i for i in range(10)]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=weights,
begin_step_percent=0.0,
end_step_percent=1.0
)
total_steps = 10
for idx in range(10):
codeflash_output = ext._get_weight(idx, total_steps) # 4.37μs -> 4.27μs (2.34% faster)

def test_list_weight_outside_range():
"""Test list weight returns 0.0 outside begin/end step percent."""
weights = [1.0] * 10
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=weights,
begin_step_percent=0.2,
end_step_percent=0.8
)
total_steps = 10
# Steps 0,1,9 should return 0.0
for idx in [0, 1, 9]:
codeflash_output = ext._get_weight(idx, total_steps) # 1.76μs -> 2.34μs (24.7% slower)
# Steps 2-8 should return 1.0
for idx in range(2, 9):
codeflash_output = ext._get_weight(idx, total_steps) # 2.61μs -> 1.97μs (32.4% faster)

-------- EDGE TEST CASES --------

def test_begin_step_percent_zero():
"""Test begin_step_percent=0.0 includes step 0."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=2.5,
begin_step_percent=0.0,
end_step_percent=0.5
)
total_steps = 4
# first_step = 0, last_step = ceil(0.5*4)=2
for idx in range(0, 3):
codeflash_output = ext._get_weight(idx, total_steps) # 2.22μs -> 2.52μs (11.8% slower)
for idx in range(3, 4):
codeflash_output = ext._get_weight(idx, total_steps) # 396ns -> 289ns (37.0% faster)

def test_end_step_percent_one():
"""Test end_step_percent=1.0 includes last step."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=1.1,
begin_step_percent=0.5,
end_step_percent=1.0
)
total_steps = 5
# first_step = floor(0.55)=2, last_step=ceil(1.05)=5
for idx in range(2, 6):
codeflash_output = ext._get_weight(idx, total_steps) # 2.47μs -> 2.67μs (7.63% slower)
for idx in range(0, 2):
codeflash_output = ext._get_weight(idx, total_steps) # 744ns -> 516ns (44.2% faster)

def test_begin_equals_end_percent():
"""Test begin_step_percent == end_step_percent only one step is active."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=0.9,
begin_step_percent=0.3,
end_step_percent=0.3
)
total_steps = 10
# first_step = floor(0.310)=3, last_step=ceil(0.310)=3
for idx in range(10):
expected = 0.9 if idx == 3 else 0.0
codeflash_output = ext._get_weight(idx, total_steps) # 4.08μs -> 3.92μs (4.13% faster)

def test_negative_weight():
"""Test negative constant weight is returned in range."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=-1.5,
begin_step_percent=0.0,
end_step_percent=1.0
)
total_steps = 5
for idx in range(5):
codeflash_output = ext._get_weight(idx, total_steps) # 2.75μs -> 2.99μs (8.16% slower)

def test_weight_list_with_negative_values():
"""Test weight list with negative values."""
weights = [-0.5, 0.0, 0.5, 1.0]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=weights,
begin_step_percent=0.0,
end_step_percent=1.0
)
total_steps = 4
for idx in range(4):
codeflash_output = ext._get_weight(idx, total_steps) # 2.44μs -> 2.82μs (13.4% slower)

def test_weight_list_length_mismatch_raises():
"""Test that out-of-bounds index raises IndexError for weight list."""
weights = [0.1, 0.2, 0.3]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=weights,
begin_step_percent=0.0,
end_step_percent=1.0
)
total_steps = 5
# timestep_index 3 and 4 out of range for weights
with pytest.raises(IndexError):
ext._get_weight(3, total_steps) # 1.72μs -> 2.29μs (25.0% slower)
with pytest.raises(IndexError):
ext._get_weight(4, total_steps) # 764ns -> 716ns (6.70% faster)

def test_invalid_percent_values():
"""Test percent values outside [0,1] behave as expected."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=2.0,
begin_step_percent=-0.5,
end_step_percent=1.5
)
total_steps = 4
# first_step = floor(-0.54) = -2, last_step = ceil(1.54) = 6
# So all steps 0-3 should be in range
for idx in range(4):
codeflash_output = ext._get_weight(idx, total_steps) # 2.48μs -> 2.77μs (10.2% slower)
# Out of range lower
codeflash_output = ext._get_weight(-3, total_steps) # 394ns -> 303ns (30.0% faster)
# Out of range upper
codeflash_output = ext._get_weight(7, total_steps) # 368ns -> 288ns (27.8% faster)

def test_zero_timesteps():
"""Test behavior when total_num_timesteps is zero."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=1.0,
begin_step_percent=0.0,
end_step_percent=1.0
)
# Should always return 0.0 since first_step=0, last_step=0
codeflash_output = ext._get_weight(0, 0) # 1.22μs -> 1.78μs (31.8% slower)
codeflash_output = ext._get_weight(1, 0) # 508ns -> 438ns (16.0% faster)
codeflash_output = ext._get_weight(-1, 0) # 355ns -> 275ns (29.1% faster)

def test_non_integer_timesteps():
"""Test behavior when total_num_timesteps is not integer (should be coerced by math.floor/ceil)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(1),
weight=1.0,
begin_step_percent=0.5,
end_step_percent=1.0
)
# Non-integer total_num_timesteps
total_steps = 5.5
first_step = math.floor(0.5 * total_steps) # 2
last_step = math.ceil(1.0 * total_steps) # 6
for idx in range(int(total_steps)+2): # up to 7
if idx < first_step or idx > last_step:
codeflash_output = ext._get_weight(idx, total_steps)
else:
codeflash_output = ext._get_weight(idx, total_steps)

-------- LARGE SCALE TEST CASES --------

def test_large_list_weight():
"""Test large list of weights (1000 elements)."""
weights = [float(i) for i in range(1000)]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(10),
weight=weights,
begin_step_percent=0.0,
end_step_percent=1.0
)
total_steps = 1000
# Sample a few indices
for idx in [0, 100, 500, 999]:
codeflash_output = ext._get_weight(idx, total_steps) # 2.81μs -> 2.96μs (4.91% slower)

def test_large_constant_weight():
"""Test constant weight with large number of timesteps."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(100),
weight=3.14,
begin_step_percent=0.0,
end_step_percent=1.0
)
total_steps = 1000
# Sample a few indices
for idx in [0, 100, 500, 999]:
codeflash_output = ext._get_weight(idx, total_steps) # 2.52μs -> 2.75μs (8.15% slower)

def test_large_list_weight_partial_range():
"""Test large list with partial begin/end percent."""
weights = [float(i) for i in range(1000)]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(10),
weight=weights,
begin_step_percent=0.1,
end_step_percent=0.9
)
total_steps = 1000
first_step = math.floor(0.1 * total_steps) # 100
last_step = math.ceil(0.9 * total_steps) # 900
# Indices outside range
for idx in [0, 50, 99, 901, 999]:
codeflash_output = ext._get_weight(idx, total_steps) # 2.21μs -> 2.52μs (12.3% slower)
# Indices inside range
for idx in [100, 500, 900]:
codeflash_output = ext._get_weight(idx, total_steps) # 1.34μs -> 1.02μs (30.9% faster)

def test_performance_large_scale():
"""Test that large scale call does not crash or hang (performance test)."""
weights = [1.0] * 1000
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=torch.zeros(100),
weight=weights,
begin_step_percent=0.0,
end_step_percent=1.0
)
total_steps = 1000
# Check that all weights are correct for a subset
for idx in range(0, 1000, 100):
codeflash_output = ext._get_weight(idx, total_steps) # 4.48μs -> 4.07μs (10.0% faster)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

#------------------------------------------------
import math
from typing import List, Union

imports

import pytest # used for our unit tests
import torch
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import
XLabsIPAdapterExtension

class XlabsIpAdapterFlux:
# Dummy class for model argument
pass
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import
XLabsIPAdapterExtension

unit tests

Helper function for creating dummy tensors

def make_dummy_tensor(size=1):
return torch.zeros(size)

-------------------------------

1. Basic Test Cases

-------------------------------

def test_constant_weight_within_range():
"""Test constant float weight, timestep within active range."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=2.5,
begin_step_percent=0.2,
end_step_percent=0.8,
)
total_steps = 10
# first_step = 2, last_step = 8
for i in range(2, 9):
codeflash_output = ext._get_weight(i, total_steps) # 3.40μs -> 3.57μs (4.76% slower)

def test_constant_weight_outside_range():
"""Test constant float weight, timestep outside active range."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=1.0,
begin_step_percent=0.3,
end_step_percent=0.6,
)
total_steps = 10
# first_step = 3, last_step = 6
for i in [0, 1, 2, 7, 8, 9]:
codeflash_output = ext._get_weight(i, total_steps) # 2.68μs -> 2.94μs (8.93% slower)

def test_list_weight_within_range():
"""Test list weight, timestep within active range."""
weights = [0.0, 0.1, 0.2, 0.3, 0.4]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=weights,
begin_step_percent=0.0,
end_step_percent=0.8, # last_step = ceil(0.8*5) = 4
)
total_steps = 5
for i in range(0, 5):
codeflash_output = ext._get_weight(i, total_steps) # 2.78μs -> 3.09μs (9.91% slower)

def test_list_weight_outside_range():
"""Test list weight, timestep outside active range."""
weights = [1, 2, 3, 4, 5, 6]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=weights,
begin_step_percent=0.2,
end_step_percent=0.5, # first_step=1, last_step=3
)
total_steps = 6
for i in [0, 4, 5]:
codeflash_output = ext._get_weight(i, total_steps) # 1.74μs -> 2.30μs (24.5% slower)

def test_begin_equals_end_percent():
"""Test when begin_step_percent == end_step_percent (only one step active)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=42.0,
begin_step_percent=0.5,
end_step_percent=0.5,
)
total_steps = 10
# first_step = floor(5) = 5, last_step = ceil(5) = 5
for i in range(0, 10):
expected = 42.0 if i == 5 else 0.0
codeflash_output = ext._get_weight(i, total_steps) # 4.06μs -> 3.81μs (6.37% faster)

-------------------------------

2. Edge Test Cases

-------------------------------

def test_zero_total_steps():
"""Test with total_num_timesteps=0 (should never be in range)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=3.0,
begin_step_percent=0.0,
end_step_percent=1.0,
)
codeflash_output = ext._get_weight(0, 0) # 1.26μs -> 1.81μs (30.1% slower)

def test_negative_timestep_index():
"""Test with negative timestep_index (should always return 0.0)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=1.23,
begin_step_percent=0.0,
end_step_percent=1.0,
)
codeflash_output = ext._get_weight(-1, 10) # 1.04μs -> 1.61μs (35.6% slower)

def test_timestep_index_greater_than_total():
"""Test with timestep_index >= total_num_timesteps (should return 0.0)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=7.0,
begin_step_percent=0.0,
end_step_percent=1.0,
)
codeflash_output = ext._get_weight(10, 10) # 1.23μs -> 1.84μs (33.3% slower)
codeflash_output = ext._get_weight(100, 10) # 465ns -> 403ns (15.4% faster)

def test_begin_percent_greater_than_end_percent():
"""Test begin_step_percent > end_step_percent (should never be in range)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=5.0,
begin_step_percent=0.8,
end_step_percent=0.2,
)
total_steps = 10
for i in range(0, 10):
codeflash_output = ext._get_weight(i, total_steps) # 3.76μs -> 3.81μs (1.36% slower)

def test_begin_percent_zero_end_percent_one():
"""Test begin_step_percent=0.0, end_step_percent=1.0 (all steps active)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=9.9,
begin_step_percent=0.0,
end_step_percent=1.0,
)
total_steps = 7
for i in range(0, 7):
codeflash_output = ext._get_weight(i, total_steps) # 3.41μs -> 3.51μs (2.82% slower)

def test_list_weight_length_mismatch():
"""Test list weight with length < total_num_timesteps (should raise IndexError)."""
weights = [1.0, 2.0, 3.0]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=weights,
begin_step_percent=0.0,
end_step_percent=1.0,
)
# Should raise IndexError for i >= 3
with pytest.raises(IndexError):
ext._get_weight(3, 4) # 1.81μs -> 2.37μs (23.5% slower)

def test_list_weight_non_float_values():
"""Test list weight with non-float values (should return as-is)."""
weights = [None, "a", 3.14]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=weights,
begin_step_percent=0.0,
end_step_percent=1.0,
)
# Should return None for i=0, "a" for i=1, 3.14 for i=2
codeflash_output = ext._get_weight(0, 3) # 1.31μs -> 1.75μs (25.5% slower)
codeflash_output = ext._get_weight(1, 3) # 607ns -> 512ns (18.6% faster)
codeflash_output = ext._get_weight(2, 3) # 334ns -> 267ns (25.1% faster)

def test_weight_zero():
"""Test with weight=0.0 (should return 0.0 in active range)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=0.0,
begin_step_percent=0.0,
end_step_percent=1.0,
)
codeflash_output = ext._get_weight(0, 1) # 1.18μs -> 1.82μs (34.9% slower)

def test_weight_negative():
"""Test with negative weight (should return negative value in active range)."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=-5.5,
begin_step_percent=0.0,
end_step_percent=1.0,
)
codeflash_output = ext._get_weight(0, 1) # 1.15μs -> 1.78μs (35.6% slower)

def test_float_precision_percent():
"""Test with begin/end percents that are not exactly representable."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=1.0,
begin_step_percent=0.3333333,
end_step_percent=0.6666666,
)
total_steps = 9
# first_step = floor(2.9999997) = 2, last_step = ceil(6.0000004) = 7
for i in range(0, 9):
expected = 1.0 if 2 <= i <= 7 else 0.0
codeflash_output = ext._get_weight(i, total_steps) # 4.08μs -> 3.95μs (3.31% faster)

-------------------------------

3. Large Scale Test Cases

-------------------------------

def test_large_list_weight():
"""Test with a large list of weights (length 1000)."""
weights = [float(i) for i in range(1000)]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=weights,
begin_step_percent=0.1,
end_step_percent=0.9,
)
total_steps = 1000
first_step = math.floor(0.1 * total_steps)
last_step = math.ceil(0.9 * total_steps)
# Test a few in-range and out-of-range
for i in [0, first_step - 1, last_step + 1, total_steps - 1]:
codeflash_output = ext._get_weight(i, total_steps) # 1.94μs -> 2.33μs (16.8% slower)
# Test some in-range values
for i in [first_step, (first_step + last_step)//2, last_step]:
codeflash_output = ext._get_weight(i, total_steps) # 1.33μs -> 1.06μs (25.3% faster)

def test_large_constant_weight():
"""Test with a large number of steps and constant weight."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=123.456,
begin_step_percent=0.0,
end_step_percent=1.0,
)
total_steps = 1000
for i in [0, 500, 999]:
codeflash_output = ext._get_weight(i, total_steps) # 2.08μs -> 2.52μs (17.4% slower)

def test_large_sparse_active_range():
"""Test with a large number of steps, but narrow active range."""
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=77.0,
begin_step_percent=0.45,
end_step_percent=0.46,
)
total_steps = 1000
first_step = math.floor(0.45 * total_steps)
last_step = math.ceil(0.46 * total_steps)
# Only steps between first_step and last_step (inclusive) are active
for i in range(0, total_steps):
expected = 77.0 if first_step <= i <= last_step else 0.0
codeflash_output = ext._get_weight(i, total_steps) # 294μs -> 198μs (48.5% faster)

def test_large_list_weight_partial_active():
"""Test with large list of weights, partial active range."""
weights = [i for i in range(1000)]
ext = XLabsIPAdapterExtension(
model=XlabsIpAdapterFlux(),
image_prompt_clip_embed=make_dummy_tensor(),
weight=weights,
begin_step_percent=0.25,
end_step_percent=0.75,
)
total_steps = 1000
first_step = math.floor(0.25 * total_steps)
last_step = math.ceil(0.75 * total_steps)
# Only weights[first_step] to weights[last_step] should be returned, rest 0.0
for i in [0, first_step - 1, last_step + 1, total_steps - 1]:
codeflash_output = ext._get_weight(i, total_steps) # 1.93μs -> 2.32μs (16.7% slower)
for i in [first_step, last_step]:
codeflash_output = ext._get_weight(i, total_steps) # 969ns -> 780ns (24.2% faster)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-XLabsIPAdapterExtension._get_weight-mhtsxaen and push.

Codeflash Static Badge

The optimized code achieves a **29% speedup** through a key **caching optimization** that eliminates redundant mathematical computations in repeated calls.

**Primary Optimization - Step Calculation Caching:**
The original code recalculates `first_step` and `last_step` on every call using `math.floor()` and `math.ceil()`. The optimized version caches these values along with `total_num_timesteps` as instance variables, only recalculating when the step count changes. This is highly effective because in typical AI inference workflows, the same `total_num_timesteps` is used across multiple timestep calls within a single generation.

**Performance Impact by Test Pattern:**
- **Best case (48.5% faster)**: Large-scale tests with repeated calls using the same `total_num_timesteps` benefit most from caching
- **Moderate gains (25-30% faster)**: Tests with consistent step counts show solid improvements
- **Slight regression (up to 35% slower)**: First calls with new `total_num_timesteps` pay a small cache setup cost, but this is amortized across subsequent calls

**Secondary Optimizations:**
- **Direct int casting**: Replaces `math.floor()` with `int()` for non-negative multiplication results, avoiding function call overhead
- **Local variable assignment**: Reduces attribute access by storing `self._weight` in a local variable

The line profiler shows the caching check (`self._last_total_num_timesteps != total_num_timesteps`) is hit 1,188 times but the expensive calculations only execute 36 times, demonstrating the effectiveness of avoiding redundant computations. This optimization is particularly valuable in AI model inference where the same step configuration is used repeatedly within each generation cycle.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 10, 2025 23:56
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash labels Nov 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant