Skip to content

Commit

Permalink
Add merge_bn() call to Risev3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
QueensGambit committed May 6, 2024
1 parent ce12a9c commit 5a12a23
Showing 1 changed file with 13 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
https://arxiv.org/pdf/1807.06521.pdf
"""
import logging

import torch
from torch.nn import Sequential, Conv2d, BatchNorm2d, Module
from timm.models.layers import DropPath
Expand Down Expand Up @@ -145,11 +147,11 @@ def __init__(self, nb_input_channels, board_height, board_width,
raise Exception(f"Unavailable se_type: {se_type}. Available se_types include {se_types}")

path_dropout_rates = [x.item() for x in torch.linspace(0, path_dropout, len(kernels))] # stochastic depth decay rule
res_blocks = _get_res_blocks(act_types, channels, channels_operating_init, channel_expansion, kernels, se_types, use_transformers, path_dropout_rates, conv_block, kernel_5_channel_ratio, round_channels_to_next_32)
self.res_blocks = _get_res_blocks(act_types, channels, channels_operating_init, channel_expansion, kernels, se_types, use_transformers, path_dropout_rates, conv_block, kernel_5_channel_ratio, round_channels_to_next_32)

self.body_spatial = Sequential(
_Stem(channels=channels, act_type=act_types[0], nb_input_channels=nb_input_channels),
*res_blocks,
*self.res_blocks,
)
self.nb_body_spatial_out = channels * board_height * board_width

Expand All @@ -171,6 +173,15 @@ def forward(self, x):

return process_value_policy_head(out, self.value_head, self.policy_head, self.use_plys_to_end, self.use_wdl)

def merge_bn(self):
"""
Calls the merge_bn() function for the NTB blocks
"""
for res_block in self.res_blocks:
if isinstance(res_block, NTB):
res_block.merge_bn()
logging.info("Called merge_bn()")


def get_rise_v33_model(args):
"""
Expand Down

0 comments on commit 5a12a23

Please sign in to comment.