Skip to content

Commit 7171b47

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Better type hint of instantiate_preconditioner_config()
Summary: `instantiate_preconditioner_config()` only returns `None` when encoutnering `PreconditionerComputationType.None` as input, and this type hint enhancements make the intention more clear. Differential Revision: D81278578
1 parent 80562ff commit 7171b47

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

distributed_shampoo/examples/trainer_utils.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from functools import partial
1818
from operator import attrgetter
1919
from pathlib import Path
20-
from typing import overload, Type
20+
from typing import Literal, overload, Type
2121

2222
import numpy as np
2323

@@ -406,11 +406,6 @@ def instantiate_optimizer(
406406
weight_decay=weight_decay,
407407
)
408408
elif optimizer_type == OptimizerType.DISTRIBUTED_SHAMPOO:
409-
assert (
410-
preconditioner_config := instantiate_preconditioner_config(
411-
preconditioner_computation_type=preconditioner_computation_type,
412-
)
413-
) is not None
414409
optimizer_cls = partial(
415410
DistributedShampoo,
416411
betas=betas,
@@ -431,14 +426,32 @@ def instantiate_optimizer(
431426
grafting_epsilon=grafting_epsilon,
432427
),
433428
distributed_config=distributed_config,
434-
preconditioner_config=preconditioner_config,
429+
preconditioner_config=instantiate_preconditioner_config(
430+
preconditioner_computation_type=preconditioner_computation_type,
431+
),
435432
)
436433
else:
437434
raise ValueError(f"Invalid OptimizerType {optimizer_type}!")
438435

439436
return optimizer_cls(parameters, lr=lr)
440437

441438

439+
@overload
440+
def instantiate_preconditioner_config(
441+
preconditioner_computation_type: Literal[PreconditionerComputationType.NONE],
442+
grafting_beta2: float = ...,
443+
grafting_epsilon: float = ...,
444+
) -> None: ...
445+
446+
447+
@overload
448+
def instantiate_preconditioner_config(
449+
preconditioner_computation_type: PreconditionerComputationType,
450+
grafting_beta2: float = ...,
451+
grafting_epsilon: float = ...,
452+
) -> PreconditionerConfig: ...
453+
454+
442455
def instantiate_preconditioner_config(
443456
preconditioner_computation_type: PreconditionerComputationType,
444457
grafting_beta2: float = 1.0,

0 commit comments

Comments
 (0)