Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class HuggingFacePostTrainingConfig(BaseModel):
dpo_beta: float = 0.1
use_reference_model: bool = True
dpo_loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
dpo_output_dir: str
dpo_output_dir: str | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, yes this is a backwards incompatible change as you noted. We should make that a runtime check where the provider yells if this value continues to be None at runtime as @cdoern notes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So do we need to change anything else here or is this good ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently its throwing stack trace error, do we need to show this a log error only or something else?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'll need to add checks in the code which uses dpo_output_dir. This is the source of the errors you are seeing, also please run pre-commit to re-generate the openAPI schema

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could see only one place where dpo_output_dir was being used have added a check there also, please review again once

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cdoern I am unable to run pre-commit because of some dependencies issues coming up because of upgrading of my laptop, solely because of change in architecture


@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def handler(on_log_message_cb, on_status_change_cb, on_artifact_collected_

resources_allocated, checkpoints = await recipe.train(
model=finetuned_model,
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}",
output_dir=f"{self.config.dpo_output_dir}/{job_uuid}" if self.config.dpo_output_dir else None,
job_uuid=job_uuid,
dpo_config=algorithm_config,
config=training_config,
Expand Down
Loading