Skip to content

Commit

Permalink
[Bugfix] Fix schedule and dockerfile (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jan 25, 2023
1 parent 2bc223c commit 681f2b8
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ RUN cd $HOME/epoi && git fetch && git checkout b2e2e98 && pip3 install -e ".[dev
RUN git clone https://github.com/huggingface/transformers.git $HOME/transformers
RUN cd $HOME/transformers && git checkout 2bdd9fa && pip3 install -e ".[dev]" --no-deps

# FIXME Install official DeepSpeed
USER root
RUN pip3 install deepspeed==0.6.5
USER deepspeed

# Fix dependencies
RUN pip3 install huggingface-hub tokenizers numpy==1.23.4 datasets

Expand Down
3 changes: 2 additions & 1 deletion examples/gpt/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,8 @@ def fwd_post_hook(_module, _input, output):
sch[word_embed_name].sync(mode="fwd_post", sync_op_or_fn=fwd_post_hook)

# Shard output embedding.
head_sch.shard("weight", axis=0)
if head_sch is not None:
head_sch.shard("weight", axis=0)


def shard_qkv(
Expand Down
3 changes: 2 additions & 1 deletion examples/opt/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def fwd_post_hook(_module, _input, output):
sch[word_embed_name].sync(mode="fwd_post", sync_op_or_fn=fwd_post_hook)

# Shard output embedding.
head_sch.shard("weight", axis=0)
if head_sch is not None:
head_sch.shard("weight", axis=0)


def shard_qkv(
Expand Down

0 comments on commit 681f2b8

Please sign in to comment.