From 1fba8eabf74822303b7483907626df987eeda714 Mon Sep 17 00:00:00 2001 From: Ludvig Ericson Date: Wed, 8 Feb 2023 13:21:49 +0100 Subject: [PATCH] Add batch_first, dtype, device arguments Also cleaned up imports. --- rezero/transformer/rztx.py | 89 +++++++++++++++++++++++--------------- 1 file changed, 54 insertions(+), 35 deletions(-) diff --git a/rezero/transformer/rztx.py b/rezero/transformer/rztx.py index 58fbff0..3d5f211 100644 --- a/rezero/transformer/rztx.py +++ b/rezero/transformer/rztx.py @@ -1,21 +1,12 @@ -import math +from typing import Optional + import torch -import torch.nn as nn -import torch.nn.functional as F - -from torch.nn.parameter import Parameter -from torch.nn.init import xavier_uniform_ -from torch.nn.init import constant_ -from torch.nn.init import xavier_normal_ -from torch.nn.modules.module import Module -from torch.nn.modules.activation import MultiheadAttention -from torch.nn.modules.container import ModuleList -from torch.nn.init import xavier_uniform_ -from torch.nn.modules.dropout import Dropout -from torch.nn.modules.linear import Linear -from torch.nn import TransformerEncoder - -class RZTXEncoderLayer(Module): +from torch import nn +from torch import Tensor +from torch.nn import functional as F + + +class RZTXEncoderLayer(nn.Module): r"""RZTXEncoderLayer is made up of self-attn and feedforward network with residual weights for faster convergece. This encoder layer is based on the paper "ReZero is All You Need: @@ -34,17 +25,20 @@ class RZTXEncoderLayer(Module): >>> src = torch.rand(10, 32, 512) >>> out = encoder_layer(src) """ - def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu'): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation='relu', batch_first=False, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs) # Implementation of Feedforward model - self.linear1 = Linear(d_model, dim_feedforward) - self.dropout = Dropout(dropout) - self.linear2 = Linear(dim_feedforward, d_model) - self.dropout1 = Dropout(dropout) - self.dropout2 = Dropout(dropout) - self.resweight = nn.Parameter(torch.Tensor([0])) + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.resweight = nn.Parameter(torch.zeros((1,), **factory_kwargs)) if activation == "relu": self.activation = F.relu @@ -57,7 +51,7 @@ def __setstate__(self, state): super().__setstate__(state) def forward(self, src, src_mask=None, src_key_padding_mask=None): - # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor + # type: (torch.Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). @@ -70,17 +64,18 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None): src2 = src src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) - src2 = src2[0] # no attention weights + src2 = src2[0] # no attention weights src2 = src2 * self.resweight src = src + self.dropout1(src2) # Pointiwse FF Layer - src2 = src + src2 = src src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) src2 = src2 * self.resweight src = src + self.dropout2(src2) return src + class RZTXDecoderLayer(nn.Module): r"""RZTXDecoderLayer is made up of self-attn and feedforward network with residual weights for faster convergece. @@ -100,19 +95,30 @@ class RZTXDecoderLayer(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = decoder_layer(src) """ - def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation='relu', batch_first=False, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.self_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + **factory_kwargs) + self.multihead_attn = nn.MultiheadAttention(d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + **factory_kwargs) # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) + self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) - self.resweight = nn.Parameter(torch.Tensor([0])) + self.resweight = nn.Parameter(torch.zeros((1,), **factory_kwargs)) if activation == "relu": self.activation = F.relu @@ -146,4 +152,17 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, else: # for backward compatibility tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt)))) tgt = tgt + self.dropout3(tgt2) * self.resweight - return tgt \ No newline at end of file + return tgt + + +def Transformer(*, num_encoder_layers, num_decoder_layers, **kw): + layer_keywords = {'d_model', 'nhead', 'dim_feedforward', 'dropout', + 'activation', 'batch_first', 'device', 'dtype'} + layer_kwargs = {k: v for k, v in kw.items() if k in layer_keywords} + encoder_layer = RZTXEncoderLayer(**layer_kwargs) + decoder_layer = RZTXDecoderLayer(**layer_kwargs) + encoder_norm = None + decoder_norm = None + encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) + return nn.Transformer(**kw, custom_encoder=encoder, custom_decoder=decoder)