Skip to content

Commit

Permalink
Version 1.1.0 & improved batching code
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuel Stoeckel committed Nov 25, 2019
1 parent edd886e commit 341bff9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.hucompute.textimager.uima</groupId>
<artifactId>deep-eos-uima</artifactId>
<version>1.0.0</version>
<version>1.1.0</version>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
Expand Down
28 changes: 11 additions & 17 deletions src/main/resources/python/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np
import os
from typing import List

import numpy as np

from utils import Utils

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
Expand Down Expand Up @@ -41,26 +39,22 @@ def tag(self, text) -> List[int]:
potential_eos_list = util.build_potential_eos_list(text, self.window_size)

eos_pos = []
for batch_no in range(int(len(potential_eos_list) / self.batch_size) + 1):
i = batch_no * self.batch_size
j = min(len(potential_eos_list), i + self.batch_size)
batch = potential_eos_list[i:j]
for i in range(0, len(potential_eos_list), self.batch_size):
batch = potential_eos_list[i:i + self.batch_size]
batch_size = len(batch)

eos_positions = [eos_position for eos_position, _ in batch]
char_sequences = [(-1.0, char_sequence) for _, char_sequence in batch]
data_set = util.build_data_set(char_sequences, self.char_2_id_dict, self.window_size)
features = np.array([i[1] for i in data_set])
batch_size = len(features)

if batch_size > 0:
predicted = self.deep_eos_model.predict(
features,
batch_size=batch_size,
verbose=0)

for i in range(batch_size):
if predicted[i][0] >= 0.5:
eos_pos.append(int(eos_positions[i]))
predicted = self.deep_eos_model.predict(
features,
batch_size=batch_size,
verbose=0)

for j in range(batch_size):
if predicted[j][0] >= 0.5:
eos_pos.append(int(eos_positions[j]))

return eos_pos

0 comments on commit 341bff9

Please sign in to comment.