|
12 | 12 | from collections.abc import Callable, Iterator |
13 | 13 | from copy import deepcopy |
14 | 14 | from dataclasses import asdict |
15 | | -from functools import partial |
16 | 15 | from typing import Any, overload |
17 | 16 |
|
18 | 17 | import torch |
|
56 | 55 | ) |
57 | 56 |
|
58 | 57 | from distributed_shampoo.shampoo_types import ( |
59 | | - AdaGradGraftingConfig, |
60 | | - AdamGraftingConfig, |
| 58 | + AdaGradPreconditionerConfig, |
| 59 | + AdamPreconditionerConfig, |
61 | 60 | AmortizedPreconditionerConfig, |
62 | 61 | BETA3, |
63 | 62 | BETAS, |
|
78 | 77 | FullyShardDistributedConfig, |
79 | 78 | GRAFTING_CONFIG, |
80 | 79 | GRAFTING_PRECONDITIONER_LIST, |
81 | | - GraftingConfig, |
82 | 80 | HSDPDistributedConfig, |
83 | 81 | HybridShardDistributedConfig, |
84 | 82 | LR, |
|
94 | 92 | PRECONDITIONER_CONFIG, |
95 | 93 | PreconditionerConfig, |
96 | 94 | PREVIOUS_GRAD_SELECTOR, |
97 | | - RMSpropGraftingConfig, |
| 95 | + RMSpropPreconditionerConfig, |
98 | 96 | RootInvShampooPreconditionerConfig, |
99 | | - SGDGraftingConfig, |
| 97 | + SGDPreconditionerConfig, |
100 | 98 | SHAMPOO_PRECONDITIONER_LIST, |
101 | 99 | ShampooPT2CompileConfig, |
102 | 100 | SignDescentPreconditionerConfig, |
@@ -338,7 +336,7 @@ def __init__( |
338 | 336 | use_nesterov: bool = False, |
339 | 337 | use_bias_correction: bool = True, |
340 | 338 | use_decoupled_weight_decay: bool = True, |
341 | | - grafting_config: GraftingConfig | None = None, |
| 339 | + grafting_config: PreconditionerConfig | None = None, |
342 | 340 | use_pin_memory: bool = False, |
343 | 341 | shampoo_pt2_compile_config: ShampooPT2CompileConfig | None = None, |
344 | 342 | distributed_config: DistributedConfig = DefaultSingleDeviceDistributedConfig, |
@@ -577,93 +575,103 @@ def _initialize_blocked_parameters_state(self) -> None: |
577 | 575 | ), "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions." |
578 | 576 | param_state[block_index] = {} |
579 | 577 |
|
| 578 | + @torch.no_grad() |
| 579 | + def _preconditioner_config_to_list_cls( |
| 580 | + self, |
| 581 | + state_lists: dict[str, Any], |
| 582 | + group: dict[str, Any], |
| 583 | + preconditioner_config: PreconditionerConfig, |
| 584 | + ) -> PreconditionerList | None: |
| 585 | + match preconditioner_config: |
| 586 | + case None: |
| 587 | + return None |
| 588 | + case SGDPreconditionerConfig(): |
| 589 | + return SGDPreconditionerList( |
| 590 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 591 | + ) |
| 592 | + case ( |
| 593 | + AdaGradPreconditionerConfig() |
| 594 | + | RMSpropPreconditionerConfig() |
| 595 | + | AdamPreconditionerConfig() |
| 596 | + ): |
| 597 | + return AdagradPreconditionerList( |
| 598 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 599 | + state=self.state, |
| 600 | + block_info_list=state_lists[DISTRIBUTOR].local_block_info_list, |
| 601 | + beta2=( |
| 602 | + 1.0 |
| 603 | + if type(group[GRAFTING_CONFIG]) is AdaGradPreconditionerConfig |
| 604 | + else group[GRAFTING_CONFIG].beta2 |
| 605 | + ), |
| 606 | + epsilon=group[GRAFTING_CONFIG].epsilon, |
| 607 | + use_bias_correction=type(group[GRAFTING_CONFIG]) |
| 608 | + is AdamPreconditionerConfig, |
| 609 | + ) |
| 610 | + case ( |
| 611 | + RootInvShampooPreconditionerConfig() |
| 612 | + | EigendecomposedShampooPreconditionerConfig() |
| 613 | + | EigenvalueCorrectedShampooPreconditionerConfig() |
| 614 | + ): |
| 615 | + preconditioner_config_to_list_cls: dict[ |
| 616 | + type[PreconditionerConfig], Callable[..., PreconditionerList] |
| 617 | + ] = { |
| 618 | + RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList, |
| 619 | + EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList, |
| 620 | + EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList, |
| 621 | + } |
| 622 | + return preconditioner_config_to_list_cls[type(preconditioner_config)]( |
| 623 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 624 | + preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 625 | + state=self.state, |
| 626 | + block_info_list=state_lists[DISTRIBUTOR].local_block_info_list, |
| 627 | + beta2=group[BETAS][1], |
| 628 | + epsilon=group[EPSILON], |
| 629 | + use_bias_correction=group[USE_BIAS_CORRECTION], |
| 630 | + ) |
| 631 | + case SignDescentPreconditionerConfig(): |
| 632 | + return SignDescentPreconditionerList( |
| 633 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 634 | + preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 635 | + ) |
| 636 | + case SpectralDescentPreconditionerConfig(): |
| 637 | + assert ( |
| 638 | + group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2 |
| 639 | + ), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig." |
| 640 | + return SpectralDescentPreconditionerList( |
| 641 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 642 | + preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 643 | + ) |
| 644 | + case _: |
| 645 | + raise NotImplementedError(f"{preconditioner_config=} not supported!") |
| 646 | + |
580 | 647 | @torch.no_grad() |
581 | 648 | def _instantiate_shampoo_preconditioner_list(self) -> None: |
582 | 649 | for state_lists, group in zip( |
583 | 650 | self._per_group_state_lists, self.param_groups, strict=True |
584 | 651 | ): |
585 | | - match group[PRECONDITIONER_CONFIG]: |
586 | | - case ( |
587 | | - RootInvShampooPreconditionerConfig() |
588 | | - | EigendecomposedShampooPreconditionerConfig() |
589 | | - | EigenvalueCorrectedShampooPreconditionerConfig() |
590 | | - ): |
591 | | - preconditioner_config_to_list_cls: dict[ |
592 | | - type[PreconditionerConfig], Callable[..., PreconditionerList] |
593 | | - ] = { |
594 | | - RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList, |
595 | | - EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList, |
596 | | - EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList, |
597 | | - } |
598 | | - preconditioner_list_cls: Callable[..., PreconditionerList] = ( |
599 | | - partial( |
600 | | - preconditioner_config_to_list_cls[ |
601 | | - type(group[PRECONDITIONER_CONFIG]) |
602 | | - ], |
603 | | - state=self.state, |
604 | | - block_info_list=state_lists[ |
605 | | - DISTRIBUTOR |
606 | | - ].local_block_info_list, |
607 | | - beta2=group[BETAS][1], |
608 | | - epsilon=group[EPSILON], |
609 | | - use_bias_correction=group[USE_BIAS_CORRECTION], |
610 | | - ) |
611 | | - ) |
612 | | - case SignDescentPreconditionerConfig(): |
613 | | - preconditioner_list_cls = partial(SignDescentPreconditionerList) |
614 | | - case SpectralDescentPreconditionerConfig(): |
615 | | - assert ( |
616 | | - group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2 |
617 | | - ), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig." |
618 | | - preconditioner_list_cls = SpectralDescentPreconditionerList |
619 | | - case _: |
620 | | - raise NotImplementedError( |
621 | | - f"{group[PRECONDITIONER_CONFIG]=} not supported!" |
622 | | - ) |
623 | | - |
624 | | - state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls( |
625 | | - block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
626 | | - preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 652 | + assert ( |
| 653 | + group[PRECONDITIONER_CONFIG] is not None |
| 654 | + ), f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo." |
| 655 | + state_lists[SHAMPOO_PRECONDITIONER_LIST] = ( |
| 656 | + self._preconditioner_config_to_list_cls( |
| 657 | + state_lists=state_lists, |
| 658 | + group=group, |
| 659 | + preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 660 | + ) |
627 | 661 | ) |
628 | 662 |
|
629 | 663 | @torch.no_grad() |
630 | 664 | def _instantiate_grafting(self) -> None: |
631 | 665 | for state_lists, group in zip( |
632 | 666 | self._per_group_state_lists, self.param_groups, strict=True |
633 | 667 | ): |
634 | | - match group[GRAFTING_CONFIG]: |
635 | | - case None: |
636 | | - state_lists[GRAFTING_PRECONDITIONER_LIST] = None |
637 | | - case SGDGraftingConfig(): |
638 | | - state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList( |
639 | | - block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
640 | | - ) |
641 | | - case ( |
642 | | - AdaGradGraftingConfig() |
643 | | - | RMSpropGraftingConfig() |
644 | | - | AdamGraftingConfig() |
645 | | - ): |
646 | | - state_lists[GRAFTING_PRECONDITIONER_LIST] = ( |
647 | | - AdagradPreconditionerList( |
648 | | - block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
649 | | - state=self.state, |
650 | | - block_info_list=state_lists[ |
651 | | - DISTRIBUTOR |
652 | | - ].local_block_info_list, |
653 | | - beta2=( |
654 | | - 1.0 |
655 | | - if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig |
656 | | - else group[GRAFTING_CONFIG].beta2 |
657 | | - ), |
658 | | - epsilon=group[GRAFTING_CONFIG].epsilon, |
659 | | - use_bias_correction=type(group[GRAFTING_CONFIG]) |
660 | | - is AdamGraftingConfig, |
661 | | - ) |
662 | | - ) |
663 | | - case _: |
664 | | - raise NotImplementedError( |
665 | | - f"{group[GRAFTING_CONFIG]=} not supported!" |
666 | | - ) |
| 668 | + state_lists[GRAFTING_PRECONDITIONER_LIST] = ( |
| 669 | + self._preconditioner_config_to_list_cls( |
| 670 | + state_lists=state_lists, |
| 671 | + group=group, |
| 672 | + preconditioner_config=group[GRAFTING_CONFIG], |
| 673 | + ) |
| 674 | + ) |
667 | 675 |
|
668 | 676 | @torch.no_grad() |
669 | 677 | def _instantiate_steps(self) -> None: |
|
0 commit comments