Skip to content

Commit

Permalink
final commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinKalema committed Jun 11, 2024
1 parent c316114 commit 03a60de
Showing 1 changed file with 68 additions and 20 deletions.
88 changes: 68 additions & 20 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,71 @@
import streamlit as st
from fastai.text.all import *

@st.cache_resource
def load_model():
with st.spinner('Model is being loaded...'):
learn = load_learner('models/text_classifier_model.pkl')
return learn

st.title('ULMFiT Swahili News Article Classifier')

st.markdown("""
ULMFiT (Universal Language Model Fine-tuning) is an effective transfer learning method for NLP tasks.
""")

user_text = st.text_area('Enter text for classification')

if st.button('Classify'):
if user_text:
pred_class, pred_idx, outputs = learn.predict(user_text)
st.write(f"Input text belongs to: {pred_class}")
else:
st.write("Please enter text to classify.")
class TextClassifierApp:
"""
A Streamlit app for classifying Swahili news articles using ULMFiT.
Attributes:
learn (Learner): The FastAI learner object for the text classifier.
Methods:
load_model(): Loads the pre-trained model with a spinner indicating the loading process.
predict(text: str) -> str: Predicts the class of the given text.
run(): Runs the Streamlit app, providing the user interface for text classification.
"""
def __init__(self):
"""
Initializes the TextClassifierApp by loading the model.
"""
self.learn = None
self.load_model()

@st.cache_resource
def load_model(self):
"""
Loads the pre-trained model and shows a spinner during the loading process.
Returns:
None
"""
with st.spinner('Model is being loaded...'):
self.learn = load_learner('models/text_classifier_model.pkl')

def predict(self, text):
"""
Predicts the class of the given text.
Args:
text (str): The text to classify.
Returns:
str: The predicted class.
"""
pred_class, pred_idx, outputs = self.learn.predict(text)
return pred_class

def run(self):
"""
Runs the Streamlit app, providing the user interface for text classification.
Returns:
None
"""
st.title('ULMFiT Swahili News Article Classifier')

st.markdown("""
ULMFiT (Universal Language Model Fine-tuning) is an effective transfer learning method for NLP tasks.
""")

user_text = st.text_area('Enter text for classification')

if st.button('Classify'):
if user_text:
pred_class = self.predict(user_text)
st.write(f"Input text belongs to: {pred_class}")
else:
st.write("Please enter text to classify.")

if __name__ == '__main__':
app = TextClassifierApp()
app.run()

0 comments on commit 03a60de

Please sign in to comment.