-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes for 65B and 70B runs #414
Changes from 1 commit
e7a73ce
ead441e
9152ea5
9b2de63
f860ad1
672e4b4
41cfd01
ea81867
1690d8b
98e93e4
0348696
f1bab7a
c5c5e3a
09beb0c
d0500ab
1acac5c
9c6a144
53f9ba6
d4e5ef7
d4b14f9
bc41306
fbc8cba
95dd681
b9ed52c
5bf7211
559ab28
96abd69
6980f07
e04ee47
c91ecb4
d8c1ef9
013bd44
c9c928d
e286076
275463a
cb1343b
ee49920
010ef38
4ffa5ea
e6fd8f3
5369c4c
e4894bb
859eb9c
ea847d0
4f586ba
f66c248
527f3ed
deb353b
241761f
9d4c9fe
05003ef
f84d7e4
95fd624
2f70619
7925438
b82a51f
20d6a4c
523b23b
4ab7aa4
9e6f251
666ffbb
dd59f98
37eaf9c
4cb93c4
c3084c5
4fccc21
54919e0
b66a599
f2331a8
2292afd
21e288a
ea13729
394bb22
f537966
951b7ca
ddf8467
05d29b8
e80587a
e9c6f20
66a99ff
f84bb62
c7962be
b58fef4
cfc362c
1ae86c5
9556dd9
1e2729c
f60ae3a
40cf61c
5da1c7e
1e60b73
fb15986
f2fdf99
1394236
3897360
d625ccb
bc73854
17b3316
058e20e
7ddb124
671c852
439e18f
fba3dee
b99b93e
6103643
9ed5965
8133f58
026c26c
ca0fe2b
856860d
b1e3855
fa8ec33
286675e
12755a4
d3f0e41
bc9fe29
832bd31
fc64b77
74d0527
6a97372
d31f14c
888fc44
169b7b8
28719d8
824c985
8e80f23
f8580eb
c6c114d
92d2a08
ed0c5b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -243,6 +243,14 @@ class ModelConfig(BaseConfig): | |
The number of self-attention heads. | ||
""" | ||
|
||
n_kv_heads: Optional[int] = None | ||
""" | ||
The number of heads to use for keys and values. | ||
Set this to ``None`` or ``n_heads`` for normal multi-head attention. | ||
Set this to 1 for multi-query attention. | ||
Set it to some in-between value for Llama2-style grouped query attention. | ||
""" | ||
|
||
n_layers: int = 12 | ||
""" | ||
The number of layers/blocks. | ||
|
@@ -309,8 +317,7 @@ class ModelConfig(BaseConfig): | |
|
||
multi_query_attention: bool = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this |
||
""" | ||
Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters | ||
and is more efficient during inference. | ||
Deprecated. Use n_kv_heads instead. | ||
""" | ||
|
||
attention_layer_norm: bool = False | ||
|
@@ -428,6 +435,29 @@ class ModelConfig(BaseConfig): | |
See :data:`TrainConfig.precision` instead. | ||
""" | ||
|
||
def __post_init__(self): | ||
if self.n_kv_heads is None: | ||
self.n_kv_heads = self.n_heads | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then here we could do this: if self.multi_query_attention:
self.n_kv_heads = 1
elif self.n_kv_heads is None:
self.n_kv_heads = self.n_heads |
||
|
||
@classmethod | ||
def update_legacy_settings(cls, config: D) -> D: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then this won't be needed. |
||
new_config = config.copy() | ||
if om.is_dict(new_config): | ||
assert isinstance(new_config, DictConfig) | ||
|
||
if hasattr(new_config, "multi_query_attention"): | ||
Muennighoff marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if hasattr(new_config, "n_kv_heads") and new_config.n_kv_heads is not None: | ||
raise OlmoConfigurationError("You can't specify both `multi_query_attention` and `n_kv_heads`. Specify only `n_kv_heads`.") | ||
if new_config.multi_query_attention: | ||
new_config.n_kv_heads = 1 | ||
else: | ||
new_config.n_kv_heads = new_config.n_heads | ||
|
||
if hasattr(new_config, "optimizer"): | ||
new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am learning that |
||
|
||
return new_config | ||
|
||
|
||
class OptimizerType(StrEnum): | ||
lionw = "lionw" | ||
|
@@ -1036,4 +1066,7 @@ def update_legacy_settings(cls, config: D) -> D: | |
if hasattr(new_config, "optimizer"): | ||
new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer) | ||
|
||
if hasattr(new_config, "model"): | ||
new_config.model = ModelConfig.update_legacy_settings(new_config.model) | ||
|
||
return new_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, I just can't click "commit" here for some reason.