|
12 | 12 | from collections.abc import Callable, Iterator
|
13 | 13 | from copy import deepcopy
|
14 | 14 | from dataclasses import asdict
|
15 |
| -from functools import partial |
16 | 15 | from typing import Any, overload
|
17 | 16 |
|
18 | 17 | import torch
|
|
56 | 55 | )
|
57 | 56 |
|
58 | 57 | from distributed_shampoo.shampoo_types import (
|
59 |
| - AdaGradGraftingConfig, |
60 |
| - AdamGraftingConfig, |
| 58 | + AdaGradPreconditionerConfig, |
| 59 | + AdamPreconditionerConfig, |
61 | 60 | AmortizedPreconditionerConfig,
|
62 | 61 | BETA3,
|
63 | 62 | BETAS,
|
|
78 | 77 | FullyShardDistributedConfig,
|
79 | 78 | GRAFTING_CONFIG,
|
80 | 79 | GRAFTING_PRECONDITIONER_LIST,
|
81 |
| - GraftingConfig, |
82 | 80 | HSDPDistributedConfig,
|
83 | 81 | HybridShardDistributedConfig,
|
84 | 82 | LR,
|
|
94 | 92 | PRECONDITIONER_CONFIG,
|
95 | 93 | PreconditionerConfig,
|
96 | 94 | PREVIOUS_GRAD_SELECTOR,
|
97 |
| - RMSpropGraftingConfig, |
| 95 | + RMSpropPreconditionerConfig, |
98 | 96 | RootInvShampooPreconditionerConfig,
|
99 |
| - SGDGraftingConfig, |
| 97 | + SGDPreconditionerConfig, |
100 | 98 | SHAMPOO_PRECONDITIONER_LIST,
|
101 | 99 | ShampooPT2CompileConfig,
|
102 | 100 | SignDescentPreconditionerConfig,
|
@@ -338,7 +336,7 @@ def __init__(
|
338 | 336 | use_nesterov: bool = False,
|
339 | 337 | use_bias_correction: bool = True,
|
340 | 338 | use_decoupled_weight_decay: bool = True,
|
341 |
| - grafting_config: GraftingConfig | None = None, |
| 339 | + grafting_config: PreconditionerConfig | None = None, |
342 | 340 | use_pin_memory: bool = False,
|
343 | 341 | shampoo_pt2_compile_config: ShampooPT2CompileConfig | None = None,
|
344 | 342 | distributed_config: DistributedConfig = DefaultSingleDeviceDistributedConfig,
|
@@ -577,93 +575,103 @@ def _initialize_blocked_parameters_state(self) -> None:
|
577 | 575 | ), "There should not exist any optimizer state yet. Maybe verify that _instantiate_distributor was called before all other instantiation functions."
|
578 | 576 | param_state[block_index] = {}
|
579 | 577 |
|
| 578 | + @torch.no_grad() |
| 579 | + def _preconditioner_config_to_list_cls( |
| 580 | + self, |
| 581 | + state_lists: dict[str, Any], |
| 582 | + group: dict[str, Any], |
| 583 | + preconditioner_config: PreconditionerConfig, |
| 584 | + ) -> PreconditionerList | None: |
| 585 | + match preconditioner_config: |
| 586 | + case None: |
| 587 | + return None |
| 588 | + case SGDPreconditionerConfig(): |
| 589 | + return SGDPreconditionerList( |
| 590 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 591 | + ) |
| 592 | + case ( |
| 593 | + AdaGradPreconditionerConfig() |
| 594 | + | RMSpropPreconditionerConfig() |
| 595 | + | AdamPreconditionerConfig() |
| 596 | + ): |
| 597 | + return AdagradPreconditionerList( |
| 598 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 599 | + state=self.state, |
| 600 | + block_info_list=state_lists[DISTRIBUTOR].local_block_info_list, |
| 601 | + beta2=( |
| 602 | + 1.0 |
| 603 | + if type(group[GRAFTING_CONFIG]) is AdaGradPreconditionerConfig |
| 604 | + else group[GRAFTING_CONFIG].beta2 |
| 605 | + ), |
| 606 | + epsilon=group[GRAFTING_CONFIG].epsilon, |
| 607 | + use_bias_correction=type(group[GRAFTING_CONFIG]) |
| 608 | + is AdamPreconditionerConfig, |
| 609 | + ) |
| 610 | + case ( |
| 611 | + RootInvShampooPreconditionerConfig() |
| 612 | + | EigendecomposedShampooPreconditionerConfig() |
| 613 | + | EigenvalueCorrectedShampooPreconditionerConfig() |
| 614 | + ): |
| 615 | + preconditioner_config_to_list_cls: dict[ |
| 616 | + type[PreconditionerConfig], Callable[..., PreconditionerList] |
| 617 | + ] = { |
| 618 | + RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList, |
| 619 | + EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList, |
| 620 | + EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList, |
| 621 | + } |
| 622 | + return preconditioner_config_to_list_cls[type(preconditioner_config)]( |
| 623 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 624 | + preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 625 | + state=self.state, |
| 626 | + block_info_list=state_lists[DISTRIBUTOR].local_block_info_list, |
| 627 | + beta2=group[BETAS][1], |
| 628 | + epsilon=group[EPSILON], |
| 629 | + use_bias_correction=group[USE_BIAS_CORRECTION], |
| 630 | + ) |
| 631 | + case SignDescentPreconditionerConfig(): |
| 632 | + return SignDescentPreconditionerList( |
| 633 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 634 | + preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 635 | + ) |
| 636 | + case SpectralDescentPreconditionerConfig(): |
| 637 | + assert ( |
| 638 | + group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2 |
| 639 | + ), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig." |
| 640 | + return SpectralDescentPreconditionerList( |
| 641 | + block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
| 642 | + preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 643 | + ) |
| 644 | + case _: |
| 645 | + raise NotImplementedError(f"{preconditioner_config=} not supported!") |
| 646 | + |
580 | 647 | @torch.no_grad()
|
581 | 648 | def _instantiate_shampoo_preconditioner_list(self) -> None:
|
582 | 649 | for state_lists, group in zip(
|
583 | 650 | self._per_group_state_lists, self.param_groups, strict=True
|
584 | 651 | ):
|
585 |
| - match group[PRECONDITIONER_CONFIG]: |
586 |
| - case ( |
587 |
| - RootInvShampooPreconditionerConfig() |
588 |
| - | EigendecomposedShampooPreconditionerConfig() |
589 |
| - | EigenvalueCorrectedShampooPreconditionerConfig() |
590 |
| - ): |
591 |
| - preconditioner_config_to_list_cls: dict[ |
592 |
| - type[PreconditionerConfig], Callable[..., PreconditionerList] |
593 |
| - ] = { |
594 |
| - RootInvShampooPreconditionerConfig: RootInvShampooPreconditionerList, |
595 |
| - EigendecomposedShampooPreconditionerConfig: EigendecomposedShampooPreconditionerList, |
596 |
| - EigenvalueCorrectedShampooPreconditionerConfig: EigenvalueCorrectedShampooPreconditionerList, |
597 |
| - } |
598 |
| - preconditioner_list_cls: Callable[..., PreconditionerList] = ( |
599 |
| - partial( |
600 |
| - preconditioner_config_to_list_cls[ |
601 |
| - type(group[PRECONDITIONER_CONFIG]) |
602 |
| - ], |
603 |
| - state=self.state, |
604 |
| - block_info_list=state_lists[ |
605 |
| - DISTRIBUTOR |
606 |
| - ].local_block_info_list, |
607 |
| - beta2=group[BETAS][1], |
608 |
| - epsilon=group[EPSILON], |
609 |
| - use_bias_correction=group[USE_BIAS_CORRECTION], |
610 |
| - ) |
611 |
| - ) |
612 |
| - case SignDescentPreconditionerConfig(): |
613 |
| - preconditioner_list_cls = partial(SignDescentPreconditionerList) |
614 |
| - case SpectralDescentPreconditionerConfig(): |
615 |
| - assert ( |
616 |
| - group[DISTRIBUTED_CONFIG].target_parameter_dimensionality == 2 |
617 |
| - ), f"{group[DISTRIBUTED_CONFIG].target_parameter_dimensionality=} must be 2 when using SpectralDescentPreconditionerConfig." |
618 |
| - preconditioner_list_cls = SpectralDescentPreconditionerList |
619 |
| - case _: |
620 |
| - raise NotImplementedError( |
621 |
| - f"{group[PRECONDITIONER_CONFIG]=} not supported!" |
622 |
| - ) |
623 |
| - |
624 |
| - state_lists[SHAMPOO_PRECONDITIONER_LIST] = preconditioner_list_cls( |
625 |
| - block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
626 |
| - preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 652 | + assert ( |
| 653 | + group[PRECONDITIONER_CONFIG] is not None |
| 654 | + ), f"{group[PRECONDITIONER_CONFIG]=} is None. Please check the instantiation of DistributedShampoo." |
| 655 | + state_lists[SHAMPOO_PRECONDITIONER_LIST] = ( |
| 656 | + self._preconditioner_config_to_list_cls( |
| 657 | + state_lists=state_lists, |
| 658 | + group=group, |
| 659 | + preconditioner_config=group[PRECONDITIONER_CONFIG], |
| 660 | + ) |
627 | 661 | )
|
628 | 662 |
|
629 | 663 | @torch.no_grad()
|
630 | 664 | def _instantiate_grafting(self) -> None:
|
631 | 665 | for state_lists, group in zip(
|
632 | 666 | self._per_group_state_lists, self.param_groups, strict=True
|
633 | 667 | ):
|
634 |
| - match group[GRAFTING_CONFIG]: |
635 |
| - case None: |
636 |
| - state_lists[GRAFTING_PRECONDITIONER_LIST] = None |
637 |
| - case SGDGraftingConfig(): |
638 |
| - state_lists[GRAFTING_PRECONDITIONER_LIST] = SGDPreconditionerList( |
639 |
| - block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
640 |
| - ) |
641 |
| - case ( |
642 |
| - AdaGradGraftingConfig() |
643 |
| - | RMSpropGraftingConfig() |
644 |
| - | AdamGraftingConfig() |
645 |
| - ): |
646 |
| - state_lists[GRAFTING_PRECONDITIONER_LIST] = ( |
647 |
| - AdagradPreconditionerList( |
648 |
| - block_list=state_lists[DISTRIBUTOR].local_blocked_params, |
649 |
| - state=self.state, |
650 |
| - block_info_list=state_lists[ |
651 |
| - DISTRIBUTOR |
652 |
| - ].local_block_info_list, |
653 |
| - beta2=( |
654 |
| - 1.0 |
655 |
| - if type(group[GRAFTING_CONFIG]) is AdaGradGraftingConfig |
656 |
| - else group[GRAFTING_CONFIG].beta2 |
657 |
| - ), |
658 |
| - epsilon=group[GRAFTING_CONFIG].epsilon, |
659 |
| - use_bias_correction=type(group[GRAFTING_CONFIG]) |
660 |
| - is AdamGraftingConfig, |
661 |
| - ) |
662 |
| - ) |
663 |
| - case _: |
664 |
| - raise NotImplementedError( |
665 |
| - f"{group[GRAFTING_CONFIG]=} not supported!" |
666 |
| - ) |
| 668 | + state_lists[GRAFTING_PRECONDITIONER_LIST] = ( |
| 669 | + self._preconditioner_config_to_list_cls( |
| 670 | + state_lists=state_lists, |
| 671 | + group=group, |
| 672 | + preconditioner_config=group[GRAFTING_CONFIG], |
| 673 | + ) |
| 674 | + ) |
667 | 675 |
|
668 | 676 | @torch.no_grad()
|
669 | 677 | def _instantiate_steps(self) -> None:
|
|
0 commit comments