-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[tests] unbloat tests/lora/utils.py
#11845
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -103,34 +103,6 @@ def get_dummy_inputs(self, with_generator=True): | |||
|
|||
return noise, input_ids, pipeline_inputs | |||
|
|||
@unittest.skip("Not supported in AuraFlow.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are skipped appropriately from the parent method. I think it's okay in this case, because it eases things a bit.
tests/lora/utils.py
tests/lora/utils.py
I haven't checked the PR yet, but I was wondering: When I do bigger refactors of tests, I always check the line coverage before and after the refactor and ensure that they're the same (in PEFT we use |
Indeed, it's important. Do you have any more guidelines for me to do that? |
So you can install
It can be a bit difficult to parse what has changed, but basically, you want the |
@BenjaminBossan I reran with: CUDA_VISIBLE_DEVICES="" pytest \
-n 24 --max-worker-restart=0 --dist=loadfile \
--cov=src/diffusers/ \
--cov-report=term-missing \
--cov-report=json:unbloat.json \
tests/lora/ on this branch and I then used Gemini to report the file locations where this branch has reduced coverage: Coverage Comparison Summary
==============================
📊 Coverage Changes:
- utils/testing_utils.py: 33.53% -> 33.38% (-0.15%)
- schedulers/scheduling_k_dpm_2_ancestral_discrete.py: 16.40% -> 0.00% (-16.40%)
- utils/typing_utils.py: 64.86% -> 8.11% (-56.75%)
- models/unets/unet_2d_condition.py: 53.81% -> 53.36% (-0.45%)
- configuration_utils.py: 66.24% -> 46.50% (-19.74%)
- schedulers/scheduling_dpmsolver_singlestep.py: 8.67% -> 0.00% (-8.67%)
- models/model_loading_utils.py: 33.18% -> 17.73% (-15.45%)
- pipelines/pipeline_loading_utils.py: 28.78% -> 16.77% (-12.01%)
- utils/hub_utils.py: 30.26% -> 22.56% (-7.70%)
- schedulers/scheduling_euler_discrete.py: 22.22% -> 14.14% (-8.08%)
- utils/logging.py: 52.27% -> 50.76% (-1.51%)
- utils/peft_utils.py: 82.74% -> 82.23% (-0.51%)
- schedulers/scheduling_deis_multistep.py: 10.63% -> 0.00% (-10.63%)
- schedulers/scheduling_unipc_multistep.py: 7.60% -> 0.00% (-7.60%)
- schedulers/scheduling_utils.py: 98.00% -> 84.00% (-14.00%)
- pipelines/pipeline_utils.py: 52.76% -> 33.55% (-19.21%)
- schedulers/scheduling_edm_euler.py: 22.22% -> 0.00% (-22.22%)
- models/modeling_utils.py: 51.14% -> 23.00% (-28.14%)
- schedulers/scheduling_k_dpm_2_discrete.py: 17.17% -> 0.00% (-17.17%)
============================== Codeimport json
from decimal import Decimal, ROUND_HALF_UP
def get_coverage_data(report_path):
"""Loads a JSON coverage report and extracts file coverage data."""
with open(report_path, 'r') as f:
data = json.load(f)
file_coverage = {}
for filename, stats in data['files'].items():
# Clean up the filename for better readability
clean_filename = filename.replace('src/diffusers/', '')
# Calculate coverage percentage with two decimal places
covered_lines = stats['summary']['covered_lines']
num_statements = stats['summary']['num_statements']
if num_statements > 0:
coverage_percent = (Decimal(covered_lines) / Decimal(num_statements)) * 100
file_coverage[clean_filename] = coverage_percent.quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
else:
file_coverage[clean_filename] = Decimal('0.0')
return file_coverage
def compare_coverage(main_report, feature_report):
"""Compares two coverage reports and prints a summary of the differences."""
main_coverage = get_coverage_data(main_report)
feature_coverage = get_coverage_data(feature_report)
main_files = set(main_coverage.keys())
feature_files = set(feature_coverage.keys())
# --- Report Summary ---
print("Coverage Comparison Summary\n" + "="*30)
# Files with changed coverage
common_files = main_files.intersection(feature_files)
changed_coverage_files = {
file: (main_coverage[file], feature_coverage[file])
for file in common_files if main_coverage[file] != feature_coverage[file]
}
if changed_coverage_files:
print("\n📊 Coverage Changes:")
for file, (main_cov, feature_cov) in changed_coverage_files.items():
change = feature_cov - main_cov
print(f" - {file}: {main_cov}% -> {feature_cov}% ({'+' if change > 0 else ''}{change.quantize(Decimal('0.01'))}%)")
else:
print("\nNo change in coverage for existing files.")
# New files in the feature branch
new_files = feature_files - main_files
if new_files:
print("\n✨ New Files in Feature Branch:")
for file in new_files:
print(f" - {file} (Coverage: {feature_coverage[file]}%)")
# Removed files from the feature branch
removed_files = main_files - feature_files
if removed_files:
print("\n🗑️ Removed Files from Feature Branch:")
for file in removed_files:
print(f" - {file}")
print("\n" + "="*30)
if __name__ == "__main__":
compare_coverage('coverage_main.json', 'coverage_feature.json') Will try to improve it / see what is going on. I think coverage reductions in files like |
Nice, there seem to be some big drops in a couple files, definitely worth investigating. I skimmed the script and I think it's not quite correct. If, say, before, foo.py was covered line 0-10 of 20 total lines, and after, lines 10-20 are covered, the difference would be reported as 0. But in reality, 10 lines are being missed. So the more accurate way would be to check the |
@BenjaminBossan here are my findings. First, here's the updated comparison script: Codeimport json
from decimal import Decimal, ROUND_HALF_UP
def parse_coverage_report(report_path: str) -> dict:
"""
Loads a JSON coverage report and extracts detailed data for each file,
including missing lines and coverage percentage.
"""
try:
with open(report_path, 'r') as f:
data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
print(f"Error loading {report_path}: {e}")
return {}
coverage_data = {}
for filename, stats in data.get('files', {}).items():
summary = stats.get('summary', {})
covered = summary.get('covered_lines', 0)
total = summary.get('num_statements', 0)
# Calculate coverage percentage
if total > 0:
percentage = (Decimal(covered) / Decimal(total)) * 100
else:
percentage = Decimal('100.0') # No statements means 100% covered
coverage_data[filename] = {
'missing_lines': set(stats.get('missing_lines', [])),
'coverage_pct': percentage.quantize(Decimal('0.01'), rounding=ROUND_HALF_UP)
}
return coverage_data
def format_lines_as_ranges(lines: list[int]) -> str:
"""Converts a list of line numbers into a compact string of ranges."""
if not lines:
return ""
ranges = []
start = end = lines[0]
for i in range(1, len(lines)):
if lines[i] == end + 1:
end = lines[i]
else:
ranges.append(f"{start}-{end}" if start != end else f"{start}")
start = end = lines[i]
ranges.append(f"{start}-{end}" if start != end else f"{start}")
return ", ".join(ranges)
def find_and_report_coverage_changes(main_report_path: str, feature_report_path: str):
"""
Compares two coverage reports and prints a detailed report on any
lost coverage, including percentages and specific line numbers.
"""
main_data = parse_coverage_report(main_report_path)
feature_data = parse_coverage_report(feature_report_path)
lost_coverage_report = {}
# Find files with lost line coverage
for filename, main_file_data in main_data.items():
if filename in feature_data:
feature_file_data = feature_data[filename]
# Find lines that are missing now but were NOT missing before
newly_missed_lines = sorted(list(
feature_file_data['missing_lines'] - main_file_data['missing_lines']
))
# Record if there are newly missed lines OR if the percentage has dropped
# (e.g., due to new uncovered lines being added)
if newly_missed_lines or feature_file_data['coverage_pct'] < main_file_data['coverage_pct']:
lost_coverage_report[filename] = {
'lines': newly_missed_lines,
'main_pct': main_file_data['coverage_pct'],
'feature_pct': feature_file_data['coverage_pct']
}
# --- Print the Final Report ---
print("📊❌ Coverage Change Report")
print("=" * 30)
if not lost_coverage_report:
print("\n✅ No coverage degradation detected. Great job!")
return
print("\nThe following files have reduced coverage:\n")
for filename, changes in lost_coverage_report.items():
clean_filename = filename.replace('src/diffusers/', '')
main_pct = changes['main_pct']
feature_pct = changes['feature_pct']
diff = (feature_pct - main_pct).quantize(Decimal('0.01'))
print(f"📄 File: {clean_filename}")
print(f" Percentage: {main_pct}% → {feature_pct}% ({diff}%)")
if changes['lines']:
print(f" Newly Missed Lines: {format_lines_as_ranges(changes['lines'])}")
print("-" * 25)
if __name__ == "__main__":
find_and_report_coverage_changes('coverage_main.json', 'unbloat.json') The JSON files were obtained by running the following command once on CUDA_VISIBLE_DEVICES="" pytest \
-n 24 --max-worker-restart=0 --dist=loadfile \
--cov=src/diffusers/ \
--cov-report=json:<CHANGE_ME>.json \
tests/lora/
Here is first report before fixes: Unroll📊❌ Coverage Change Report
==============================
The following files have reduced coverage:
📄 File: configuration_utils.py
Percentage: 66.24% → 46.50% (-19.74%)
Newly Missed Lines: 161, 164, 167, 169-170, 172, 268, 342-353, 355-356, 358, 360, 366, 373, 375-376, 380, 382, 440-441, 443, 447-448, 450, 452-453, 455-456, 458, 493, 499-500, 567, 570-572, 595-597, 599-600, 602, 604, 606, 613, 615-616, 618, 620, 630-631
-------------------------
📄 File: loaders/lora_base.py
Percentage: 78.16% → 78.16% (0.00%)
Newly Missed Lines: 732, 760
-------------------------
📄 File: models/model_loading_utils.py
Percentage: 33.18% → 17.73% (-15.45%)
Newly Missed Lines: 67, 114, 141-142, 165, 169, 173, 176, 231-232, 234-235, 238, 242, 258-261, 263, 266-268, 270-271, 273, 277, 291, 293, 295, 302, 304, 350-351, 381
-------------------------
📄 File: models/modeling_utils.py
Percentage: 51.14% → 23.00% (-28.14%)
Newly Missed Lines: 82-83, 86-87, 90, 238-239, 241-244, 247-248, 649, 653-654, 666-668, 672, 674, 683, 687-688, 691, 694, 699-701, 703-704, 706-708, 710, 716-719, 722, 726, 743-744, 746, 907-927, 929, 935-938, 940, 949, 956, 962, 968, 975, 977, 985, 993, 999-1000, 1004, 1009, 1012, 1015, 1031, 1035-1036, 1047, 1049, 1065, 1069-1071, 1074, 1077, 1080, 1082-1083, 1086-1089, 1105, 1108, 1110, 1113, 1117, 1139, 1152-1154, 1177, 1193-1194, 1199-1200, 1207, 1209-1210, 1212-1213, 1215, 1218-1219, 1221, 1223, 1225, 1228, 1230, 1236, 1239, 1242, 1265, 1273, 1281, 1285, 1293, 1298, 1301, 1303, 1306, 1460-1463, 1465, 1468-1470, 1472, 1474-1475, 1478, 1490-1491, 1495-1496, 1498, 1501, 1503, 1506-1507, 1509, 1515-1516, 1531, 1533, 1540-1541, 1560, 1568, 1576, 1581, 1583, 1589-1590, 1596, 1610, 1800, 1802-1804, 1806-1808, 1810, 1816, 1820, 1822, 1826, 1828, 1832, 1834, 1838, 1840, 1842
-------------------------
📄 File: models/unets/unet_2d_condition.py
Percentage: 53.81% → 53.36% (-0.45%)
Newly Missed Lines: 536-537
-------------------------
📄 File: pipelines/pipeline_loading_utils.py
Percentage: 28.78% → 16.77% (-12.01%)
Newly Missed Lines: 378, 380, 385, 393, 395-396, 398, 445, 455-456, 721, 725, 735, 737-739, 742, 756, 759-763, 768, 770-771, 775, 784-790, 792, 797, 805-806, 809-810, 814, 826, 829-830, 835, 851, 862, 867, 901-902, 909, 913-914, 916, 922, 926, 1137-1138
-------------------------
📄 File: pipelines/pipeline_utils.py
Percentage: 52.76% → 33.55% (-19.21%)
Newly Missed Lines: 272-276, 278, 286, 288-293, 295-298, 302, 306, 308-310, 316-318, 320-323, 325, 333, 336-339, 341-346, 350, 353, 355, 739, 741-764, 766, 772, 781, 784, 790, 796, 801, 804, 809, 813, 819, 827, 852, 856, 869-870, 876, 878, 882, 888-889, 895, 898, 908, 914, 926-930, 933, 938, 941-944, 946, 948, 951, 959, 965, 968-969, 987-989, 991, 999, 1002-1004, 1007, 1017, 1022, 1047, 1051, 1054, 1064-1070, 1077, 1079-1080, 1083-1085, 1090, 1093, 1096-1097, 1099, 1699-1700
-------------------------
📄 File: schedulers/scheduling_deis_multistep.py
Percentage: 10.63% → 0.00% (-10.63%)
Newly Missed Lines: 18-19, 21-22, 24-26, 29-30, 34, 78, 130-131, 133-134, 210-211, 217-218, 225, 235, 314, 348, 372, 383, 409, 431, 462, 522, 580, 649, 738, 758, 770, 835, 851, 885
-------------------------
📄 File: schedulers/scheduling_dpmsolver_singlestep.py
Percentage: 8.67% → 0.00% (-8.67%)
Newly Missed Lines: 17-18, 20-21, 23-26, 29-30, 32, 36, 80, 145-146, 148-149, 235, 275-276, 282-283, 290, 300, 405, 439, 463, 474, 500, 522, 553, 653, 717, 828, 950, 1014, 1034, 1046, 1117, 1133, 1167
-------------------------
📄 File: schedulers/scheduling_edm_euler.py
Percentage: 22.22% → 0.00% (-22.22%)
Newly Missed Lines: 15-17, 19, 21-24, 27, 30, 32, 45-46, 49, 85-86, 88-89, 133-134, 138-139, 145-146, 153, 163, 168, 176, 191, 215, 265, 276, 287, 302, 310, 410, 443, 447
-------------------------
📄 File: schedulers/scheduling_euler_discrete.py
Percentage: 22.22% → 14.14% (-8.08%)
Newly Missed Lines: 207, 209, 213, 215, 217, 219, 226, 229-230, 232, 237-239, 242, 245, 248, 250, 252-255, 257-259
-------------------------
📄 File: schedulers/scheduling_k_dpm_2_ancestral_discrete.py
Percentage: 16.40% → 0.00% (-16.40%)
Newly Missed Lines: 15-17, 19-20, 22-25, 28-29, 32, 34, 47-48, 52, 96, 135-136, 138-139, 181-182, 189-190, 196-197, 204, 214, 244, 344, 368, 394, 416, 447-448, 452, 467, 475, 583, 616
-------------------------
📄 File: schedulers/scheduling_k_dpm_2_discrete.py
Percentage: 17.17% → 0.00% (-17.17%)
Newly Missed Lines: 15-17, 19-20, 22-24, 27-28, 31, 33, 46-47, 51, 95, 134-135, 137-138, 181-182, 189-190, 196-197, 204, 214, 244, 328-329, 333, 348, 357, 381, 407, 429, 460, 555, 588
-------------------------
📄 File: schedulers/scheduling_unipc_multistep.py
Percentage: 7.60% → 0.00% (-7.60%)
Newly Missed Lines: 18-19, 21-22, 24-26, 29-30, 34, 79, 115, 185-186, 188-189, 276-277, 283-284, 291, 301, 424, 458, 482, 493, 519, 541, 572, 645, 774, 912, 932, 944, 1025, 1041, 1075
-------------------------
📄 File: schedulers/scheduling_utils.py
Percentage: 98.00% → 84.00% (-14.00%)
Newly Missed Lines: 151, 158, 175, 189-191, 194
-------------------------
📄 File: utils/hub_utils.py
Percentage: 30.26% → 22.56% (-7.70%)
Newly Missed Lines: 79-80, 82-84, 87, 90, 92-93, 96, 189, 191-194, 200, 205
-------------------------
📄 File: utils/logging.py
Percentage: 52.27% → 50.76% (-1.51%)
Newly Missed Lines: 307-308
-------------------------
📄 File: utils/peft_utils.py
Percentage: 82.74% → 82.23% (-0.51%)
Newly Missed Lines: 222
-------------------------
📄 File: utils/testing_utils.py
Percentage: 33.53% → 33.38% (-0.15%)
Newly Missed Lines: 543-544, 547, 551
-------------------------
📄 File: utils/typing_utils.py
Percentage: 64.86% → 8.11% (-56.75%)
Newly Missed Lines: 26, 30-32, 35-36, 38, 41, 43, 47-49, 51, 54, 64, 71, 78, 80, 84, 86, 91
------------------------- Then I added back in this test: After that, when I added Line 332 in 2527917
the coverage was no longer lagging behind, except for Here are my two cents:
LMK if these make sense or if anything is unclear. |
Nice, so the coverage is basically back to what it was previously. I'm not sure if
as it seems to hit lines that would otherwise remain untested. If the same coverage can be achieved with a better test, then that should be added, otherwise I don't really see the harm in keeping this simple test. |
I will add it back in but I think the current state of PR should not now be a blocker for reviews. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First of all, thanks a lot for taking this big task to refactor and simplify those tests. Seeing net 1000 lines removed is always great.
As for reviewing the change, I have to admit that it's quite hard. The overall amount of changes is quite large (also, some changes appear to be unrelated, like renaming variables). Although I suspect most new code is just old code moved to a different location (with some small changes), the diff view does not reveal that. Therefore, I haven't done a line per line review at this point.
Generally, refactoring tests is a delicate task. It is easy to accidentally cover fewer scenarios, since the tests will pass just fine. That is why I really wanted to see the change in test coverage. Of course, this is not a magic bullet to ensure that everything is still tested that was tested before, but I think it's the best boost in confidence we can get.
As for the gist of the refactor, I think it is a nice improvement, as witnessed by the lower line count combined with keeping the test coverage. That said, if I could dream of a "perfect" design, it would look more like this to me (using test_lora_set_adapters_scenarios
as an example):
# specific model class
@parametrize("scheduler_cls, <scheduler-classes>)
def test_lora_set_adapters_simple(self, scheduler_cls):
super()._test_lora_set_adapters_simple(scheduler_cls)
@parametrize("scheduler_cls, <scheduler-classes>)
def test_lora_set_adapters_weighted(self, scheduler_cls):
super()._test_lora_set_adapters_weighted(scheduler_cls)
...
# base class
def _test_lora_set_adapters_simple(self, scheduler_cls):
# maybe even consider parametrized pytest fixtures
pipe, inputs, output_no_lora, _ = self._setup_multi_adapter_pipeline(scheduler_cls)
# test for simple scenario
def _test_lora_set_adapters_weighted(self, scheduler_cls):
# maybe even consider parametrized pytest fixtures
pipe, inputs, output_no_lora, _ = self._setup_multi_adapter_pipeline(scheduler_cls)
# test for weighted scenario
That way, each test would be testing for "one thing" instead of multiple things, which is preferable most of the time. The checks that precede the scenario-specific logic could be moved out into a separate test, that way we avoid duplicated checks.
I'm not asking to do another refactor according to what I described. As I wrote, I think this is already an improvement and whether my suggestion is really better can be debated, I just wanted present my opinion on it.
|
||
def test_simple_inference_with_text_denoiser_lora_unfused(self): | ||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) | ||
@parameterized.expand([("simple",), ("weighted",), ("block_lora",), ("delete_adapter",)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do the scenarios have to be single element tuples? And is "simple"
a valid scenario here?
@unittest.skip("Not supported in CogVideoX.") | ||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): | ||
pass | ||
# TODO: skip them properly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment, on its own, is not very helpful in explain what it is that needs to be done here.
A few things to consider when refactoring this Mixin. I think we should try to address speed, composability, control flow and readability. Regarding Speed; We run a very large combination of tests here. And base output is computed for each of them. We can speed up test time by ~2X, by caching the base output and reusing across test cases since this doesn't change Another thing to help with speed is to only test with a single scheduler. The default is to use two schedulers and manually override in the inheriting class. I don't think the additional scheduler test is giving us much signal in terms of LoRA functionality. Regarding Control Flow / Readability When we refactored I think There are utility functions that are probably better off being broken up into individual functions. e.g. Line 133 in fd084dd
Line 242 in fd084dd
Line 270 in fd084dd
Line 385 in fd084dd
This is a sign that the function should be broken up into smaller pieces and called individually. Regarding Composability It's good to use parameterized expand if code can be reused across different combinations of LoRA actions. But I don't think having a single function handle all cases will scale well as we add additional testing conditions. There is a risk to get a very large function with lots of conditional paths. It's better to make a look up table for components that need to be tested (text encoder, denoiser, text + denoiser) and actions to tested (fuse, unfuse, load, unload) and then compose those together to run the test. Proposed new Mixin and a some pseudo examples. class PeftLoraLoaderMixinTests:
pipeline_class = None
scheduler_class = None
scheduler_kwargs = None
lora_supported_text_encoders = []
denoiser_name = ""
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
COMPONENT_SETUP_MAP = {
"text_encoder_only": ["_setup_text_encoder", ["text_lora_config"]],
"denoiser_only": ["_setup_denoiser", ["denoiser_lora_config"]],
"text_and_denoiser": ["_setup_text_and_denoiser", ["text_lora_config", "denoiser_lora_config"]],
}
ACTION_MAP = {
"fuse": "_action_fuse",
"unfuse": "_action_unfuse",
"save_load": "_action_save_load",
"unload": "_action_unload",
"scale": "_action_scale",
}
_base_output = None
rank = 4
def get_lora_config(self, rank, target_modules, lora_alpha=None, use_dora=False):
return LoraConfig(
r=rank, target_modules=target_modules, lora_alpha=lora_alpha, init_lora_weights=False, use_dora=use_dora
)
def get_dummy_components(self):
raise NotImplementedError
def get_dummy_inputs(self, with_generator=True):
raise NotImplementedError
@property
def output_shape(self):
raise NotImplementedError
def setup_pipeline(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
return pipe
def get_base_output(self, pipe, with_generator=True):
if self._base_output is None:
inputs = self.get_dummy_inputs(with_generator)
self._base_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
return self._base_output
def _setup_lora_text_encoders(self, pipe, text_lora_config):
for component_name in self.lora_supported_text_encoders:
component = getattr(pipe, component_name)
component.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(component), f"Lora not correctly set in {component_name}")
return self.lora_supported_text_encoders
def _setup_lora_denoiser(self, pipe, denoiser_lora_config):
component = getattr(pipe, self.denoiser_name)
component.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(component), f"Lora not correctly set in {component}.")
return self.denoiser_name
def _setup_text_and_denoiser(self, pipe, text_lora_config, denoiser_lora_config):
text_encoders = self._setup_lora_text_encoders(pipe, text_lora_config)
denoiser = self._setup_lora_denoiser(pipe, denoiser_lora_config)
return text_encoders.append(denoiser)
def _action_fuse(self, pipe, base_output, lora_output, lora_components, expected_atol=1e-3):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
inputs = self.get_dummy_inputs(with_generator=False)
output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(base_output, output, atol=expected_atol),
f"Output after fuse should differ from base output",
)
def _action_save_load(self, pipe, base_output, lora_output, lora_components, expected_atol=1e-3):
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = {}
for component_name in lora_components:
if not hasattr(pipe, component_name):
continue
modules_to_save[component_name] = getattr(pipe, component_name)
# Save
state_dicts = {}
metadatas = {}
for module_name, module in modules_to_save.items():
if module is not None and getattr(module, "peft_config", None) is not None:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
pipe.save_lora_weights(tmpdir, weight_name="lora.safetensors", **state_dicts, **metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir, weight_name="lora.safetensors")
# remaining assertions
def _action_unload(self, pipe, base_output, lora_output, lora_components, expected_atol=1e-3):
pipe.unload_lora_weights()
for component_name in lora_components:
self.assertFalse(
check_if_lora_correctly_set(getattr(pipe, component_name)),
f"Lora layers should not be present in {component_name} after unloading",
)
inputs = self.get_dummy_inputs(with_generator=False)
outputs = pipe(**inputs, generator=torch.manual_seed(0))[0]
# remaining assertions
def _should_skip_test(self, components):
if components in ["text_encoder_only", "text_and_denoiser"]:
return "text_encoder" not in self.pipeline_class._lora_loadable_modules
return False
def _setup_lora_components(self, pipe, components, text_lora_config, denoiser_lora_config):
method_name, config_names = self.COMPONENT_SETUP_MAP[components]
setup_method = getattr(self, method_name)
kwargs = {"pipe": pipe}
config_map = {"text_lora_config": text_lora_config, "denoiser_lora_config": denoiser_lora_config}
for config_name in config_names:
kwargs.update({config_name: config_map[config_name]})
components = setup_method(**kwargs)
return components
def _execute_lora_action(self, action, pipe, base_output, lora_output, lora_components, expected_atol):
"""Execute a specific LoRA action and return the output"""
action_method = getattr(self, self.ACTION_MAP[action])
return action_method(pipe, base_output, lora_output, lora_components, expected_atol)
def _test_lora_action(self, action, components, expected_atol=1e-4):
# Skip if not supported
if self._should_skip_test(components):
self.skipTest(f"{components} LoRA is not supported")
pipe = self.setup_pipeline()
base_output = self.get_base_output(pipe)
lora_components = self._setup_lora_components(pipe, components)
lora_output = pipe(**get_inputs())
self._execute_lora_action(action, pipe, base_output, lora_output, lora_components, expected_atol)
@parameterized.expand(
[
# Test actions on text_encoder LoRA only
("fused", "text_encoder_only"),
("unloaded", "text_encoder_only"),
("save_load", "text_encoder_only"),
# Test actions on both text_encoder and denoiser LoRA
("fused", "text_and_denoiser"),
("unloaded", "text_and_denoiser"),
("unfused", "text_and_denoiser"),
("save_load", "text_and_denoiser"),
("disable", "text_and_denoiser"),
],
name_func=lambda func, num, p: f"{func.__name__}_{p[0]}_{p[1]}", # so that test logs give us a nice test name and not an index
)
def test_lora_actions(self, action, components):
"""Test various LoRA actions with different component combinations"""
self._test_lora_action(action, components)
def test_low_cpu_mem_usage_with_injection(self):
pipe = self.setup_pipeline()
text_lora_config = self.get_config(...)
denoiser_lora_config = self.get_lora_config(...)
for component_name in self.lora_supported_text_encoders:
if component_name not in self.pipeline_class._lora_loadable_modules:
continue
inject_adapter_in_model(text_lora_config, getattr(pipe, component_name), low_cpu_mem_usage=True)self
self.assertTrue(
check_if_lora_correctly_set(getattr(pipe, component_name)), f"Lora not correctly set in {component_name}."
)
# remaining assertions
# denoiser tests |
Thanks both for being candid about the feedback. I will try to address as much as possible. |
What does this PR do?
We take the following approach:
parameterized
to combine similar flavored tests.peft>=0.15.0
a mandate. So, I removed@require_peft_version_greater
decorator.In a follow-up PR, I will attempt to improve tests from the LoRA test suite that take the most amount of time.