Skip to content

Commit

Permalink
Merge pull request LeelaChessZero#5 from almaudoh/attention-net-updates
Browse files Browse the repository at this point in the history
Fix net.py to match current proto.
  • Loading branch information
Ergodice authored Oct 11, 2023
2 parents d6f59b9 + e1a71d9 commit 1eacb21
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
2 changes: 1 addition & 1 deletion libs/lczero-common
Submodule lczero-common updated 1 files
+200 −1 proto/net.proto
12 changes: 11 additions & 1 deletion tf/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LC0_MINOR_WITH_INPUT_TYPE_5 = 27
LC0_MINOR_WITH_MISH = 29
LC0_MINOR_WITH_ATTN_BODY = 30
LC0_MINOR_WITH_MULTIHEAD = 31
LC0_PATCH = 0
WEIGHTS_MAGIC = 0x1c0

Expand Down Expand Up @@ -63,6 +64,9 @@ def set_networkformat(self, net):
if net == pb.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT \
and self.pb.min_version.minor < LC0_MINOR_WITH_ATTN_BODY:
self.pb.min_version.minor = LC0_MINOR_WITH_ATTN_BODY
if net == pb.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT \
and self.pb.min_version.minor < LC0_MINOR_WITH_MULTIHEAD:
self.pb.min_version.minor = LC0_MINOR_WITH_MULTIHEAD

def set_policyformat(self, policy):
self.pb.format.network_format.policy = policy
Expand Down Expand Up @@ -113,6 +117,11 @@ def set_ffn_activation(self, activation):
self.pb.min_version.minor = LC0_MINOR_WITH_ATTN_BODY
return None

def set_input_embedding(self, embedding):
self.pb.format.network_format.input_embedding = embedding
if self.pb.min_version.minor < LC0_MINOR_WITH_MULTIHEAD:
self.pb.min_version.minor = LC0_MINOR_WITH_MULTIHEAD

def activation(self, name):
if name == "relu":
return pb.NetworkFormat.ACTIVATION_RELU
Expand Down Expand Up @@ -451,7 +460,8 @@ def moves_left_to_bp(l, w):
# pb_name = 'policy.' + convblock_to_bp(weights_name)

elif base_layer == 'value':
if layers[1] in ['st', 'vanilla']:
pb_prefix = ''
if layers[1] in ['st', 'q', 'winner']:
pb_prefix = 'value_heads.' + layers[1] + '.'
if 'dense' in layers[2] or 'embedding' in layers[2]:
pb_name = value_to_bp(layers[2], weights_name)
Expand Down
14 changes: 12 additions & 2 deletions tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,22 @@ def __init__(self, cfg):
if self.encoder_layers > 0:
self.net.set_headcount(self.encoder_heads)
self.net.set_networkformat(
pb.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_HEADFORMAT)
pb.NetworkFormat.NETWORK_ATTENTIONBODY_WITH_MULTIHEADFORMAT)
self.net.set_smolgen_activation(
self.net.activation(self.smolgen_activation))
self.net.set_ffn_activation(self.net.activation(
'default'))

if self.embedding_style == "new":
self.net.set_input_embedding(
pb.NetworkFormat.INPUT_EMBEDDING_PE_DENSE)
elif self.encoder_layers > 0:
self.net.set_input_embedding(
pb.NetworkFormat.INPUT_EMBEDDING_PE_MAP)
else:
self.net.set_input_embedding(
pb.NetworkFormat.INPUT_EMBEDDING_NONE)

self.ffn_activation = self.cfg["model"].get(
"ffn_activation", self.DEFAULT_ACTIVATION)

Expand Down Expand Up @@ -1951,4 +1961,4 @@ def apply_sparsity(self):
for layer in self.model.layers:
if layer.name in self.sparsity_patterns:
kernel = layer.kernel
kernel.assign(kernel * self.sparsity_patterns[layer.name])
kernel.assign(kernel * self.sparsity_patterns[layer.name])

0 comments on commit 1eacb21

Please sign in to comment.