@@ -491,6 +491,7 @@ class BoringCkptPathModel(BoringModel):
491491 def __init__ (self , out_dim : int = 2 , hidden_dim : int = 2 ) -> None :
492492 super ().__init__ ()
493493 self .save_hyperparameters ()
494+ self .hidden_dim = hidden_dim
494495 self .layer = torch .nn .Linear (32 , out_dim )
495496
496497
@@ -526,6 +527,41 @@ def add_arguments_to_parser(self, parser):
526527 assert "Parsing of ckpt_path hyperparameters failed" in err .getvalue ()
527528
528529
530+ class BoringCkptPathSubclass (BoringCkptPathModel ):
531+ def __init__ (self , extra : bool = True , ** kwargs ) -> None :
532+ super ().__init__ (** kwargs )
533+ self .extra = extra
534+
535+
536+ def test_lightning_cli_ckpt_path_argument_hparams_subclass_mode (cleandir ):
537+ class CkptPathCLI (LightningCLI ):
538+ def add_arguments_to_parser (self , parser ):
539+ parser .link_arguments ("model.init_args.out_dim" , "model.init_args.hidden_dim" , compute_fn = lambda x : x * 2 )
540+
541+ cli_args = ["fit" , "--model=BoringCkptPathSubclass" , "--model.out_dim=4" , "--trainer.max_epochs=1" ]
542+ with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
543+ cli = CkptPathCLI (BoringCkptPathModel , subclass_mode_model = True )
544+
545+ assert cli .config .fit .model .class_path .endswith (".BoringCkptPathSubclass" )
546+ assert cli .config .fit .model .init_args == Namespace (out_dim = 4 , hidden_dim = 8 , extra = True )
547+ hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
548+ assert hparams_path .is_file ()
549+ hparams = yaml .safe_load (hparams_path .read_text ())
550+ assert hparams ["out_dim" ] == 4
551+ assert hparams ["hidden_dim" ] == 8
552+ assert hparams ["extra" ] is True
553+
554+ checkpoint_path = next (Path (cli .trainer .log_dir , "checkpoints" ).glob ("*.ckpt" ))
555+ cli_args = ["predict" , "--model=BoringCkptPathModel" , f"--ckpt_path={ checkpoint_path } " ]
556+ with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
557+ cli = CkptPathCLI (BoringCkptPathModel , subclass_mode_model = True )
558+
559+ assert isinstance (cli .model , BoringCkptPathSubclass )
560+ assert cli .model .hidden_dim == 8
561+ assert cli .model .extra is True
562+ assert cli .model .layer .out_features == 4
563+
564+
529565def test_lightning_cli_submodules (cleandir ):
530566 class MainModule (BoringModel ):
531567 def __init__ (self , submodule1 : LightningModule , submodule2 : LightningModule , main_param : int = 1 ):
0 commit comments