diff --git a/README.md b/README.md
index 20544ca..f03c637 100644
--- a/README.md
+++ b/README.md
@@ -2,24 +2,56 @@
## Final Project: Fake News Detection
-### Introduction & Motivation
+### 1. Introduction & Motivation
Fake news has become a major problem. A Pew Research Center survey conducted in 2016 found that nearly two-thirds of U.S. adults say that fake news has caused a great deal of confusion about basic facts of current events and nearly a quarter have said that they have either knowingly or unknowingly shared fake news online[^1]. Due to the widespread nature of social media, the spread of misinformation is fast and can have significant consequences for individuals, communities, and societies. For example, fake news can lead to distrust of the media, undermining of the democratic process/contribute to political divisiveness, and influence erroneous decision-making[^2].
By developing accurate and effective methods to predict fake news, social scientists can: 1. understand how people consume information (specifically cognitive biases, social homophily and inattention) which can shed light on the psychological, social, and technological factors that contribute to the viral spread of misinformation; 2. assess the impact of fake news on society on the most vulnerable communities/individuals most susceptible of sharing misinformation without proper verification and develop interventions to mitigate the harmful effects; and finally, 3. understand the specifics of fake news that lead to manipulation of populations exposed and how to combat this phenomenon.
-### Research Questions
-1. Can fake news be predicted and if so, how well?
-2. What are the biggest differences between articles from reliable and unreliable sources and are there topics that are more susceptible to being faked?
-
-### Why Large-Scale Computing is Important in NLP (non-exhaustive list):
+### 2. Why Large-Scale Computing is Important in NLP (non-exhaustive list):
1. Text processing, that is, the act of turning text into a numeric input that a machine learning model can take in, would benefit largely from parallelization. Namely, tokenization, the act of breaking down chunks of text into smaller subunits (e.g., words) is a necessary step that can be computationally expensive, especially when dealing with large documents.
2. Feature extraction such as obtaining n-grams from text can lead to extremely wide dataframes (high dimensions--count vectorizers increase in folds of the length of the vocabulary size, which can be in the tens of thousands), requiring substantial memory resources.
3. Large language models (not used in this project, but can be applied to increase accuracy), have millions of parameters leading to the need for more compute-intensive resources.
4. Model fine-tuning often involve computationally expensive and time consuming procedures such as hyperparameter tuning via grid search.
-### Project Structure
+### 3. Project Details
+
+#### Research Questions
+1. Can fake news be predicted and if so, how well?
+2. What are the biggest differences between articles from reliable and unreliable sources and are there topics that are more susceptible to being faked?
#### Data
+Data come from [this](https://www.kaggle.com/competitions/fake-news/data) Kaggle competition. The key file is ```bash train.csv ```, which is a labeled dataset of 20,800 news articles. The ```bash test.csv``` file does not contain labels so I excluded it from this project.
+
+#### Procedure
+The project is divided into two main parts and uses PySpark:
+
+1. Build a text cleaning and preprocessing pipeline
+ 1. Data cleaning
+ 1. Tokenize text
+ 2. Clean & normalize tokens: remove stop words, punctuation, make all text lowercase and lemmatize words (extracting base words--for example "running" --> "run")
+ 2. Text processing - convert preprocessed tokens to a numerical format models can take in using a count vectorizer which takes in n-grams from the corpus and counts the number of instances that n-gram is seen in the example
+
+2. Build a machine learning pipeline to obtain predictions (each notebook also performs text cleaning and preprocessing)
+ 1. Build and tune models (logistic regression and gradient boosted trees) to predict whether an article is from an unreliable source (fake)
+ * Code: [fake_news_prediction.ipynb](https://github.com/macs30123-s23/final-project-fake_news/blob/main/fake_news_prediction.ipynb)
+ 2. Perform LDA topic modeling to analyze which topics are more likely to be manipulated into fake news.
+ * Code: [lda.ipynb](https://github.com/macs30123-s23/final-project-fake_news/blob/main/lda.ipynb)
+
+#### Results
+* Fake News Prediction - data for both models were split into an 80/20 train-test split
+
+ * Logistic Regression: I chose a logistic regression model since logistic regression is relatively simple and interpretable and provides a probabilistic interpretation of classification results. I performed hyperparameter tuning via 5-fold grid search cross validation of the regularization parameter and elastic net parameter. The evaluator used was the BinaryClassificationEvaluator from PySpark with AUC-ROC as the evaluation metric. The test AUC and test accuracy came out to 0.9732 and 0.9217, respectively, indicating that fake news can be predicted well using a matrix of n-gram token counts from the count vectorizer and logistic regression.
+
+ * Gradient Boosted Tree Classifier: The second model I chose to use was a gradient boosted tree since they are generally considered accurate, stable, and highly interpretable. Additionally, contrary to linear models such as logistic regression, tree-based models don’t assume our data have linear boundaries. I performed hyperparameter tuning via 5-fold grid search cross validation of maximum depth of the tree and maximum number of iterations. The evaluator used was the BinaryClassificationEvaluator from PySpark with AUC-ROC as the evaluation metric. The test AUC and test accuracy came out to 0.9724 and 0.9071, respectively. The test AUC is similar to the one from the logistic regression model, but test accuracy was slightly lower here.
+
+* LDA Topic Modeling
+
+The image below shows the results from 10 topics from unreliable (fake) articles.
+
+
+The image below shows the results from 10 topics from reliable (real) articles.
+
+
[^1]: https://www.pewresearch.org/journalism/2016/12/15/many-americans-believe-fake-news-is-sowing-confusion/
diff --git a/lda.ipynb b/lda.ipynb
new file mode 100644
index 0000000..82f6cfa
--- /dev/null
+++ b/lda.ipynb
@@ -0,0 +1,653 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "source": [
+ "%%bash\n",
+ "apt-get update -qq\n",
+ "apt-get install -y openjdk-8-jdk-headless -qq > /dev/null\n",
+ "wget -q \"https://archive.apache.org/dist/spark/spark-3.1.1/spark-3.1.1-bin-hadoop2.7.tgz\" > /dev/null\n",
+ "tar -xvf spark-3.1.1-bin-hadoop2.7.tgz > /dev/null\n",
+ "\n",
+ "pip install pyspark findspark --quiet\n",
+ "pip install sparknlp --quiet"
+ ],
+ "metadata": {
+ "id": "wDPQfQjdlina",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "649f6d58-cf61-4085-9dfd-ee29175d4c2b"
+ },
+ "id": "wDPQfQjdlina",
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 310.8/310.8 MB 4.5 MB/s eta 0:00:00\n",
+ " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 489.4/489.4 kB 27.9 MB/s eta 0:00:00\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "c44c787e",
+ "metadata": {
+ "id": "c44c787e"
+ },
+ "outputs": [],
+ "source": [
+ "# setup spark session\n",
+ "import os\n",
+ "import sys\n",
+ "\n",
+ "os.environ['PYSPARK_PYTHON'] = sys.executable\n",
+ "os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable\n",
+ "\n",
+ "# # Find Spark so that we can access session within our notebook\n",
+ "import findspark\n",
+ "findspark.init()\n",
+ "\n",
+ "# Start SparkSession on all available cores\n",
+ "from pyspark.sql import SparkSession\n",
+ "# spark = SparkSession.builder.master(\"local[*]\").getOrCreate()\n",
+ "spark = SparkSession.builder \\\n",
+ " .master(\"local[*]\")\\\n",
+ " .config(\"spark.jars.packages\", \"com.johnsnowlabs.nlp:spark-nlp_2.12:4.3.2\")\\\n",
+ " .getOrCreate()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "0f67e38c",
+ "metadata": {
+ "id": "0f67e38c"
+ },
+ "outputs": [],
+ "source": [
+ "data = spark.read.csv('train1.csv',\n",
+ " header='true',\n",
+ " inferSchema='true',\n",
+ " multiLine=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "393c26b4",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "393c26b4",
+ "outputId": "bc8768c9-d362-4483-9ae7-c9001adaec37"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Total Columns: 5\n",
+ "Total Rows: 20800\n",
+ "root\n",
+ " |-- id: integer (nullable = true)\n",
+ " |-- title: string (nullable = true)\n",
+ " |-- author: string (nullable = true)\n",
+ " |-- text: string (nullable = true)\n",
+ " |-- label: integer (nullable = true)\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('Total Columns: %d' % len(data.dtypes))\n",
+ "print('Total Rows: %d' % data.count())\n",
+ "data.printSchema()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "17cd0b6f",
+ "metadata": {
+ "id": "17cd0b6f"
+ },
+ "source": [
+ "### Text Preprocessing"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "b3437bd8",
+ "metadata": {
+ "id": "b3437bd8"
+ },
+ "outputs": [],
+ "source": [
+ "from pyspark.ml.feature import Word2Vec, Word2VecModel\n",
+ "from pyspark.ml.feature import StopWordsRemover, IDF, CountVectorizer\n",
+ "from pyspark.sql.types import ArrayType, StringType\n",
+ "from pyspark.sql.functions import regexp_replace, array, col, udf, split\n",
+ "from pyspark.ml import Pipeline\n",
+ "from sparknlp.annotator import Lemmatizer, LemmatizerModel, Tokenizer, StopWordsCleaner, Normalizer\n",
+ "from sparknlp.base import DocumentAssembler, Finisher\n",
+ "from pyspark.ml.clustering import LDA\n",
+ "from tqdm import tqdm\n",
+ "import matplotlib.pyplot as plt\n",
+ "from wordcloud import WordCloud"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "bd929335",
+ "metadata": {
+ "id": "bd929335"
+ },
+ "outputs": [],
+ "source": [
+ "def doc_assembler(inputCol):\n",
+ " '''Spark NLP document assembler'''\n",
+ " \n",
+ " return DocumentAssembler().setInputCol(inputCol)\n",
+ "\n",
+ "\n",
+ "def tokenizer(inputCol, outputCol):\n",
+ " '''Tokenize text for input to the lemmatizer'''\n",
+ " \n",
+ " tokenizer = (Tokenizer()\n",
+ " .setInputCols([inputCol])\n",
+ " .setOutputCol(outputCol))\n",
+ " return tokenizer\n",
+ "\n",
+ "\n",
+ "def stopwords(inputCol, outputCol):\n",
+ " '''Remove stopwords'''\n",
+ "\n",
+ " stopwords = StopWordsCleaner.pretrained(\"stopwords_en\", \"en\") \\\n",
+ " .setInputCols([inputCol]) \\\n",
+ " .setOutputCol(outputCol)\n",
+ " return stopwords\n",
+ "\n",
+ "\n",
+ "def normalizer(inputCol, outputCol):\n",
+ " '''Remove unnecessary characters and make tokens lowercase'''\n",
+ " \n",
+ " normalizer = (Normalizer() \n",
+ " .setInputCols([inputCol])\n",
+ " .setOutputCol(outputCol)\n",
+ " .setLowercase(True))\n",
+ " return normalizer\n",
+ " \n",
+ " \n",
+ "def lemmatizer(inputCol, outputCol):\n",
+ " '''\n",
+ " Retrieve root words out of the input tokens\n",
+ " using default pretrained lemmatizer\n",
+ " '''\n",
+ " \n",
+ " lemmatizer = (LemmatizerModel.pretrained(name=\"lemma_antbnc\", lang=\"en\")\n",
+ " .setInputCols([inputCol])\n",
+ " .setOutputCol(outputCol))\n",
+ " return lemmatizer\n",
+ "\n",
+ "\n",
+ "def finisher(finishedCol):\n",
+ " '''Finisher transform for Spark NLP pipeline'''\n",
+ " \n",
+ " finisher = (Finisher()\n",
+ " .setInputCols([finishedCol])\n",
+ " .setIncludeMetadata(False))\n",
+ " return finisher\n",
+ "\n",
+ "\n",
+ "def run_sparknlp_pipeline(df):\n",
+ " \"\"\"Create a SparkNLP pipeline that transforms the input DataFrame to procude a final output\n",
+ " column storing each document as a sequence of lemmas (root words).\n",
+ " \"\"\"\n",
+ " nlpPipeline = Pipeline(stages=[\n",
+ " doc_assembler(\"text\"),\n",
+ " tokenizer(\"document\", \"token\"),\n",
+ " stopwords('token', 'token_s'),\n",
+ " normalizer('token_s', 'cleaned_tokens'),\n",
+ " lemmatizer(\"cleaned_tokens\", \"lemma\"),\n",
+ " finisher(\"lemma\")\n",
+ " ])\n",
+ " df1 = nlpPipeline.fit(df).transform(df).withColumnRenamed('finished_lemma', 'allTokens')\n",
+ "\n",
+ " return df1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "819e2bb3",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "819e2bb3",
+ "outputId": "125ff8b4-6776-4418-e71b-0034a45ed1de"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "stopwords_en download started this may take some time.\n",
+ "Approximate size to download 2.9 KB\n",
+ "[OK!]\n",
+ "lemma_antbnc download started this may take some time.\n",
+ "Approximate size to download 907.6 KB\n",
+ "[OK!]\n"
+ ]
+ }
+ ],
+ "source": [
+ "sampled_data = data.sample(fraction=0.5)\n",
+ "\n",
+ "nlpPipelineDF = run_sparknlp_pipeline(sampled_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "fafa496f",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "fafa496f",
+ "outputId": "f1a55646-722b-4f77-a957-4b4053f935c8"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "+---+--------------------+--------------------+--------------------+-----+--------------------+\n",
+ "| id| title| author| text|label| allTokens|\n",
+ "+---+--------------------+--------------------+--------------------+-----+--------------------+\n",
+ "| 0|House Dem Aide: W...| Darrell Lucus|\"House Dem Aide: ...| 1|[house, dem, aide...|\n",
+ "| 1|FLYNN: Hillary Cl...| Daniel J. Flynn|Ever get the feel...| 0|[feel, life, circ...|\n",
+ "| 3|15 Civilians Kill...| Jessica Purkiss|Videos 15 Civilia...| 1|[video, civilian,...|\n",
+ "| 4|Iranian woman jai...| Howard Portnoy|Print An Iranian ...| 1|[print, iranian, ...|\n",
+ "| 9|A Back-Channel Pl...|Megan Twohey and ...|A week before Mic...| 0|[week, michael, f...|\n",
+ "| 12|Russian Researche...| Amando Flavio|The mystery surro...| 1|[mystery, surroun...|\n",
+ "| 15|In Major League S...| Jack Williams|Guillermo Barros ...| 0|[guillermo, barro...|\n",
+ "| 18|FBI Closes In On ...| The Doc|FBI Closes In On ...| 1|[fbi, close, hill...|\n",
+ "| 23|Massachusetts Cop...| null|Massachusetts Cop...| 1|[massachusetts, c...|\n",
+ "| 24|Abortion Pill Ord...|Donald G. McNeil ...|Orders for aborti...| 0|[order, abortion,...|\n",
+ "| 27|Humiliated Hillar...| Amanda Shea|Humiliated Hillar...| 1|[humiliate, hilla...|\n",
+ "| 29|How Hillary Clint...| Mark Landler|Hillary Clinton s...| 0|[hillary, clinton...|\n",
+ "| 30|Chuck Todd to Buz...| Ian Hanchett|During a discussi...| 0|[discussion, buzz...|\n",
+ "| 31|Israel is Becomin...| null|Country: Israel W...| 1|[country, israel,...|\n",
+ "| 32|Having Won Boris ...| Steven Erlanger|LONDON — With ...| 0|[london, giddy, c...|\n",
+ "| 33|Texas Oil Fields ...| Clifford Krauss|MIDLAND Tex. — ...| 0|[midland, tex, la...|\n",
+ "| 34|Bayer Deal for Mo...|Leslie Picker, Da...|Don Halcomb a ...| 0|[don, halcomb, fa...|\n",
+ "| 37|Open Thread (NOT ...| b|\"Open Thread (NOT...| 1|[open, thread, we...|\n",
+ "| 38|Democrat Gutierre...| AWR Hawkins|Rep. Luis Gutierr...| 0|[rep, luis, gutie...|\n",
+ "| 39|Avoiding Peanuts ...| Aaron E. Carroll|This article orig...| 0|[article, origina...|\n",
+ "+---+--------------------+--------------------+--------------------+-----+--------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "nlpPipelineDF.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "7e111e53",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "7e111e53",
+ "outputId": "6439fcbd-5e51-4abc-8810-b8997f257987"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Total Rows: 10336\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "DataFrame[id: int, title: string, author: string, text: string, label: int, allTokens: array]"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 10
+ }
+ ],
+ "source": [
+ "print('Total Rows: %d' % nlpPipelineDF.count())\n",
+ "nlpPipelineDF.persist()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4037682d",
+ "metadata": {
+ "id": "4037682d"
+ },
+ "source": [
+ "### LDA: Count Vectorizer & Prepare for Topic Modeling\n",
+ "\n",
+ "* Note: LDA code adopted from [this](https://github.com/lsc4ss-s21/large-scale-personal-finance/blob/main/4_Pyspark_topic_modeling.ipynb) sample project "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "ebfb9950",
+ "metadata": {
+ "id": "ebfb9950"
+ },
+ "outputs": [],
+ "source": [
+ "def count_vec(inputCol, outputCol, params):\n",
+ " cv = CountVectorizer(\n",
+ " inputCol=inputCol,\n",
+ " outputCol=outputCol,\n",
+ " vocabSize=params['vocabsize'],\n",
+ " minDF=params['minDF'],\n",
+ " maxDF=params['maxDF'],\n",
+ " minTF=1.0\n",
+ " )\n",
+ " return cv\n",
+ "\n",
+ "def idf(inputCol, outputCol):\n",
+ " return IDF(inputCol=\"features\", outputCol=\"idf\")\n",
+ "\n",
+ "def lda_model(params):\n",
+ " lda = LDA(\n",
+ " k=params['num_topics'],\n",
+ " maxIter=params['iterations'],\n",
+ " optimizer=\"online\",\n",
+ " seed=1,\n",
+ " learningOffset = 100.0, \n",
+ " learningDecay = 0.75,\n",
+ " )\n",
+ " return lda\n",
+ "\n",
+ "\n",
+ "def lda_pipeline(df, params):\n",
+ " '''\n",
+ " Create a Spark ML pipeline and transform the input NLP-transformed DataFrame \n",
+ " to produce an LDA topic model.\n",
+ " '''\n",
+ "\n",
+ " mlPipeline = Pipeline(\n",
+ " stages=[\n",
+ " count_vec(\"allTokens\", \"features\", params),\n",
+ " idf(\"features\", \"idf\"),\n",
+ " lda_model(params)\n",
+ " ]\n",
+ " )\n",
+ " mlModel = mlPipeline.fit(df)\n",
+ " ldaModel = mlModel.stages[2]\n",
+ " \n",
+ " return mlModel, ldaModel\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "a0f91788",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "a0f91788",
+ "outputId": "33b69a84-ed0a-4ed9-8538-ac996b33ec9a"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "{'num_topics': 10,\n",
+ " 'iterations': 10,\n",
+ " 'vocabsize': 7000,\n",
+ " 'minDF': 0.02,\n",
+ " 'maxDF': 0.8}"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 12
+ }
+ ],
+ "source": [
+ "lda_params = dict(num_topics = 10,\n",
+ " iterations = 10,\n",
+ " vocabsize = 7000,\n",
+ " minDF = 0.02,\n",
+ " maxDF = 0.8\n",
+ " )\n",
+ "lda_params"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "a58d1f6c",
+ "metadata": {
+ "id": "a58d1f6c"
+ },
+ "outputs": [],
+ "source": [
+ "# separate fake and real news data\n",
+ "fake = nlpPipelineDF.filter(col(\"label\") == 1)\n",
+ "real = nlpPipelineDF.filter(col(\"label\") != 1)\n",
+ "\n",
+ "f_mlModel, f_ldaModel = lda_pipeline(fake, lda_params)\n",
+ "r_mlModel, r_ldaModel = lda_pipeline(real, lda_params)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9c76097b",
+ "metadata": {
+ "id": "9c76097b"
+ },
+ "source": [
+ "### Get Topics"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "1295e24e",
+ "metadata": {
+ "id": "1295e24e"
+ },
+ "outputs": [],
+ "source": [
+ "def get_topics(mlModel, ldaModel, num_words=15):\n",
+ " '''\n",
+ " Get words and weights from the LDA model.\n",
+ " '''\n",
+ " # Store vocab from CountVectorizer\n",
+ " vocab = mlModel.stages[0].vocabulary\n",
+ " \n",
+ " # Store LDA model part of pipeline\n",
+ " ldaModel = mlModel.stages[2]\n",
+ "\n",
+ " # Take top n words in each topic\n",
+ " topics = ldaModel.describeTopics(num_words)\n",
+ " topics_rdd = topics.rdd\n",
+ "\n",
+ " topic_words = topics_rdd \\\n",
+ " .map(lambda row: row['termIndices']) \\\n",
+ " .map(lambda idx_list: [vocab[idx] for idx in idx_list]) \\\n",
+ " .collect()\n",
+ "\n",
+ " topic_weights = topics_rdd \\\n",
+ " .map(lambda row: row['termWeights']) \\\n",
+ " .collect()\n",
+ "\n",
+ " # Store topic words and weights as a list of dicts\n",
+ " topics = [dict(zip(words, weights))\n",
+ " for words, weights in zip(topic_words, topic_weights)]\n",
+ " return topics"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# get topics for real and fake news\n",
+ "fake_news_topics = get_topics(f_mlModel, f_ldaModel, 10)\n",
+ "real_news_topics = get_topics(r_mlModel, r_ldaModel, 10)"
+ ],
+ "metadata": {
+ "id": "ad1Ndo2fxxRT"
+ },
+ "id": "ad1Ndo2fxxRT",
+ "execution_count": 15,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "662ddbbc",
+ "metadata": {
+ "id": "662ddbbc"
+ },
+ "outputs": [],
+ "source": [
+ "# Plot topic wordclouds\n",
+ "def wordclouds(topics, fname, colormap=\"viridis\"):\n",
+ " cloud = WordCloud(\n",
+ " background_color='white',\n",
+ " width=600,\n",
+ " height=400,\n",
+ " colormap=colormap,\n",
+ " prefer_horizontal=1.0,\n",
+ " )\n",
+ "\n",
+ " num_topics = len(topics)\n",
+ " fig = plt.figure(figsize=(16, 10))\n",
+ "\n",
+ " for idx, word_weights in tqdm(enumerate(topics), total=num_topics):\n",
+ " ax = fig.add_subplot(int((num_topics / 5)) + 1, 5, int(idx + 1))\n",
+ " wordcloud = cloud.generate_from_frequencies(word_weights)\n",
+ " ax.imshow(wordcloud, interpolation=\"bilinear\")\n",
+ " ax.set_title('Topic {}'.format(idx + 1))\n",
+ " ax.set_xticklabels([])\n",
+ " ax.set_yticklabels([])\n",
+ " ax.tick_params(length=0)\n",
+ "\n",
+ " plt.tick_params(labelsize=14)\n",
+ " plt.subplots_adjust(wspace=0.1, hspace=0.1)\n",
+ " plt.margins(x=0.1, y=0.1)\n",
+ " fig.savefig(fname, bbox_inches='tight')\n",
+ " plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# plot fake news topics\n",
+ "plt.close()\n",
+ "wordclouds(fake_news_topics, 'fake_topics.png')"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 494
+ },
+ "id": "TqmJQwtVws6y",
+ "outputId": "2a0b3dc6-a797-4c3b-ce58-d8fb312b2ed3"
+ },
+ "id": "TqmJQwtVws6y",
+ "execution_count": 17,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "100%|██████████| 10/10 [00:02<00:00, 3.50it/s]\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "