Skip to content

Commit

Permalink
fixes #533
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Aug 22, 2024
1 parent 59c876a commit 5a08165
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ Most recent releases are shown at the top. Each release shows:
- **Fixed**: Bug fixes that don't change documented behaviour


## 0.41.5 (TBD)

### new:
- N/A

### changed
- N/A

### fixed:
- Update `test_lda.py` due to changes in `numpy` (#533)


## 0.41.4 (2024-06-18)

### new:
Expand Down
2 changes: 1 addition & 1 deletion ktrain/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__all__ = ["__version__"]
__version__ = "0.41.4"
__version__ = "0.41.5"
13 changes: 9 additions & 4 deletions tests/test_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,23 @@ def test_qa(self):
texts = newsgroups_train.data + newsgroups_test.data

# buld and test LDA topic model
tm = ktrain.text.get_topic_model(texts, n_features=10000)
tm = ktrain.text.get_topic_model(texts)
tm.build(texts, threshold=0.25)
texts = tm.filter(texts)
tags = tm.topics[np.argmax(tm.predict([rawtext]))]
self.assertEqual(
tags, "space nasa earth data launch surface solar moon mission planet"
tags.split(" ")[:5], ["space", "nasa", "earth", "data", "launch"]
)
# "space nasa earth data launch surface solar moon mission planet"
tm.save("/tmp/tm")
tm = ktrain.text.load_topic_model("/tmp/tm")
tm.build(texts, threshold=0.25)
tags = tm.topics[np.argmax(tm.predict([rawtext]))]
# self.assertEqual(
# tags, "space nasa earth data launch surface solar moon mission planet"
# )
self.assertEqual(
tags, "space nasa earth data launch surface solar moon mission planet"
tags.split(" ")[:5], ["space", "nasa", "earth", "data", "launch"]
)

# document similarity
Expand All @@ -61,7 +65,8 @@ def test_qa(self):
reverse=True,
)
df = pd.DataFrame(data, columns=["Prediction", "Score", "Text"])
self.assertTrue("recommendations for a laser printer" in df["Text"].values[0])
print(f"Best match for technical topic: {df['Text'].values[0]}")
self.assertTrue("Stacker achieves better compression" in df["Text"].values[0])

# recommender
tm.train_recommender()
Expand Down

0 comments on commit 5a08165

Please sign in to comment.