From 03a60ded0809a6246fa9918fa13ff298d9fb9a3f Mon Sep 17 00:00:00 2001 From: MartinKalema Date: Tue, 11 Jun 2024 17:21:54 +0300 Subject: [PATCH] final commit --- app.py | 88 +++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 68 insertions(+), 20 deletions(-) diff --git a/app.py b/app.py index 41064a8..77ffcc1 100644 --- a/app.py +++ b/app.py @@ -1,23 +1,71 @@ import streamlit as st from fastai.text.all import * -@st.cache_resource -def load_model(): - with st.spinner('Model is being loaded...'): - learn = load_learner('models/text_classifier_model.pkl') - return learn - -st.title('ULMFiT Swahili News Article Classifier') - -st.markdown(""" -ULMFiT (Universal Language Model Fine-tuning) is an effective transfer learning method for NLP tasks. -""") - -user_text = st.text_area('Enter text for classification') - -if st.button('Classify'): - if user_text: - pred_class, pred_idx, outputs = learn.predict(user_text) - st.write(f"Input text belongs to: {pred_class}") - else: - st.write("Please enter text to classify.") \ No newline at end of file +class TextClassifierApp: + """ + A Streamlit app for classifying Swahili news articles using ULMFiT. + + Attributes: + learn (Learner): The FastAI learner object for the text classifier. + + Methods: + load_model(): Loads the pre-trained model with a spinner indicating the loading process. + predict(text: str) -> str: Predicts the class of the given text. + run(): Runs the Streamlit app, providing the user interface for text classification. + """ + def __init__(self): + """ + Initializes the TextClassifierApp by loading the model. + """ + self.learn = None + self.load_model() + + @st.cache_resource + def load_model(self): + """ + Loads the pre-trained model and shows a spinner during the loading process. + + Returns: + None + """ + with st.spinner('Model is being loaded...'): + self.learn = load_learner('models/text_classifier_model.pkl') + + def predict(self, text): + """ + Predicts the class of the given text. + + Args: + text (str): The text to classify. + + Returns: + str: The predicted class. + """ + pred_class, pred_idx, outputs = self.learn.predict(text) + return pred_class + + def run(self): + """ + Runs the Streamlit app, providing the user interface for text classification. + + Returns: + None + """ + st.title('ULMFiT Swahili News Article Classifier') + + st.markdown(""" + ULMFiT (Universal Language Model Fine-tuning) is an effective transfer learning method for NLP tasks. + """) + + user_text = st.text_area('Enter text for classification') + + if st.button('Classify'): + if user_text: + pred_class = self.predict(user_text) + st.write(f"Input text belongs to: {pred_class}") + else: + st.write("Please enter text to classify.") + +if __name__ == '__main__': + app = TextClassifierApp() + app.run()