You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
1.module file
import math
from typing import Optional
import torch
from torch import Tensor, nn as nn
from torch.nn import functional as F
from torch.nn.modules import transformer
from timm.models.vision_transformer import PatchEmbed, VisionTransformer
class DecoderLayer(nn.Module):
"""A Transformer decoder layer supporting two-stream attention (XLNet)
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', layer_norm_eps=1e-5):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm_c = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = transformer._get_activation_fn(activation)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.gelu
super().__setstate__(state)
def forward_stream(
self,
tgt: Tensor,
tgt_norm: Tensor,
tgt_kv: Tensor,
memory: Tensor,
tgt_mask: Optional[Tensor],
tgt_key_padding_mask: Optional[Tensor],
):
"""Forward pass for a single stream (i.e. content or query)
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
Both tgt_kv and memory are expected to be LayerNorm'd too.
memory is LayerNorm'd by ViT.
"""
tgt2, sa_weights = self.self_attn(
tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)
tgt = tgt + self.dropout1(tgt2)
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt)))))
tgt = tgt + self.dropout3(tgt2)
return tgt, sa_weights, ca_weights
def forward(
self,
query,
content,
memory,
query_mask: Optional[Tensor] = None,
content_mask: Optional[Tensor] = None,
content_key_padding_mask: Optional[Tensor] = None,
update_content: bool = True,
):
query_norm = self.norm_q(query)
content_norm = self.norm_c(content)
query = self.forward_stream(query, query_norm, content_norm, memory, query_mask, content_key_padding_mask)[0]
if update_content:
content = self.forward_stream(
content, content_norm, content_norm, memory, content_mask, content_key_padding_mask
)[0]
return query, content
class Decoder(nn.Module): constants = ['norm']
def __init__(self, decoder_layer, num_layers, norm):
super().__init__()
self.layers = transformer._get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
query,
content,
memory,
query_mask: Optional[Tensor] = None,
content_mask: Optional[Tensor] = None,
content_key_padding_mask: Optional[Tensor] = None,
):
for i, mod in enumerate(self.layers):
last = i == len(self.layers) - 1
query, content = mod(
query, content, memory, query_mask, content_mask, content_key_padding_mask, update_content=not last
)
query = self.norm(query)
return query
model file
from functools import partial
from typing import Optional, Sequence
import torch
import torch.nn as nn
from torch import Tensor
from timm.models.helpers import named_apply
from strhub.data.utils import Tokenizer
from strhub.models.utils import init_weights
from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding
class PARSeq(nn.Module):
def __init__(
self,
num_tokens: int,
max_label_length: int,
img_size: Sequence[int],
patch_size: Sequence[int],
embed_dim: int,
enc_num_heads: int,
enc_mlp_ratio: int,
enc_depth: int,
dec_num_heads: int,
dec_mlp_ratio: int,
dec_depth: int,
decode_ar: bool,
refine_iters: int,
dropout: float,
) -> None:
super().__init__()
self.max_label_length = max_label_length
self.decode_ar = decode_ar
self.refine_iters = refine_iters
self.encoder = Encoder(
img_size, patch_size, embed_dim=embed_dim, depth=enc_depth, num_heads=enc_num_heads, mlp_ratio=enc_mlp_ratio
)
decoder_layer = DecoderLayer(embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout)
self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(embed_dim))
# We don't predict <bos> nor <pad>
self.head = nn.Linear(embed_dim, num_tokens - 2)
self.text_embed = TokenEmbedding(num_tokens, embed_dim)
# +1 for <eos>
self.pos_queries = nn.Parameter(torch.Tensor(1, max_label_length + 1, embed_dim))
self.dropout = nn.Dropout(p=dropout)
# Encoder has its own init.
named_apply(partial(init_weights, exclude=['encoder']), self)
nn.init.trunc_normal_(self.pos_queries, std=0.02)
@property
def _device(self) -> torch.device:
return next(self.head.parameters(recurse=False)).device
@torch.jit.ignore
def no_weight_decay(self):
param_names = {'text_embed.embedding.weight', 'pos_queries'}
enc_param_names = {'encoder.' + n for n in self.encoder.no_weight_decay()}
return param_names.union(enc_param_names)
def encode(self, img: torch.Tensor):
return self.encoder(img)
def decode(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: Optional[Tensor] = None,
tgt_padding_mask: Optional[Tensor] = None,
tgt_query: Optional[Tensor] = None,
tgt_query_mask: Optional[Tensor] = None,
):
N, L = tgt.shape
# <bos> stands for the null context. We only supply position information for characters after <bos>.
null_ctx = self.text_embed(tgt[:, :1])
tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
tgt_emb = self.dropout(torch.cat([null_ctx, tgt_emb], dim=1))
if tgt_query is None:
tgt_query = self.pos_queries[:, :L].expand(N, -1, -1)
tgt_query = self.dropout(tgt_query)
return self.decoder(tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask)
def forward(self, tokenizer: Tokenizer, images: Tensor, max_length: Optional[int] = None) -> Tensor:
testing = max_length is None
max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length)
bs = images.shape[0]
# +1 for <eos> at end of sequence.
num_steps = max_length + 1
memory = self.encode(images)
# Query positions up to `num_steps`
pos_queries = self.pos_queries[:, :num_steps].expand(bs, -1, -1)
# Special case for the forward permutation. Faster than using `generate_attn_masks()`
tgt_mask = query_mask = torch.triu(torch.ones((num_steps, num_steps), dtype=torch.float, device=self._device), 1)
if self.decode_ar:
tgt_in = torch.full((bs, num_steps), tokenizer.pad_id, dtype=torch.long, device=self._device)
tgt_in[:, 0] = tokenizer.bos_id
logits = []
for i in range(num_steps):
j = i + 1 # next token index
# Efficient decoding:
# Input the context up to the ith token. We use only one query (at position = i) at a time.
# This works because of the lookahead masking effect of the canonical (forward) AR context.
# Past tokens have no access to future tokens, hence are fixed once computed.
tgt_out = self.decode(
tgt_in[:, :j],
memory,
tgt_mask[:j, :j],
tgt_query=pos_queries[:, i:j],
tgt_query_mask=query_mask[i:j, :j],
)
# the next token probability is in the output's ith token position
p_i = self.head(tgt_out)
logits.append(p_i)
if j < num_steps:
# greedy decode. add the next token index to the target input
tgt_in[:, j] = p_i.squeeze().argmax(-1)
# Efficient batch decoding: If all output words have at least one EOS token, end decoding.
if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
break
logits = torch.cat(logits, dim=1)
else:
# No prior context, so input is just <bos>. We query all positions.
tgt_in = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
logits = self.head(tgt_out)
if self.refine_iters:
# For iterative refinement, we always use a 'cloze' mask.
# We can derive it from the AR forward mask by unmasking the token context to the right.
query_mask[torch.triu(torch.ones(num_steps, num_steps, dtype=torch.bool, device=self._device), 2)] = 0
bos = torch.full((bs, 1), tokenizer.bos_id, dtype=torch.long, device=self._device)
for i in range(self.refine_iters):
# Prior context is the previous output.
tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
# Mask tokens beyond the first EOS token.
tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
tgt_out = self.decode(
tgt_in, memory, tgt_mask, tgt_padding_mask, pos_queries, query_mask[:, : tgt_in.shape[1]]
)
logits = self.head(tgt_out)
return logits
The text was updated successfully, but these errors were encountered:
error log | 日志或报错信息 | ログ
model | 模型 | モデル
how to reproduce | 复现步骤 | 再現方法
1.module file
import math
from typing import Optional
import torch
from torch import Tensor, nn as nn
from torch.nn import functional as F
from torch.nn.modules import transformer
from timm.models.vision_transformer import PatchEmbed, VisionTransformer
class DecoderLayer(nn.Module):
"""A Transformer decoder layer supporting two-stream attention (XLNet)
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
class Decoder(nn.Module):
constants = ['norm']
class Encoder(VisionTransformer):
class TokenEmbedding(nn.Module):
from functools import partial
from typing import Optional, Sequence
import torch
import torch.nn as nn
from torch import Tensor
from timm.models.helpers import named_apply
from strhub.data.utils import Tokenizer
from strhub.models.utils import init_weights
from .modules import Decoder, DecoderLayer, Encoder, TokenEmbedding
class PARSeq(nn.Module):
The text was updated successfully, but these errors were encountered: