Skip to content

Commit

Permalink
Merge pull request hpcaitech#6149 from ver217/hotfix/ckpt
Browse files Browse the repository at this point in the history
[checkpointio] disable buffering
  • Loading branch information
wangbluo authored Nov 21, 2024
2 parents 152162a + 8fddbab commit 8ecff0c
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
8 changes: 6 additions & 2 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def save_unsharded_optimizer(

from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
f_writer = AsyncFileWriter(
fp=open(checkpoint, "wb", buffering=0), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
)
save_nested(f_writer, state_dict)
self.async_writers.append(f_writer)
else:
Expand Down Expand Up @@ -225,7 +227,9 @@ def save_sharded_optimizer(
from colossalai.utils.safetensors import save_nested

f_writer = AsyncFileWriter(
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
fp=open(checkpoint_file_path, "wb", buffering=0),
n_entries=self.N_WRITE_ENTRIES,
backend="pthread",
)
save_nested(f_writer, shard)
self.async_writers.append(f_writer)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def save_unsharded_model(
if use_async:
from tensornvme.async_file_io import AsyncFileWriter

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
writer = AsyncFileWriter(open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread")
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
Expand Down
4 changes: 3 additions & 1 deletion colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,9 @@ def save_unsharded_model(

from colossalai.utils.safetensors import move_and_save

writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread")
writer = AsyncFileWriter(
open(checkpoint, "wb", buffering=0), self.N_WRITE_ENTRIES, backend="pthread"
)
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
self.async_writers.append(writer)
Expand Down
2 changes: 1 addition & 1 deletion colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def async_save_state_dict_shards(
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)

writer = AsyncFileWriter(open(checkpoint_file_path, "wb"), n_write_entries, backend="pthread")
writer = AsyncFileWriter(open(checkpoint_file_path, "wb", buffering=0), n_write_entries, backend="pthread")
writers.append(writer)

if pinned_state_dict is not None:
Expand Down

0 comments on commit 8ecff0c

Please sign in to comment.