Skip to content

Commit

Permalink
Merge pull request #13 from Ulas-Scan/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
iyoubee authored Jun 12, 2024
2 parents d4c0ff5 + 067e5a2 commit 360680f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 22 deletions.
59 changes: 46 additions & 13 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from utils import get_model, get_tokenizer, predict_sentiment
from utils import get_model, get_tokenizer, predict_results
from flask import Flask, request, jsonify, abort
from dotenv import load_dotenv
from functools import wraps
Expand All @@ -20,6 +20,33 @@
# Get API key from environment variable
API_KEY = os.getenv('API_KEY')

def process_statements(statements):
try:
logits = predict_results(statements, tokenizer, model, MAX_LENGTH)
return logits
except Exception as e:
print(f"Error occurred: {e}")
return None

def process_statement(statement):
try:
# Tokenize the statement
_ = tokenizer(
text=statement,
add_special_tokens=True,
max_length=MAX_LENGTH,
truncation=True,
padding='max_length',
return_tensors='tf'
)
return statement, True # Return the statement and True if it's valid
except ValueError as e:
print(f"Skipping invalid statement: {statement}. Error: {e}")
return statement, False # Return the statement and False if it's invalid
except Exception as e:
print(f"Error occurred during tokenization: {e}")
return statement, False # Return the statement and False if an error occurs

def require_api_key(f):
@wraps(f)
def decorated_function(*args, **kwargs):
Expand All @@ -44,21 +71,27 @@ def predict():
# statements: list of reviews
report = {'Positive': 0, 'Negative': 0}

def process_statement(statement):
try:
sentiment = predict_sentiment(statement, tokenizer, model, MAX_LENGTH)
return sentiment
except Exception as e:
print(f"Error occurred: {e}, statement: {statement}")
return None

with ThreadPoolExecutor() as executor:
futures = {executor.submit(process_statement, statement): statement for statement in statements}
futures = [executor.submit(process_statement, statement) for statement in statements]

valid_statements = []
for future in as_completed(futures):
result = future.result()
if result:
report[result] += 1
statement, is_valid = future.result()
if is_valid:
valid_statements.append(statement)

if not valid_statements:
return jsonify({"error": "No valid statements provided"}), 400

predictions = process_statements(valid_statements)
if predictions is not None:
for pred in predictions:
if pred[0] > pred[1]:
report['Negative'] += 1
else:
report['Positive'] += 1
else:
return jsonify({"error": "Error processing statements"}), 500

return jsonify(report)

Expand Down
17 changes: 8 additions & 9 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@ def get_tokenizer(model_name):
def get_model(pretrained_path):
return TFBertForSequenceClassification.from_pretrained(pretrained_path)

# dump
def translate_to_indo(text):
translator = GoogleTranslator(source='en', target='id')
translated_text = translator.translate(text)
return translated_text

def predict_sentiment(text, tokenizer, model, max_length):
# translated_text = translate_to_indo(text)
tokenized_text = tokenizer(
text=text,
def predict_results(texts, tokenizer, model, max_length):
tokenized_texts = tokenizer(
text=texts,
add_special_tokens=True,
max_length=max_length,
truncation=True,
padding='max_length',
return_tensors='tf'
)
input_ids = tokenized_text['input_ids']
attention_mask = tokenized_text['attention_mask']
prediction = model.predict([input_ids, attention_mask])
sentiment = "Positive" if prediction[0][0][1] >= 1 else "Negative"
return sentiment
input_ids = tokenized_texts['input_ids']
attention_masks = tokenized_texts['attention_mask']
predictions = model.predict([input_ids, attention_masks], use_multiprocessing=True, workers=4)
return predictions.logits

0 comments on commit 360680f

Please sign in to comment.