Skip to content

Commit

Permalink
fix ref_src_lines mode & speed up mit48px inference"
Browse files Browse the repository at this point in the history
  • Loading branch information
dmMaze committed Sep 5, 2024
1 parent c1ae207 commit 77e4e52
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 17 deletions.
170 changes: 168 additions & 2 deletions modules/ocr/mit48px.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __call__(self, textblk_lst: List[TextBlock], regions: List[np.ndarray], text
image_tensor = image_tensor.to(self.device)

with torch.no_grad():
ret = self.model.infer_beam_batch(image_tensor, widths, beams_k = 5, max_seq_length = 255)
ret = self.model.infer_beam_batch_tensor(image_tensor, widths, beams_k = 5, max_seq_length = 255)
for i, (pred_chars_index, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred) in enumerate(ret):
if prob < 0.2:
continue
Expand Down Expand Up @@ -551,11 +551,15 @@ def __init__(self, dictionary, max_len):
encoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
encoder.forward = transformer_encoder_forward
self.encoders.append(encoder)
self.encoders.forward = self.encoder_forward

for i in range(5) :
decoder = nn.TransformerDecoderLayer(embd_dim, nhead, dropout = 0, batch_first = True, norm_first = True)
decoder.self_attn = XposMultiheadAttention(embd_dim, nhead, self_attention = True)
decoder.multihead_attn = XposMultiheadAttention(embd_dim, nhead, encoder_decoder_attention = True)
self.decoders.append(decoder)
self.decoders.forward = self.decoder_forward

self.embd = nn.Embedding(self.dict_size, embd_dim)
self.pred1 = nn.Sequential(nn.Linear(embd_dim, embd_dim), nn.GELU(), nn.Dropout(0.15))
self.pred = nn.Linear(embd_dim, self.dict_size)
Expand Down Expand Up @@ -670,6 +674,168 @@ def infer_beam_batch(self, img: torch.FloatTensor, img_widths: List[int], beams_
result.append((cur_hypo.out_idx[1:], cur_hypo.prob(), fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))
return result

def infer_beam_batch_tensor(self, img: torch.FloatTensor, img_widths: List[int], beams_k: int = 5, start_tok = 1, end_tok = 2, pad_tok = 0, max_finished_hypos: int = 2, max_seq_length = 384):
N, C, H, W = img.shape
assert H == 48 and C == 3


memory = self.backbone(img)
memory = einops.rearrange(memory, 'N C 1 W -> N W C')
valid_feats_length = [(x + 3) // 4 + 2 for x in img_widths]
input_mask = torch.zeros(N, memory.size(1), dtype = torch.bool).to(img.device)

for i, l in enumerate(valid_feats_length):
input_mask[i, l:] = True
memory = self.encoders(memory, input_mask) # N, W, Dim


out_idx = torch.full((N, 1), start_tok, dtype=torch.long, device=img.device) # Shape [N, 1]
cached_activations = torch.zeros(N, len(self.decoders)+1, max_seq_length, 320, device=img.device) # [N, L, S, E]
log_probs = torch.zeros(N, 1, device=img.device) # Shape [N, 1] # N, E
idx_embedded = self.embd(out_idx[:, -1:])


decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, 0)
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # N, n_chars
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim = 1) # N, k


out_idx = torch.cat([out_idx.unsqueeze(1).expand(-1, beams_k, -1), pred_chars_index.unsqueeze(-1)], dim=-1).reshape(-1, 2) # Shape [N * k, 2]
log_probs = pred_chars_values.view(-1, 1) # Shape [N * k, 1]
memory = memory.repeat_interleave(beams_k, dim=0)
input_mask = input_mask.repeat_interleave(beams_k, dim=0)
cached_activations = cached_activations.repeat_interleave(beams_k, dim=0)
batch_index = torch.arange(N).repeat_interleave(beams_k, dim=0).to(img.device)


finished_hypos = defaultdict(list)
N_remaining = N


for step in range(1, max_seq_length):
idx_embedded = self.embd(out_idx[:, -1:])
decoded, cached_activations = self.decoders(idx_embedded, cached_activations, memory, input_mask, step)
pred_char_logprob = self.pred(self.pred1(decoded)).log_softmax(-1) # Shape [N * k, dict_size]
pred_chars_values, pred_chars_index = torch.topk(pred_char_logprob, beams_k, dim=1) # [N * k, k]


finished = out_idx[:, -1] == end_tok
pred_chars_values[finished] = 0
pred_chars_index[finished] = end_tok


# Extend hypotheses
new_out_idx = out_idx.unsqueeze(1).expand(-1, beams_k, -1) # Shape [N * k, k, seq_len]
new_out_idx = torch.cat([new_out_idx, pred_chars_index.unsqueeze(-1)], dim=-1) # Shape [N * k, k, seq_len + 1]
new_out_idx = new_out_idx.view(-1, step + 2) # Reshape to [N * k^2, seq_len + 1]
new_log_probs = log_probs.unsqueeze(1).expand(-1, beams_k, -1) + pred_chars_values.unsqueeze(-1) # Shape [N * k^2, 1]
new_log_probs = new_log_probs.view(-1, 1) # [N * k^2, 1]


# Sort and select top-k hypotheses per sample
new_out_idx = new_out_idx.view(N_remaining, -1, step + 2) # [N, k^2, seq_len + 1]
new_log_probs = new_log_probs.view(N_remaining, -1) # [N, k^2]
batch_topk_log_probs, batch_topk_indices = new_log_probs.topk(beams_k, dim=1) # [N, k]

# Gather the top-k hypotheses based on log probabilities
expanded_topk_indices = batch_topk_indices.unsqueeze(-1).expand(-1, -1, new_out_idx.shape[-1]) # Shape [N, k, seq_len + 1]
out_idx = torch.gather(new_out_idx, 1, expanded_topk_indices).reshape(-1, step + 2) # [N * k, seq_len + 1]
log_probs = batch_topk_log_probs.view(-1, 1) # Reshape to [N * k, 1]


# Check for finished sequences
finished = (out_idx[:, -1] == end_tok) # Check if the last token is the end token
finished = finished.view(N_remaining, beams_k) # Reshape to [N, k]
finished_counts = finished.sum(dim=1) # Count the number of finished hypotheses per sample
finished_batch_indices = (finished_counts >= max_finished_hypos).nonzero(as_tuple=False).squeeze()


if finished_batch_indices.numel() == 0:
continue


if finished_batch_indices.dim() == 0:
finished_batch_indices = finished_batch_indices.unsqueeze(0)

for idx in finished_batch_indices:
batch_log_probs = batch_topk_log_probs[idx]
best_beam_idx = batch_log_probs.argmax()
finished_hypos[batch_index[beams_k * idx].item()] = \
out_idx[idx * beams_k + best_beam_idx], \
torch.exp(batch_log_probs[best_beam_idx]).item(), \
cached_activations[idx * beams_k + best_beam_idx]


remaining_indexs = []
for i in range(N_remaining):
if i not in finished_batch_indices:
for j in range(beams_k):
remaining_indexs.append(i * beams_k + j)


if not remaining_indexs:
break


N_remaining = int(len(remaining_indexs) / beams_k)
out_idx = out_idx.index_select(0, torch.tensor(remaining_indexs, device=img.device))
log_probs = log_probs.index_select(0, torch.tensor(remaining_indexs, device=img.device))
memory = memory.index_select(0, torch.tensor(remaining_indexs, device=img.device))
cached_activations = cached_activations.index_select(0, torch.tensor(remaining_indexs, device=img.device))
input_mask = input_mask.index_select(0, torch.tensor(remaining_indexs, device=img.device))
batch_index = batch_index.index_select(0, torch.tensor(remaining_indexs, device=img.device))


# Ensure we have the correct number of finished hypotheses for each sample
assert len(finished_hypos) == N


# Final output processing and color predictions
result = []
for i in range(N):
final_idx, prob, decoded = finished_hypos[i]
color_feats = self.color_pred1(decoded[-1].unsqueeze(0))
fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = \
self.color_pred_fg(color_feats), \
self.color_pred_bg(color_feats), \
self.color_pred_fg_ind(color_feats), \
self.color_pred_bg_ind(color_feats)
result.append((final_idx[1:], prob, fg_pred[0], bg_pred[0], fg_ind_pred[0], bg_ind_pred[0]))


return result

def encoder_forward(self, memory, encoder_mask):
for layer in self.encoders :
memory = layer(layer, src = memory, src_key_padding_mask = encoder_mask)
return memory

def decoder_forward(
self,
embd: torch.Tensor,
cached_activations: torch.Tensor, # Shape [N, L, T, E] where L=num_layers, T=sequence length, E=embedding size
memory: torch.Tensor, # Shape [N, H, W, C] (Encoder memory output)
memory_mask: torch.BoolTensor,
step: int
):

layer: nn.TransformerDecoderLayer
tgt = embd # N, 1, E for the last token embedding

for l, layer in enumerate(self.decoders):
combined_activations = cached_activations[:, l, :step, :] # N, T, E
combined_activations = torch.cat([combined_activations, tgt], dim=1) # N, T+1, E
cached_activations[:, l, step, :] = tgt.squeeze(1)

# Update cache and perform self attention
tgt = tgt + layer.self_attn(layer.norm1(tgt), layer.norm1(combined_activations), layer.norm1(combined_activations), q_offset=step)[0]
tgt = tgt + layer.multihead_attn(layer.norm2(tgt), memory, memory, key_padding_mask=memory_mask, q_offset=step)[0]
tgt = tgt + layer._ff_block(layer.norm3(tgt))

cached_activations[:, l+1, step, :] = tgt.squeeze(1) # Append the new activations

return tgt.squeeze_(1), cached_activations

import numpy as np

def convert_pl_model(filename: str) :
Expand Down Expand Up @@ -707,7 +873,7 @@ def test_infer() :
img_torch = einops.rearrange((torch.from_numpy(img) / 127.5 - 1.0), 'h w c -> 1 c h w')

with torch.no_grad() :
idx, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = model.infer_beam_batch(img_torch, [new_w], 5, max_seq_length = 32)[0]
idx, prob, fg_pred, bg_pred, fg_ind_pred, bg_ind_pred = model.infer_beam_batch_tensor(img_torch, [new_w], 5, max_seq_length = 32)[0]
txt = ''
for i in idx :
txt += dictionary[i]
Expand Down
2 changes: 1 addition & 1 deletion modules/translators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self,
if TRANSLATORS.module_dict[key] == self.__class__:
self.name = key
break
self.textblk_break = '\n###\n'
self.textblk_break = '\n##\n'
self.lang_source: str = lang_source
self.lang_target: str = lang_target
self.lang_map: Dict = LANGMAP_GLOBAL.copy()
Expand Down
6 changes: 3 additions & 3 deletions utils/text_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def layout_lines_aligncenter(
elif mask[pos_y: line_bottom - lh_pad, new_x].mean() < border_thr or\
mask[pos_y: line_bottom - lh_pad, right_x].mean() < border_thr:
line_valid = False
if (len(lines) == 1 and ref_src_lines or line_right_no + 1 >= len(srcline_wlist)) and \
if ref_src_lines and (len(wl_list) == 1 or line_right_no + 1 >= len(srcline_wlist)) and \
line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_right_no, line_height, ref_src_lines):
line_valid = True
else:
Expand Down Expand Up @@ -271,7 +271,7 @@ def layout_lines_aligncenter(
elif mask[pos_y: line_bottom - lh_pad, new_x].mean() < border_thr or\
mask[pos_y: line_bottom - lh_pad, right_x].mean() < border_thr:
line_valid = False
if line_left_no - 1 < 0 and \
if ref_src_lines and line_left_no - 1 < 0 and \
line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_left_no, line_height, ref_src_lines):
line_valid = True
else:
Expand Down Expand Up @@ -358,7 +358,7 @@ def layout_lines_alignside(
if mask[np.clip(pos_y, 0, bh - 1): np.clip(line_bottom - lh_pad, 0, bh), new_x].mean() > 240:
line_valid = True
else:
if line_id + 1 >= len(srcline_wlist) and line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines):
if ref_src_lines and line_id + 1 >= len(srcline_wlist) and line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines):
line_valid = True
if line_valid:
line_valid = line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines)
Expand Down
23 changes: 12 additions & 11 deletions utils/textblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,17 +392,18 @@ def to_dict(self, deep_copy=False):
def get_transformed_region(self, img: np.ndarray, idx: int, textheight: int, maxwidth: int = None) -> np.ndarray :
im_h, im_w = img.shape[:2]

