Skip to content

Commit

Permalink
Merge branch 'main' into chameleon
Browse files Browse the repository at this point in the history
  • Loading branch information
AkshitaB authored Jul 17, 2024
2 parents 006bc60 + ab63296 commit e44f991
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 21 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Changed default distributed training strategy from single-GPU to FSDP
- Fixed behavior of `effective_memmap_dtype` to prevent unrecognized dtypes to be parsed as `uint16`.

## [v0.4.0](https://github.com/allenai/OLMo/releases/tag/v0.4.0) - 2024-07-11

### Added

- Added clipping fix to `Optimizer` class to make it work with FSDP `no_shard` and DDP.
- Added tests to compare grad norm differences between torch optimizer and clipping and OLMo optimizer and clipping on both CPU and GPU.
- Expose memmap dtype in data config
- Expose memmap dtype in data config
- Added support for DDP training.
- Added caching to disk of HF datasets used in downstream evals
- Added FLOPs logging
Expand Down
17 changes: 7 additions & 10 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,16 +598,13 @@ class DataConfig(BaseConfig):

@property
def effective_memmap_dtype(self):
if self.memmap_dtype == "uint8":
return np.uint8
if self.memmap_dtype == "uint16":
return np.uint16
elif self.memmap_dtype == "uint32":
return np.uint32
elif self.memmap_dtype == "uint64":
return np.uint64
# default to uint16 if not set
return np.uint16
try:
# getattr will check this is part of numpy module, while np.dtype will check
# if this is a valid numpy dtype.
np.dtype(dtype := getattr(np, self.memmap_dtype))
except (AttributeError, TypeError) as e:
raise TypeError(f"Value {self.memmap_dtype} is not a valid numpy type") from e
return dtype


class EvaluatorType(StrEnum):
Expand Down
26 changes: 16 additions & 10 deletions tests/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass
from pathlib import Path
from typing import List
from unittest import TestCase

import numpy

Expand Down Expand Up @@ -50,13 +51,18 @@ def test_new():
assert config.seed == 2


def test_data_config():
data_config = DataConfig.new()
assert data_config.memmap_dtype == "uint16"
assert data_config.effective_memmap_dtype == numpy.uint16
data_config.memmap_dtype = "uint32"
assert data_config.effective_memmap_dtype == numpy.uint32
data_config.memmap_dtype = "uint64"
assert data_config.effective_memmap_dtype == numpy.uint64
data_config.memmap_dtype = "unknown"
assert data_config.effective_memmap_dtype == numpy.uint16
class TestDataConfig(TestCase):
def test_data_config(self):
data_config = DataConfig.new()
self.assertEqual(data_config.memmap_dtype, "uint16")
self.assertEqual(data_config.effective_memmap_dtype, numpy.uint16)

data_config.memmap_dtype = "uint32"
self.assertEqual(data_config.effective_memmap_dtype, numpy.uint32)

data_config.memmap_dtype = "uint64"
self.assertEqual(data_config.effective_memmap_dtype, numpy.uint64)

data_config.memmap_dtype = "unknown"
with self.assertRaises(TypeError):
data_config.effective_memmap_dtype

0 comments on commit e44f991

Please sign in to comment.