Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def run_test(
strict_loading: bool = True,
dynamic_shapes: Dict = None,
check_num_matches: int = None, # Additional check of # patterns detected
skip_output_assert: bool = False,
*args, # Additional arguments for transform
) -> GraphModule:
# run model once
Expand All @@ -52,7 +53,8 @@ def run_test(
num_params_gm = count_parameters(gm)

assert num_params_model == num_params_gm
torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol)
if not skip_output_assert:
torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol)

# graph transformation + check
if check_num_matches:
Expand All @@ -76,11 +78,11 @@ def run_test(
# check if the transformation worked
assert check_transformed_graph(gm_transformed)

if strict_loading:
if strict_loading and not skip_output_assert:
# check if output equals without loading state dict
torch.testing.assert_close(y_model, y_transformed, atol=atol, rtol=rtol)

if test_load_hook:
if test_load_hook and not skip_output_assert:
# check if loading hook works from original state dict
reset_parameters(gm_transformed)
y_random = gm_transformed(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import tensorrt_llm._torch.auto_deploy # noqa: F401

torch.manual_seed(0)
torch.manual_seed(1234)


@pytest.mark.parametrize("head_dim", [64, 256]) # head_dim must be a multiple of 64
Expand Down Expand Up @@ -95,7 +95,7 @@ def test_flashinfer_custom_op_and_hf_impl(dtype, atol, rtol, head_dim):
@pytest.mark.parametrize(
"dtype,atol,rtol",
[
(torch.bfloat16, 1e-5, 1e-5),
(torch.bfloat16, 1e-4, 1e-4),
(torch.float16, 5e-4, 5e-4),
],
ids=["bfloat16", "float16"], # q/k must be in half precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,15 @@ def verify_matcher(gm):
@pytest.mark.parametrize("has_mask", [True, False])
@pytest.mark.parametrize("use_division", [False, True])
@pytest.mark.parametrize(
"dropout, rtol, atol",
"dropout, skip_output_assert",
[
(0.0, 1e-3, 1e-3), # (dropout, rtol, atol) for no dropout
(0.1, float("inf"), float("inf")), # (dropout, rtol, atol) for dropout=0.1
(0.0, False),
(0.1, True), # skip all_close assertion for dropout=0.1 for its non-deterministic output
],
)
@pytest.mark.parametrize("model_type", ["standard", "complex"])
@torch.inference_mode()
def test_match_eager_attention(has_mask, use_division, dropout, rtol, atol, model_type):
def test_match_eager_attention(has_mask, use_division, dropout, skip_output_assert, model_type):
# Set a fixed seed for consistent dropout behavior in tests
torch.manual_seed(0)

Expand Down Expand Up @@ -637,11 +637,12 @@ def verify_matcher(gm):
match_eager_attention,
verify_matcher,
lambda num_p_og: num_p_og,
atol=atol,
rtol=rtol,
test_load_hook=True,
atol=1e-3,
rtol=1e-3,
test_load_hook=False,
strict_loading=True,
dynamic_shapes=dynamic_shapes,
skip_output_assert=skip_output_assert,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def test_quantization(quant_config, atol, rtol, num_p_og):
True, # test_load_hook
False, # strict_loading
None, # dynamic_shapes
None, # check_num_matches
False, # skip_output_assert
quant_config,
)

Expand Down Expand Up @@ -133,6 +135,7 @@ def test_bmm_quantization(quant_config, atol, rtol, num_p_og, model_class):
False, # strict_loading
None, # dynamic_shapes
None, # check_num_matches
False, # skip_output_assert
quant_config,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def checker(gm):
True, # strict_loading
dyn, # dynamic_shapes
None, # check_num_matches
False, # skip_output_assert
target_layout,
)
elif transformation == "match":
Expand All @@ -284,6 +285,7 @@ def checker(gm):
True, # strict_loading
dyn, # dynamic_shapes
1, # check_num_matches
False, # skip_output_assert
)
else:
_ = run_test(
Expand All @@ -298,6 +300,7 @@ def checker(gm):
True, # strict_loading
dyn, # dynamic_shapes
None, # check_num_matches
False, # skip_output_assert
)


Expand Down Expand Up @@ -428,6 +431,7 @@ def checker(gm):
True, # strict_loading
dynamic_shapes, # dynamic_shapes
None, # check_num_matches
False, # skip_output_assert
target_layout,
)
else:
Expand All @@ -443,4 +447,5 @@ def checker(gm):
True, # strict_loading
dynamic_shapes, # dynamic_shapes
1, # check_num_matches
False, # skip_output_assert
)