Skip to content

Commit

Permalink
Update 'Gather-Excite'
Browse files Browse the repository at this point in the history
  • Loading branch information
ffiirree committed Dec 4, 2022
1 parent 1adb783 commit b87ba3e
Showing 1 changed file with 4 additions and 18 deletions.
22 changes: 4 additions & 18 deletions cvm/models/ops/blocks/gather_excite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,31 @@
from functools import partial
from contextlib import contextmanager
from torch import nn
import torch.nn.functional as F
from . import norm_act
from .vanilla_conv2d import Conv2d1x1
from ..functional import make_divisible
from typing import OrderedDict
from .depthwise_separable_conv2d import DepthwiseBlock, DepthwiseConv2dBN
from .interpolate import Interpolate

_GE_INNER_NONLINEAR: nn.Module = partial(nn.ReLU, inplace=True)
_GE_GATING_FN: nn.Module = nn.Sigmoid
_GE_DIVISOR: int = 8
_GE_USE_NORM: bool = True


@contextmanager
def ge(
inner_nonlinear: nn.Module = _GE_INNER_NONLINEAR,
gating_fn: nn.Module = _GE_GATING_FN,
divisor: int = _GE_DIVISOR,
use_norm: bool = _GE_USE_NORM
gating_fn: nn.Module = _GE_GATING_FN
):
global _GE_INNER_NONLINEAR
global _GE_GATING_FN
global _GE_DIVISOR
global _GE_USE_NORM

_pre_inner_fn = _GE_INNER_NONLINEAR
_pre_fn = _GE_GATING_FN
_pre_divisor = _GE_DIVISOR
_pre_use_norm = _GE_USE_NORM

_GE_INNER_NONLINEAR = inner_nonlinear
_GE_GATING_FN = gating_fn
_GE_DIVISOR = divisor
_GE_USE_NORM = use_norm

yield

_GE_INNER_NONLINEAR = _pre_inner_fn
_GE_GATING_FN = _pre_fn
_GE_DIVISOR = _pre_divisor
_GE_USE_NORM = _pre_use_norm


class GatherExciteBlock(nn.Module):
Expand Down

0 comments on commit b87ba3e

Please sign in to comment.