Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for BF16 optim state in SkipStepAdamW #148

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added the option to throttle checkpoint uploads to one rank from each node at a time.
- Added `unshard_strategy` parameter to `unshard_checkpoint()` function in `olmo_coer.distributed.checkpoint`.
- Added function `load_keys()` to `olmo_core.distributed.checkpoint`.
- Added support for low precision optim state in `SkipStepAdamW`.

### Changed

Expand Down
14 changes: 10 additions & 4 deletions src/olmo_core/optim/adamw.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Type
from typing import Optional, Tuple, Type, Union

import torch
import torch.nn as nn

from ..config import DType
from .config import OptimConfig
from .skip_step_optimizer import SkipStepOptimizer

Expand All @@ -29,7 +30,7 @@ def adamw_step(
p.mul_(1 - step_factor * (lr * weight_decay))

# Decay the first and second moment running average coefficient.
exp_avg.lerp_(p.grad, step_factor * (1 - beta1))
exp_avg.lerp_(p.grad.type_as(exp_avg), (step_factor * (1 - beta1)).type_as(exp_avg))
exp_avg_sq.mul_(1 - step_factor * (1 - beta2))
exp_avg_sq.add_(step_factor * p.grad * p.grad, alpha=1 - beta2)

Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
fused: Optional[bool] = None,
rolling_interval_length: int = 128,
sigma_factor: int = 6,
dtype: Optional[Union[torch.dtype, DType]] = None,
) -> None:
assert lr > 0.0
assert all([0.0 <= beta <= 1.0 for beta in betas])
Expand All @@ -73,6 +75,9 @@ def __init__(
rolling_interval_length=rolling_interval_length,
sigma_factor=sigma_factor,
)
if isinstance(dtype, DType):
dtype = dtype.as_pt()
self.dtype = dtype
self._step_skipped: Optional[torch.Tensor] = None

@property
Expand All @@ -98,8 +103,8 @@ def step(self, closure=None) -> None:
state = self.state[p]
if len(state) == 0:
state["step"] = torch.tensor(0.0, dtype=torch.float32, device=p.device)
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["exp_avg"] = torch.zeros_like(p, dtype=self.dtype)
state["exp_avg_sq"] = torch.zeros_like(p, dtype=self.dtype)

adamw_step(
p,
Expand Down Expand Up @@ -144,6 +149,7 @@ class SkipStepAdamWConfig(OptimConfig):
weight_decay: float = 1e-2
rolling_interval_length: int = 128
sigma_factor: int = 6
dtype: Optional[DType] = None

@classmethod
def optimizer(cls) -> Type[SkipStepAdamW]:
Expand Down
13 changes: 10 additions & 3 deletions src/test/optim/adamw_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from test.utils import DEVICES
from typing import Optional

import pytest
import torch
import torch.nn as nn

from olmo_core.config import DType
from olmo_core.optim import AdamWConfig, OptimGroupOverride, SkipStepAdamWConfig

from ..utils import DEVICES


class MyModel(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -64,8 +67,12 @@ def test_adamw(device: torch.device):


@pytest.mark.parametrize("device", DEVICES)
def test_skip_step_adamw(device: torch.device):
config = SkipStepAdamWConfig()
@pytest.mark.parametrize("dtype", [None, DType.bfloat16])
def test_skip_step_adamw(device: torch.device, dtype: Optional[DType]):
if dtype == DType.bfloat16 and device.type == "cpu":
pytest.skip("bfloat16 dtype requires cuda")

config = SkipStepAdamWConfig(dtype=dtype)
model = MyModel().train().to(device)
optim = config.build(model)

Expand Down
Loading