diff --git a/swift/trainers/rlhf_trainer/dpo_trainer.py b/swift/trainers/rlhf_trainer/dpo_trainer.py index b84a7051eb..2e759bee79 100644 --- a/swift/trainers/rlhf_trainer/dpo_trainer.py +++ b/swift/trainers/rlhf_trainer/dpo_trainer.py @@ -182,4 +182,5 @@ def training_step(self, model, inputs, *args, **kwargs): def prediction_step(self, model, inputs, *args, **kwargs): with self.template.forward_context(self.model, inputs): + inputs = self._prepare_inputs(inputs) return super().prediction_step(model, inputs, *args, **kwargs)