Skip to content

[WIP] Recipe model_dump fixes #1379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 57 additions & 88 deletions src/llmcompressor/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ class Recipe(RecipeBase):
when serializing a recipe, yaml will be used by default.
"""

version: Optional[str] = Field(default=None)
args: RecipeArgs = Field(default_factory=RecipeArgs)
stages: List[RecipeStage] = Field(default_factory=list)
metadata: Optional[RecipeMetaData] = Field(default=None)
args_evaluated: RecipeArgs = Field(default_factory=RecipeArgs)

@classmethod
def from_modifiers(
cls,
Expand Down Expand Up @@ -280,12 +286,6 @@ def simplify_combine_recipes(

return combined

version: str = None
args: RecipeArgs = Field(default_factory=RecipeArgs)
stages: List[RecipeStage] = Field(default_factory=list)
metadata: RecipeMetaData = None
args_evaluated: RecipeArgs = Field(default_factory=RecipeArgs)

def calculate_start(self) -> int:
"""
Calculate and return the start epoch of the recipe.
Expand Down Expand Up @@ -399,11 +399,12 @@ def remap_stages(cls, values: Dict[str, Any]) -> Dict[str, Any]:
formatted_values["stages"] = stages

# fill out any default argument values
args = {}
args = {**values.pop("args", {})}
for key, val in values.items():
args[key] = val
# avoid nesting the args in the recipe
if key not in cls.__pydantic_fields__:
args[key] = val
formatted_values["args"] = RecipeArgs(args)

return formatted_values

@staticmethod
Expand Down Expand Up @@ -504,52 +505,62 @@ def combine_metadata(self, metadata: Optional[RecipeMetaData]):
else:
self.metadata.update_missing_metadata(metadata)

def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
:return: A dictionary representation of the recipe
def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
"""
dict_ = super().model_dump(*args, **kwargs)
stages = {}

for stage in dict_["stages"]:
name = f"{stage['group']}_stage"
del stage["group"]
Generate a serializable dictionary representation of this recipe.

if name not in stages:
stages[name] = []
This method transforms the internal recipe structure into a format
suitable for YAML serialization while preserving all necessary
information for round-trip deserialization.

stages[name].append(stage)
:param args: Additional positional arguments for parent method
:param kwargs: Additional keyword arguments for parent method
:return: Dictionary ready for YAML serialization
"""
# Retrieve base representation from parent class
raw_dict = super().model_dump(*args, **kwargs)

# Initialize clean output dictionary
serializable_dict = {}

# Copy recipe metadata attributes
metadata_keys = ["version", "args", "metadata"]
for key in metadata_keys:
if value := raw_dict.get(key):
serializable_dict[key] = value

# Process and organize stages by group
if "stages" in raw_dict:
# Group stages by their type (e.g., "train", "eval")
grouped_stages = {}
for stage in raw_dict["stages"]:
group_id = (
f"{stage.pop('group')}_stage" # Remove group field and use as key
)

dict_["stages"] = stages
if group_id not in grouped_stages:
grouped_stages[group_id] = []

return dict_
grouped_stages[group_id].append(stage)

def model_dump(self, *args, **kwargs) -> Dict[str, Any]:
"""
Override the model_dump method to provide a dictionary representation that
is compatible with model_validate.
# Format each stage for YAML output
for group_id, stages in grouped_stages.items():
for idx, stage_data in enumerate(stages):
# Create unique identifiers for multiple stages of same type
final_id = f"{group_id}_{idx}" if len(stages) > 1 else group_id

Unlike the standard model_dump, this transforms the stages list to a format
expected by the validation logic, ensuring round-trip compatibility with
model_validate.
# Create clean stage representation
stage_yaml = get_yaml_serializable_stage_dict(
modifiers=stage_data["modifiers"]
)

:return: A dictionary representation of the recipe compatible with
model_validate
"""
# Get the base dictionary from parent class
base_dict = super().model_dump(*args, **kwargs)
# Preserve run type if specified
if run_type := stage_data.get("run_type"):
stage_yaml["run_type"] = run_type

# Transform stages into the expected format
if "stages" in base_dict:
stages_dict = {}
for stage in base_dict["stages"]:
group = stage["group"]
if group not in stages_dict:
stages_dict[group] = []
stages_dict[group].append(stage)
base_dict["stages"] = stages_dict
serializable_dict[final_id] = stage_yaml

return base_dict
return serializable_dict

def yaml(self, file_path: Optional[str] = None) -> str:
"""
Expand All @@ -559,10 +570,9 @@ def yaml(self, file_path: Optional[str] = None) -> str:
:return: The yaml string representation of the recipe
"""
file_stream = None if file_path is None else open(file_path, "w")
yaml_dict = self._get_yaml_dict()

ret = yaml.dump(
yaml_dict,
self.model_dump(),
stream=file_stream,
allow_unicode=True,
sort_keys=False,
Expand All @@ -575,47 +585,6 @@ def yaml(self, file_path: Optional[str] = None) -> str:

return ret

def _get_yaml_dict(self) -> Dict[str, Any]:
"""
Get a dictionary representation of the recipe for yaml serialization
The returned dict will only contain information necessary for yaml
serialization and must not be used in place of the dict method

:return: A dictionary representation of the recipe for yaml serialization
"""

original_recipe_dict = self.dict()
yaml_recipe_dict = {}

# populate recipe level attributes
recipe_level_attributes = ["version", "args", "metadata"]

for attribute in recipe_level_attributes:
if attribute_value := original_recipe_dict.get(attribute):
yaml_recipe_dict[attribute] = attribute_value

# populate stages
stages = original_recipe_dict["stages"]
for stage_name, stage_list in stages.items():
for idx, stage in enumerate(stage_list):
if len(stage_list) > 1:
# resolve name clashes caused by combining recipes with
# duplicate stage names
final_stage_name = f"{stage_name}_{idx}"
else:
final_stage_name = stage_name
stage_dict = get_yaml_serializable_stage_dict(
modifiers=stage["modifiers"]
)

# infer run_type from stage
if run_type := stage.get("run_type"):
stage_dict["run_type"] = run_type

yaml_recipe_dict[final_stage_name] = stage_dict

return yaml_recipe_dict


RecipeInput = Union[str, List[str], Recipe, List[Recipe], Modifier, List[Modifier]]
RecipeStageInput = Union[str, List[str], List[List[str]]]
Expand Down
21 changes: 21 additions & 0 deletions tests/llmcompressor/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,25 @@ def valid_recipe_strings():
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
]
""",
"""
version: 1.0
args:
learning_rate: 0.001
train_stage:
pruning_modifiers:
ConstantPruningModifier:
start: 0.0
end: 2.0
targets: ['re:.*weight']
quantization_modifiers:
QuantizationModifier:
bits: 8
targets: ['re:.*weight']
eval_stage:
pruning_modifiers:
ConstantPruningModifier:
start: 2.0
end: 4.0
targets: ['re:.*weight']
""",
]
Loading