Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 24 additions & 28 deletions distributed_shampoo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,7 @@ Ganesh Ajjanagadde (Meta), Rohan Anil (Google), Adnan Aziz (Meta), Pavan Balaji

Key distinctives of this implementation include:
- Homogeneous multi-node multi-GPU support in PyTorch.
- Learning rate grafting [3]. Our version of grafting only grafts the second moment/diagonal preconditioner. Momentum/first moment updates are performed separate from grafting. Supports the methods:
- SGD
- Adagrad
- RMSprop
- Adam
- Learning rate grafting [3]. Our version of grafting only grafts the second moment/diagonal preconditioner. Momentum/first moment updates are performed separate from grafting.
- Supports both normal and AdamW (decoupled) weight decay.
- Incorporates exponential moving averaging (with or without bias correction) to the estimate the first moment (akin to Adam).
- Incorporates momentum and Nesterov acceleration.
Expand Down Expand Up @@ -91,7 +87,7 @@ optimizer = SGD(
we would instead use:
```python
import torch
from distributed_shampoo import DistributedShampoo, SGDGraftingConfig
from distributed_shampoo import DistributedShampoo, SGDPreconditionerConfig

model = instantiate_model()

Expand All @@ -104,7 +100,7 @@ optimizer = DistributedShampoo(
weight_decay=1e-05,
max_preconditioner_dim=8192,
precondition_frequency=100,
grafting_config=SGDGraftingConfig(),
grafting_config=SGDPreconditionerConfig(),
)
```

Expand All @@ -129,7 +125,7 @@ optimizer = Adam(
we would instead use:
```python
import torch
from distributed_shampoo import AdamGraftingConfig, DistributedShampoo
from distributed_shampoo import AdamPreconditionerConfig, DistributedShampoo

model = instantiate_model()

Expand All @@ -142,7 +138,7 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=False,
grafting_config=AdamGraftingConfig(
grafting_config=AdamPreconditionerConfig(
beta2=0.999,
epsilon=1e-08,
),
Expand All @@ -168,7 +164,7 @@ optimizer = Adagrad(
we would instead use:
```python
import torch
from distributed_shampoo import AdaGradGraftingConfig, DistributedShampoo
from distributed_shampoo import AdaGradPreconditionerConfig, DistributedShampoo

model = instantiate_model()

Expand All @@ -181,7 +177,7 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=False,
grafting_config=AdaGradGraftingConfig(
grafting_config=AdaGradPreconditionerConfig(
epsilon=1e-10,
),
)
Expand All @@ -207,7 +203,7 @@ optimizer = AdamW(
we would instead use:
```python
import torch
from distributed_shampoo import AdamGraftingConfig, DistributedShampoo
from distributed_shampoo import AdamPreconditionerConfig, DistributedShampoo

model = instantiate_model()

Expand All @@ -220,7 +216,7 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
grafting_config=AdamPreconditionerConfig(
beta2=0.999,
epsilon=1e-08,
),
Expand Down Expand Up @@ -308,8 +304,8 @@ optimizer = DistributedShampoo(
{
"params": other_params,
"lr": 3e-4,
"start_preconditioning_step", math.inf,
"grafting_config": AdamGraftingConfig(
"start_preconditioning_step": math.inf,
"grafting_config": AdamPreconditionerConfig(
beta2=0.95,
epsilon=1e-10,
),
Expand Down Expand Up @@ -343,7 +339,7 @@ import torch
import torch.distributed as dist

from distributed_shampoo import (
AdamGraftingConfig,
AdamPreconditionerConfig,
DDPDistributedConfig,
DistributedShampoo,
)
Expand Down Expand Up @@ -376,7 +372,7 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
grafting_config=AdamPreconditionerConfig(
beta2=0.999,
epsilon=1e-12,
),
Expand Down Expand Up @@ -404,7 +400,7 @@ import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from distributed_shampoo import (
AdamGraftingConfig,
AdamPreconditionerConfig,
compile_fsdp_parameter_metadata,
DistributedShampoo,
FSDPDistributedConfig,
Expand Down Expand Up @@ -434,7 +430,7 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
grafting_config=AdamPreconditionerConfig(
beta2=0.999,
epsilon=1e-12,
),
Expand All @@ -456,7 +452,7 @@ import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

from distributed_shampoo import (
AdamGraftingConfig,
AdamPreconditionerConfig,
compile_fsdp_parameter_metadata,
DistributedShampoo,
HSDPDistributedConfig,
Expand Down Expand Up @@ -493,7 +489,7 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
grafting_config=AdamPreconditionerConfig(
beta2=0.999,
epsilon=1e-12,
),
Expand All @@ -519,7 +515,7 @@ import torch.distributed as dist
from torch.distributed.fsdp import fully_shard

from distributed_shampoo import (
AdamGraftingConfig,
AdamPreconditionerConfig,
DistributedShampoo,
FullyShardDistributedConfig,
)
Expand Down Expand Up @@ -548,7 +544,7 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
grafting_config=AdamPreconditionerConfig(
beta2=0.999,
epsilon=1e-12,
),
Expand All @@ -570,7 +566,7 @@ import torch.distributed as dist
from torch.distributed.fsdp import fully_shard

from distributed_shampoo import (
AdamGraftingConfig,
AdamPreconditionerConfig,
DistributedShampoo,
HybridShardDistributedConfig,
)
Expand Down Expand Up @@ -606,7 +602,7 @@ optimizer = DistributedShampoo(
max_preconditioner_dim=8192,
precondition_frequency=100,
use_decoupled_weight_decay=True,
grafting_config=AdamGraftingConfig(
grafting_config=AdamPreconditionerConfig(
beta2=0.999,
epsilon=1e-12,
),
Expand Down Expand Up @@ -678,7 +674,7 @@ With the inclusion of learning rate grafting, we can extract a good learning rat
momentum=0.9,
weight_decay=0.01,
max_preconditioner_dim=4096,
grafting_config=SGDGraftingConfig(),
grafting_config=SGDPreconditionerConfig(),
)
```

Expand All @@ -699,7 +695,7 @@ With the inclusion of learning rate grafting, we can extract a good learning rat
momentum=0.9,
weight_decay=0.01,
precondition_frequency=100,
grafting_config=SGDGraftingConfig(),
grafting_config=SGDPreconditionerConfig(),
)
```

Expand All @@ -718,7 +714,7 @@ With the inclusion of learning rate grafting, we can extract a good learning rat
momentum=0.9,
weight_decay=0.01,
start_preconditioning_step=300,
grafting_config=SGDGraftingConfig(),
grafting_config=SGDPreconditionerConfig(),
)
```

Expand Down
2 changes: 1 addition & 1 deletion distributed_shampoo/preconditioner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ optimizer = DistributedShampoo(
precondition_frequency=50,
use_bias_correction=True,
use_decoupled_weight_decay=True,
grafting_config=RMSpropGraftingConfig(beta2=0.95, epsilon=1e-8),
grafting_config=RMSpropPreconditionerConfig(beta2=0.95, epsilon=1e-8),
preconditioner_config=RootInvShampooPreconditionerConfig(
amortized_computation_config=EigenConfig(
max_iterations=1000,
Expand Down