diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index 62dd3a8f3..1e21d2732 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -6,7 +6,7 @@ """ from dataclasses import dataclass, field -from typing import Any, List, Optional +from typing import Any from loguru import logger @@ -23,11 +23,11 @@ class CompressionLifecycle: A class for managing the lifecycle of compression events in the LLM Compressor. :param state: The current state of the compression process - :type state: Optional[State] + :type state: State :param recipe: The compression recipe :type recipe: Recipe :param modifiers: The list of stage modifiers - :type modifiers: List[StageModifiers] + :type modifiers: list[StageModifiers] """ state: State = field(default_factory=State) @@ -37,8 +37,8 @@ class CompressionLifecycle: finalized: bool = False # event order validation - _last_event_type: Optional[EventType] = EventType.BATCH_END - _event_order: List[EventType] = field( + _last_event_type: EventType | None = EventType.BATCH_END + _event_order: list[EventType] = field( default_factory=lambda: [ EventType.BATCH_START, EventType.LOSS_CALCULATED, @@ -72,11 +72,11 @@ def reset(self): def initialize( self, - recipe: Optional[RecipeInput] = None, - recipe_stage: Optional[RecipeStageInput] = None, - recipe_args: Optional[RecipeArgsInput] = None, + recipe: RecipeInput | None = None, + recipe_stage: RecipeStageInput | None = None, + recipe_args: RecipeArgsInput | None = None, **kwargs, - ) -> List[Any]: + ) -> list[Any]: """ Initialize the compression lifecycle. @@ -114,7 +114,7 @@ def initialize( return mod_data - def finalize(self, **kwargs) -> List[Any]: + def finalize(self, **kwargs) -> list[Any]: """ Finalize the compression lifecycle. @@ -149,8 +149,8 @@ def finalize(self, **kwargs) -> List[Any]: return mod_data def event( - self, event_type: EventType, global_step: Optional[int] = 0, **kwargs - ) -> List[Any]: + self, event_type: EventType, global_step: int | None = 0, **kwargs + ) -> list[Any]: """ Handle a compression event.