lines = np.round(np.array(self.lines[idx])).astype(np.int64)[None]


expand_size = max(int(self._detected_font_size * 0.1), 2)
rad = np.deg2rad(self.angle)
shifted_vec = np.array([[[-1, -1],[1, -1],[1, 1],[-1, 1]]])
shifted_vec = shifted_vec * np.array([[[np.sin(rad), np.cos(rad)]]]) * expand_size
lines = lines + shifted_vec
lines[..., 0] = np.clip(lines[..., 0], 0, im_w)
lines[..., 1] = np.clip(lines[..., 1], 0, im_h)
line = np.round(lines[0]).astype(np.int64)
line = np.round(np.array(self.lines[idx])).astype(np.int64)

if not self.src_is_vertical and self.det_model == 'ctd':
# ctd detected horizontal bbox is smaller than GT
expand_size = max(int(self._detected_font_size * 0.1), 3)
rad = np.deg2rad(self.angle)
shifted_vec = np.array([[[-1, -1],[1, -1],[1, 1],[-1, 1]]])
shifted_vec = shifted_vec * np.array([[[np.sin(rad), np.cos(rad)]]]) * expand_size
line = line + shifted_vec
line[..., 0] = np.clip(line[..., 0], 0, im_w)
line[..., 1] = np.clip(line[..., 1], 0, im_h)
line = np.round(line[0]).astype(np.int64)

x1, y1, x2, y2 = line[:, 0].min(), line[:, 1].min(), line[:, 0].max(), line[:, 1].max()

Expand Down

0 comments on commit 77e4e52

Please sign in to comment.