From e842ebc60e8a448e56fb028c6952e71de0e77434 Mon Sep 17 00:00:00 2001 From: Ian Johnson Date: Mon, 29 Jan 2024 13:47:13 -0500 Subject: [PATCH] finish refactor for pip module. local python server works --- latentscope/models/__init__.py | 9 +-- latentscope/models/providers/cohereai.py | 5 +- latentscope/models/providers/mistralai.py | 7 +- latentscope/models/providers/openai.py | 7 +- latentscope/models/providers/togetherai.py | 5 +- latentscope/models/providers/voyageai.py | 5 +- latentscope/scripts/cluster.py | 56 ++++++++++----- latentscope/scripts/embed.py | 40 ++++++----- latentscope/scripts/ingest.py | 13 +++- .../{label-clusters.py => label_clusters.py} | 44 ++++++------ latentscope/scripts/umapper.py | 70 ++++++++++++------- latentscope/server/app.py | 12 ++-- latentscope/server/jobs.py | 13 ++-- latentscope/server/search.py | 5 +- latentscope/server/tags.py | 8 +-- latentscope/util/__init__.py | 2 +- latentscope/util/configuration.py | 7 +- setup.py | 4 ++ web/src/components/DatasetExplore.jsx | 40 ++++++----- web/src/components/Home.jsx | 2 +- web/src/components/Setup/Umap.jsx | 6 +- 21 files changed, 212 insertions(+), 148 deletions(-) rename latentscope/scripts/{label-clusters.py => label_clusters.py} (88%) diff --git a/latentscope/models/__init__.py b/latentscope/models/__init__.py index 2b4813e..01e014e 100644 --- a/latentscope/models/__init__.py +++ b/latentscope/models/__init__.py @@ -1,5 +1,6 @@ import os import json +import pkg_resources from .providers.transformers import TransformersEmbedProvider, TransformersChatProvider from .providers.openai import OpenAIEmbedProvider, OpenAIChatProvider from .providers.mistralai import MistralAIEmbedProvider, MistralAIChatProvider @@ -10,8 +11,8 @@ def get_embedding_model(id): """Returns a ModelProvider instance for the given model id.""" - embed_models_path = os.path.join(os.path.dirname(__file__), "embedding_models.json") - with open(embed_models_path, "r") as f: + embedding_path = pkg_resources.resource_filename('latentscope.models', 'embedding_models.json') + with open(embedding_path, "r") as f: embed_model_list = json.load(f) embed_model_dict = {model['id']: model for model in embed_model_list} model = embed_model_dict[id] @@ -34,8 +35,8 @@ def get_embedding_model(id): def get_chat_model(id): """Returns a ModelProvider instance for the given model id.""" - chat_models_path = os.path.join(os.path.dirname(__file__), "chat_models.json") - with open(chat_models_path, "r") as f: + chat_path = pkg_resources.resource_filename('latentscope.models', 'chat_models.json') + with open(chat_path, "r") as f: chat_model_list = json.load(f) chat_model_dict = {model['id']: model for model in chat_model_list} model = chat_model_dict[id] diff --git a/latentscope/models/providers/cohereai.py b/latentscope/models/providers/cohereai.py index 4143467..f671c63 100644 --- a/latentscope/models/providers/cohereai.py +++ b/latentscope/models/providers/cohereai.py @@ -3,12 +3,11 @@ import cohere from .base import EmbedModelProvider -from dotenv import load_dotenv -load_dotenv() +from latentscope.util import get_key class CohereAIEmbedProvider(EmbedModelProvider): def load_model(self): - self.client = cohere.Client(os.getenv("COHERE_API_KEY")) + self.client = cohere.Client(get_key("COHERE_API_KEY")) def embed(self, inputs): time.sleep(0.01) # TODO proper rate limiting diff --git a/latentscope/models/providers/mistralai.py b/latentscope/models/providers/mistralai.py index e0029c3..dfa482c 100644 --- a/latentscope/models/providers/mistralai.py +++ b/latentscope/models/providers/mistralai.py @@ -5,8 +5,7 @@ from transformers import AutoTokenizer from .base import EmbedModelProvider,ChatModelProvider -from dotenv import load_dotenv -load_dotenv() +from latentscope.util import get_key # TODO verify these tokenizers somehow # derived from: @@ -20,7 +19,7 @@ class MistralAIEmbedProvider(EmbedModelProvider): def load_model(self): - self.client = MistralClient(os.getenv("MISTRAL_API_KEY")) + self.client = MistralClient(get_key("MISTRAL_API_KEY")) def embed(self, inputs): time.sleep(0.1) # TODO proper rate limiting @@ -29,7 +28,7 @@ def embed(self, inputs): class MistralAIChatProvider(ChatModelProvider): def load_model(self): - self.client = MistralClient(api_key=os.getenv("MISTRAL_API_KEY")) + self.client = MistralClient(api_key=get_key("MISTRAL_API_KEY")) self.encoder = AutoTokenizer.from_pretrained(encoders[self.name]) def chat(self, messages): diff --git a/latentscope/models/providers/openai.py b/latentscope/models/providers/openai.py index 6dd3939..31a3ab6 100644 --- a/latentscope/models/providers/openai.py +++ b/latentscope/models/providers/openai.py @@ -4,12 +4,11 @@ from openai import OpenAI from .base import EmbedModelProvider, ChatModelProvider -from dotenv import load_dotenv -load_dotenv() +from latentscope.util import get_key class OpenAIEmbedProvider(EmbedModelProvider): def load_model(self): - self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + self.client = OpenAI(api_key=get_key("OPENAI_API_KEY")) self.encoder = tiktoken.encoding_for_model(self.name) def embed(self, inputs): @@ -27,7 +26,7 @@ def embed(self, inputs): class OpenAIChatProvider(ChatModelProvider): def load_model(self): - self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + self.client = OpenAI(api_key=get_key("OPENAI_API_KEY")) self.encoder = tiktoken.encoding_for_model(self.name) def chat(self, messages): diff --git a/latentscope/models/providers/togetherai.py b/latentscope/models/providers/togetherai.py index a89f593..7afa304 100644 --- a/latentscope/models/providers/togetherai.py +++ b/latentscope/models/providers/togetherai.py @@ -4,12 +4,11 @@ import together from .base import EmbedModelProvider -from dotenv import load_dotenv -load_dotenv() +from latentscope.util import get_key class TogetherAIEmbedProvider(EmbedModelProvider): def load_model(self): - together.api_key = os.getenv("TOGETHER_API_KEY") + together.api_key = get_key("TOGETHER_API_KEY") self.client = together.Together() self.encoder = tiktoken.encoding_for_model("text-embedding-ada-002") diff --git a/latentscope/models/providers/voyageai.py b/latentscope/models/providers/voyageai.py index 5f432b5..e22eca4 100644 --- a/latentscope/models/providers/voyageai.py +++ b/latentscope/models/providers/voyageai.py @@ -3,12 +3,11 @@ import voyageai from .base import EmbedModelProvider -from dotenv import load_dotenv -load_dotenv() +from latentscope.util import get_key class VoyageAIEmbedProvider(EmbedModelProvider): def load_model(self): - self.client = voyageai.Client(os.getenv("VOYAGE_API_KEY")) + self.client = voyageai.Client(get_key("VOYAGE_API_KEY")) def embed(self, inputs): time.sleep(0.1) # TODO proper rate limiting diff --git a/latentscope/scripts/cluster.py b/latentscope/scripts/cluster.py index 12f70cd..f0599d2 100644 --- a/latentscope/scripts/cluster.py +++ b/latentscope/scripts/cluster.py @@ -2,23 +2,49 @@ # Example: python cluster.py dadabase-curated umap-001 50 5 import os import re -import sys import json import hdbscan +import argparse import numpy as np import pandas as pd import matplotlib.pyplot as plt from scipy.spatial import ConvexHull from scipy.spatial.distance import cdist +from latentscope.util import get_data_dir + +# TODO move this into shared space +def calculate_point_size(num_points, min_size=10, max_size=30, base_num_points=100): + """ + Calculate the size of points for a scatter plot based on the number of points. + """ + # TODO fix this to actually calculate a log scale between min and max size + if num_points <= base_num_points: + return max_size + else: + return min(min_size + min_size * np.log(num_points / base_num_points), max_size) + + +def main(): + parser = argparse.ArgumentParser(description='Cluster UMAP embeddings') + parser.add_argument('dataset_name', type=str, help='Name of the dataset') + parser.add_argument('umap_name', type=str, help='Name of the UMAP file') + parser.add_argument('samples', type=int, help='Minimum cluster size') + parser.add_argument('min_samples', type=int, help='Minimum samples for HDBSCAN') + + args = parser.parse_args() + clusterer(args.dataset_name, args.umap_name, args.samples, args.min_samples) + def clusterer(dataset_name, umap_name, samples, min_samples): + DATA_DIR = get_data_dir() + cluster_dir = os.path.join(DATA_DIR, dataset_name, "clusters") # Check if clusters directory exists, if not, create it - if not os.path.exists(f'../data/{dataset_name}/clusters'): - os.makedirs(f'../data/{dataset_name}/clusters') + if not os.path.exists(cluster_dir): + os.makedirs(cluster_dir) # determine the index of the last cluster run by looking in the dataset directory # for files named umap-.json - cluster_files = [f for f in os.listdir(f"../data/{dataset_name}/clusters") if re.match(r"cluster-\d+\.json", f)] + cluster_files = [f for f in os.listdir(cluster_dir) if re.match(r"cluster-\d+\.json", f)] print("cluster files", sorted(cluster_files)) if len(cluster_files) > 0: last_cluster = sorted(cluster_files)[-1] @@ -31,7 +57,7 @@ def clusterer(dataset_name, umap_name, samples, min_samples): # make the umap name from the number, zero padded to 3 digits cluster_name = f"cluster-{next_cluster_number:03d}" - umap_embeddings_df = pd.read_parquet(f"../data/{dataset_name}/umaps/{umap_name}.parquet") + umap_embeddings_df = pd.read_parquet(os.path.join(DATA_DIR, dataset_name, "umaps", f"{umap_name}.parquet")) umap_embeddings = umap_embeddings_df.to_numpy() clusterer = hdbscan.HDBSCAN(min_cluster_size=samples, min_samples=min_samples, metric='euclidean') @@ -63,14 +89,16 @@ def clusterer(dataset_name, umap_name, samples, min_samples): # save umap embeddings to a parquet file with columns x,y df = pd.DataFrame({"cluster": cluster_labels, "raw_cluster": raw_cluster_labels}) - output_file = f"../data/{dataset_name}/clusters/{cluster_name}.parquet" + output_file = os.path.join(cluster_dir, f"{cluster_name}.parquet") df.to_parquet(output_file) print(df.head()) print("wrote", output_file) # generate a scatterplot of the umap embeddings and save it to a file - fig, ax = plt.subplots(figsize=(6, 6)) - plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=1, alpha=0.5, c=cluster_labels, cmap='Spectral') + fig, ax = plt.subplots(figsize=(14.22, 14.22)) # 1024px by 1024px at 72 dpi + point_size = calculate_point_size(umap_embeddings.shape[0]) + print("POINT SIZE", point_size, "for", umap_embeddings.shape[0], "points") + plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=point_size, alpha=0.5, c=cluster_labels, cmap='Spectral') # plot a convex hull around each cluster for label in non_noise_labels: points = umap_embeddings[cluster_labels == label] @@ -80,9 +108,9 @@ def clusterer(dataset_name, umap_name, samples, min_samples): plt.axis('off') # remove axis plt.gca().set_position([0, 0, 1, 1]) # remove margins - plt.savefig(f"../data/{dataset_name}/clusters/{cluster_name}.png") + plt.savefig(os.path.join(cluster_dir, f"{cluster_name}.png")) - with open(f'../data/{dataset_name}/clusters/{cluster_name}.json', 'w') as f: + with open(os.path.join(cluster_dir,f"{cluster_name}.json"), 'w') as f: json.dump({ "cluster_name": cluster_name, "umap_name": umap_name, @@ -106,12 +134,8 @@ def clusterer(dataset_name, umap_name, samples, min_samples): slides_df = pd.concat([slides_df, new_row], ignore_index=True) # write the df to parquet - slides_df.to_parquet(f"../data/{dataset_name}/clusters/{cluster_name}-labels.parquet") + slides_df.to_parquet(os.path.join(cluster_dir, f"{cluster_name}-labels.parquet")) print("done") if __name__ == "__main__": - dataset_name = sys.argv[1] - umap_name = sys.argv[2] - samples = int(sys.argv[3]) - min_samples = int(sys.argv[4]) - clusterer(dataset_name, umap_name, samples, min_samples) + main() \ No newline at end of file diff --git a/latentscope/scripts/embed.py b/latentscope/scripts/embed.py index 87dfadb..9469e78 100644 --- a/latentscope/scripts/embed.py +++ b/latentscope/scripts/embed.py @@ -1,22 +1,31 @@ -# Usage: python embed-local.py +# Usage: ls-embed import os -import sys import argparse import numpy as np import pandas as pd from tqdm import tqdm -# TODO is this hacky way to import from the models directory? -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from models import get_embedding_model +from latentscope.models import get_embedding_model +from latentscope.util import get_data_dir def chunked_iterable(iterable, size): """Yield successive chunks from an iterable.""" for i in range(0, len(iterable), size): yield iterable[i:i + size] -def embedder(dataset_name, text_column="text", model_id="transformers-BAAI___bge-small-en-v1.5"): - df = pd.read_parquet(f"../data/{dataset_name}/input.parquet") +def main(): + parser = argparse.ArgumentParser(description='Embed a dataset') + parser.add_argument('name', type=str, help='Dataset name (directory name in data/)') + parser.add_argument('text_column', type=str, help='Output file', default='text') + parser.add_argument('model', type=str, help='ID of embedding model to use', default="transformers-BAAI___bge-small-en-v1.5") + + # Parse arguments + args = parser.parse_args() + embed(args.name, args.text_column, args.model) + +def embed(dataset_name, text_column, model_id): + DATA_DIR = get_data_dir() + df = pd.read_parquet(os.path.join(DATA_DIR, dataset_name, "input.parquet")) sentences = df[text_column].tolist() model = get_embedding_model(model_id) @@ -36,19 +45,12 @@ def embedder(dataset_name, text_column="text", model_id="transformers-BAAI___bge print("sentence embeddings:", np_embeds.shape) # Save embeddings as a numpy file - if not os.path.exists(f'../data/{dataset_name}/embeddings'): - os.makedirs(f'../data/{dataset_name}/embeddings') + emb_dir = os.path.join(DATA_DIR, dataset_name, "embeddings") + if not os.path.exists(emb_dir): + os.makedirs(emb_dir) - np.save(f'../data/{dataset_name}/embeddings/{model_id}.npy', np_embeds) + np.save(os.path.join(DATA_DIR, dataset_name, "embeddings", f"{model_id}.npy"), np_embeds) print("done") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Embed a dataset') - parser.add_argument('name', type=str, help='Dataset name (directory name in data/)') - parser.add_argument('text_column', type=str, help='Output file', default='text') - parser.add_argument('model', type=str, help='ID of embedding model to use', default="transformers-BAAI___bge-small-en-v1.5") - - # Parse arguments - args = parser.parse_args() - - embedder(args.name, args.text_column, args.model) + main() \ No newline at end of file diff --git a/latentscope/scripts/ingest.py b/latentscope/scripts/ingest.py index 9e22eee..9123df9 100644 --- a/latentscope/scripts/ingest.py +++ b/latentscope/scripts/ingest.py @@ -1,15 +1,22 @@ -# Usage: python ingest.py +# Usage: ls-ingest import os -import sys import json +import argparse import pandas as pd from latentscope.util import get_data_dir +# TODO: somehow optionally accept a pandas dataframe as input def main(): + parser = argparse.ArgumentParser(description='Ingest a dataset') + parser.add_argument('name', type=str, help='Dataset name (directory name in data folder)') + args = parser.parse_args() + ingest(args.name) + +def ingest(dataset_name): DATA_DIR = get_data_dir() - dataset_name = sys.argv[1] directory = os.path.join(DATA_DIR, dataset_name) + # TODO: inspect the incoming data to see if it is a csv or parquet file csv_file = os.path.join(directory, "input.csv") print("reading", csv_file) df = pd.read_csv(csv_file) diff --git a/latentscope/scripts/label-clusters.py b/latentscope/scripts/label_clusters.py similarity index 88% rename from latentscope/scripts/label-clusters.py rename to latentscope/scripts/label_clusters.py index 995a2d7..5314fc1 100644 --- a/latentscope/scripts/label-clusters.py +++ b/latentscope/scripts/label_clusters.py @@ -6,13 +6,9 @@ import numpy as np import pandas as pd from tqdm import tqdm -from dotenv import load_dotenv -load_dotenv() - -# TODO is this hacky way to import from the models directory? -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) -from models import get_chat_model +from latentscope.util import get_data_dir +from latentscope.models import get_chat_model def chunked_iterable(iterable, size): """Yield successive chunks from an iterable.""" @@ -26,13 +22,29 @@ def too_many_duplicates(line, threshold=10): word_count[word] = word_count.get(word, 0) + 1 return any(count > threshold for count in word_count.values()) +def main(): + parser = argparse.ArgumentParser(description='Label a set of slides using OpenAI') + parser.add_argument('name', type=str, help='Dataset name (directory name in data/)') + parser.add_argument('text_column', type=str, help='Output file', default='text') + parser.add_argument('cluster_name', type=str, help='name of slides set', default='cluster-001') + parser.add_argument('model', type=str, help='Name of model to use', default="openai-gpt-3.5-turbo") + parser.add_argument('context', type=str, help='Additional context for labeling model', default="") + + # Parse arguments + args = parser.parse_args() + + labeler(args.name, args.text_column, args.cluster_name, args.model, args.context) + + def labeler(dataset_name, text_column="text", cluster_name="cluster-001", model_id="gpt-3.5-turbo", context=""): - df = pd.read_parquet(f"../data/{dataset_name}/input.parquet") - # TODO This should be dropped in the preprocessing step + DATA_DIR = get_data_dir() + df = pd.read_parquet(os.path.join(DATA_DIR, dataset_name, "input.parquet")) + # TODO This should be done in the preprocessing step df = df.reset_index(drop=True) # Load the indices for each cluster from the prepopulated labels file generated by cluster.py - clusters = pd.read_parquet(f"../data/{dataset_name}/clusters/{cluster_name}-labels.parquet") + cluster_dir = os.path.join(DATA_DIR, dataset_name, "clusters") + clusters = pd.read_parquet(os.path.join(cluster_dir, f"{cluster_name}-labels.parquet")) model = get_chat_model(model_id) model.load_model() @@ -104,18 +116,8 @@ def labeler(dataset_name, text_column="text", cluster_name="cluster-001", model_ clusters_df['label_raw'] = labels # write the df to parquet - clusters_df.to_parquet(f"../data/{dataset_name}/clusters/{cluster_name}-labels-{model_id}.parquet") + clusters_df.to_parquet(os.path.join(cluster_dir, f"{cluster_name}-labels-{model_id}.parquet")) print("done") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Label a set of slides using OpenAI') - parser.add_argument('name', type=str, help='Dataset name (directory name in data/)') - parser.add_argument('text_column', type=str, help='Output file', default='text') - parser.add_argument('cluster_name', type=str, help='name of slides set', default='cluster-001') - parser.add_argument('model', type=str, help='Name of model to use', default="openai-gpt-3.5-turbo") - parser.add_argument('context', type=str, help='Additional context for labeling model', default="") - - # Parse arguments - args = parser.parse_args() - - labeler(args.name, args.text_column, args.cluster_name, args.model, args.context) + main() diff --git a/latentscope/scripts/umapper.py b/latentscope/scripts/umapper.py index d084069..f5c70fc 100644 --- a/latentscope/scripts/umapper.py +++ b/latentscope/scripts/umapper.py @@ -11,19 +11,44 @@ import pandas as pd import matplotlib.pyplot as plt +from latentscope.util import get_data_dir -def umapper(dataset_name, model_unsanitized, neighbors=25, min_dist=0.075): - # TODO: make sanitize a function - model = model_unsanitized.replace("/", "___") +def main(): + parser = argparse.ArgumentParser(description='UMAP embeddings for a dataset') + parser.add_argument('name', type=str, help='Dataset name (directory name in data/)') + parser.add_argument('model', type=str, help='Name of embedding model to use') + parser.add_argument('neighbors', type=int, help='Output file', default=25) + parser.add_argument('min_dist', type=float, help='Output file', default=0.075) + + # Parse arguments + args = parser.parse_args() + umapper(args.name, args.model, args.neighbors, args.min_dist) + + +# TODO move this into shared space +def calculate_point_size(num_points, min_size=10, max_size=30, base_num_points=100): + """ + Calculate the size of points for a scatter plot based on the number of points. + """ + # TODO fix this to actually calculate a log scale between min and max size + if num_points <= base_num_points: + return max_size + else: + return min(min_size + min_size * np.log(num_points / base_num_points), max_size) + + +def umapper(dataset_name, model_id, neighbors=25, min_dist=0.075): + DATA_DIR = get_data_dir() # read in the embeddings - embeddings = np.load(f'../data/{dataset_name}/embeddings/{model}.npy') + embeddings = np.load(os.path.join(DATA_DIR, dataset_name, "embeddings", f"{model_id}.npy")) - if not os.path.exists(f'../data/{dataset_name}/umaps'): - os.makedirs(f'../data/{dataset_name}/umaps') + umap_dir = os.path.join(DATA_DIR, dataset_name, "umaps") + if not os.path.exists(umap_dir): + os.makedirs(umap_dir) # determine the index of the last umap run by looking in the dataset directory # for files named umap-.json - umap_files = [f for f in os.listdir(f"../data/{dataset_name}/umaps") if re.match(r"umap-\d+\.json", f)] + umap_files = [f for f in os.listdir(umap_dir) if re.match(r"umap-\d+\.json", f)] if len(umap_files) > 0: last_umap = sorted(umap_files)[-1] last_umap_number = int(last_umap.split("-")[1].split(".")[0]) @@ -33,7 +58,6 @@ def umapper(dataset_name, model_unsanitized, neighbors=25, min_dist=0.075): # make the umap name from the number, zero padded to 3 digits umap_name = f"umap-{next_umap_number:03d}" - reducer = umap.UMAP( n_neighbors=neighbors, @@ -57,41 +81,35 @@ def umapper(dataset_name, model_unsanitized, neighbors=25, min_dist=0.075): # save umap embeddings to a parquet file with columns x,y df = pd.DataFrame(umap_embeddings, columns=['x', 'y']) - output_file = f"../data/{dataset_name}/umaps/{umap_name}.parquet" + output_file = os.path.join(umap_dir, f"{umap_name}.parquet") df.to_parquet(output_file) print("wrote", output_file) # generate a scatterplot of the umap embeddings and save it to a file - fig, ax = plt.subplots(figsize=(6, 6)) - plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=1, alpha=0.5) + fig, ax = plt.subplots(figsize=(14.22, 14.22)) # 1024px by 1024px at 72 dpi + point_size = calculate_point_size(umap_embeddings.shape[0]) + print("POINT SIZE", point_size, "for", umap_embeddings.shape[0], "points") + plt.scatter(umap_embeddings[:, 0], umap_embeddings[:, 1], s=point_size, alpha=0.5) plt.axis('off') # remove axis plt.gca().set_position([0, 0, 1, 1]) # remove margins - plt.savefig(f"../data/{dataset_name}/umaps/{umap_name}.png") + plt.savefig(os.path.join(umap_dir, f"{umap_name}.png")) # save a json file with the umap parameters - with open(f'../data/{dataset_name}/umaps/{umap_name}.json', 'w') as f: + with open(os.path.join(umap_dir, f'{umap_name}.json'), 'w') as f: json.dump({ "name": umap_name, - "embeddings": model, + "embeddings": model_id, "neighbors": neighbors, - "min_dist": min_dist}, f, indent=2) + "min_dist": min_dist + }, f, indent=2) f.close() # save a pickle of the umap - with open(f'../data/{dataset_name}/umaps/{umap_name}.pkl', 'wb') as f: + with open(os.path.join(umap_dir, f'{umap_name}.pkl'), 'wb') as f: pickle.dump(reducer, f) print("done") - if __name__ == "__main__": - parser = argparse.ArgumentParser(description='UMAP embeddings for a dataset') - parser.add_argument('name', type=str, help='Dataset name (directory name in data/)') - parser.add_argument('model', type=str, help='Name of embedding model to use', default="BAAI/bge-small-en-v1.5") - parser.add_argument('neighbors', type=int, help='Output file', default=25) - parser.add_argument('min_dist', type=float, help='Output file', default=0.075) - - # Parse arguments - args = parser.parse_args() - umapper(args.name, args.model, args.neighbors, args.min_dist) + main() diff --git a/latentscope/server/app.py b/latentscope/server/app.py index 7ffc215..88bc1bf 100644 --- a/latentscope/server/app.py +++ b/latentscope/server/app.py @@ -58,7 +58,6 @@ def get_embedding_models(): @app.route('/api/chat_models', methods=['GET']) def get_chat_models(): - file_path = os.path.join(os.getcwd(), 'models', 'chat_models.json') chat_path = pkg_resources.resource_filename('latentscope.models', 'chat_models.json') with open(chat_path, 'r', encoding='utf-8') as file: models = json.load(file) @@ -92,14 +91,17 @@ def indexed(): return rows.to_json(orient="records") - dist_dir = './web/dist' @app.route('/', defaults={'path': ''}) @app.route('/') def catch_all(path): - if path != "" and os.path.exists(os.path.join(dist_dir, path)): - return send_from_directory(dist_dir, path) + if path != "": + pth = pkg_resources.resource_filename('latentscope', f"web/dist/{path}") + directory = os.path.dirname(pth) + return send_from_directory(directory, os.path.basename(pth)) else: - return send_from_directory(dist_dir, 'index.html') + pth = pkg_resources.resource_filename('latentscope', "web/dist/index.html") + directory = os.path.dirname(pth) + return send_from_directory(directory, os.path.basename(pth)) app.run(host=host, port=port, debug=debug) diff --git a/latentscope/server/jobs.py b/latentscope/server/jobs.py index d7da3ee..117a8d8 100644 --- a/latentscope/server/jobs.py +++ b/latentscope/server/jobs.py @@ -17,8 +17,11 @@ def run_job(dataset, job_id, command): os.makedirs(job_dir) progress_file = os.path.join(job_dir, f"{job_id}.json") + print("command", command) # job_name = command.replace("python ../scripts/", "").replace(".py", "!!!").split("!!!")[0], - job_name = command.split(" ")[0].split("-")[1] + job_name = command.split(" ")[0] + if "ls-" in job_name: + job_name = job_name.replace("ls-", "") job = { "dataset": dataset, "job_name": job_name, @@ -96,7 +99,7 @@ def run_embed(): model = request.args.get('model') # model id job_id = str(uuid.uuid4()) - command = f'python ../scripts/embed.py {dataset} {text_column} {model}' + command = f'ls-embed {dataset} {text_column} {model}' threading.Thread(target=run_job, args=(dataset, job_id, command)).start() return jsonify({"job_id": job_id}) @@ -110,7 +113,7 @@ def run_umap(): print("run umap", dataset, embeddings, neighbors, min_dist) job_id = str(uuid.uuid4()) - command = f'python ../scripts/umapper.py {dataset} {embeddings} {neighbors} {min_dist}' + command = f'ls-umap {dataset} {embeddings} {neighbors} {min_dist}' threading.Thread(target=run_job, args=(dataset, job_id, command)).start() return jsonify({"job_id": job_id}) @@ -147,7 +150,7 @@ def run_cluster(): print("run cluster", dataset, umap_name, samples, min_samples) job_id = str(uuid.uuid4()) - command = f'python ../scripts/cluster.py {dataset} {umap_name} {samples} {min_samples}' + command = f'ls-cluster {dataset} {umap_name} {samples} {min_samples}' threading.Thread(target=run_job, args=(dataset, job_id, command)).start() return jsonify({"job_id": job_id}) @@ -171,6 +174,6 @@ def run_cluster_label(): print("context", context) job_id = str(uuid.uuid4()) - command = f'python ../scripts/label-clusters.py {dataset} {text_column} {cluster} {model} "{context}"' + command = f'ls-label {dataset} {text_column} {cluster} {model} "{context}"' threading.Thread(target=run_job, args=(dataset, job_id, command)).start() return jsonify({"job_id": job_id}) diff --git a/latentscope/server/search.py b/latentscope/server/search.py index a415c27..653b2e5 100644 --- a/latentscope/server/search.py +++ b/latentscope/server/search.py @@ -3,12 +3,11 @@ import numpy as np from flask import Blueprint, jsonify, request -# TODO is this hacky way to import from the models directory? -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from latentscope.models import get_embedding_model # Create a Blueprint search_bp = Blueprint('search_bp', __name__) +DATA_DIR = os.getenv('LATENT_SCOPE_DATA') # in memory cache of dataset metadata, embeddings, models and tokenizers DATASETS = {} @@ -34,7 +33,7 @@ def nn(): if dataset not in DATASETS or model_id not in DATASETS[dataset]: # load the dataset embeddings - embeddings = np.load(os.path.join("../data", dataset, "embeddings", model_id + ".npy")) + embeddings = np.load(os.path.join(DATA_DIR, dataset, "embeddings", model_id + ".npy")) print("fitting embeddings") from sklearn.neighbors import NearestNeighbors nne = NearestNeighbors(n_neighbors=num, metric="cosine") diff --git a/latentscope/server/tags.py b/latentscope/server/tags.py index 0e6094a..a38572b 100644 --- a/latentscope/server/tags.py +++ b/latentscope/server/tags.py @@ -30,7 +30,7 @@ def tags(): for f in os.listdir(tagdir): if f.endswith(".indices"): tag = f.split(".")[0] - indices = np.loadtxt(os.path.join("../data", dataset, "tags", tag + ".indices"), dtype=int).tolist() + indices = np.loadtxt(os.path.join(DATA_DIR, dataset, "tags", tag + ".indices"), dtype=int).tolist() if type(indices) == int: indices = [indices] tagsets[dataset][tag] = indices @@ -49,10 +49,10 @@ def new_tag(): tagsets[dataset] = {} # search the dataset directory for all files ending in .indices tags = [] - for f in os.listdir(os.path.join("../data", dataset)): + for f in os.listdir(os.path.join(DATA_DIR, dataset)): if f.endswith(".indices"): dtag = f.split(".")[0] - indices = np.loadtxt(os.path.join("../data", dataset, "tags", dtag + ".indices"), dtype=int).tolist() + indices = np.loadtxt(os.path.join(DATA_DIR, dataset, "tags", dtag + ".indices"), dtype=int).tolist() if type(indices) == int: indices = [indices] tagsets[dataset][dtag] = indices @@ -60,7 +60,7 @@ def new_tag(): if tag not in tagsets[dataset]: tagsets[dataset][tag] = [] # create an empty file - filename = os.path.join("../data", dataset, "tags", tag + ".indices") + filename = os.path.join(DATA_DIR, dataset, "tags", tag + ".indices") with open(filename, 'w') as f: f.write("") f.close() diff --git a/latentscope/util/__init__.py b/latentscope/util/__init__.py index 15d18d1..7c2b954 100644 --- a/latentscope/util/__init__.py +++ b/latentscope/util/__init__.py @@ -1 +1 @@ -from .configuration import get_data_dir, update_data_dir, set_openai_key, set_voyage_key, set_together_key, set_cohere_key, set_mistral_key +from .configuration import get_data_dir, update_data_dir, get_key, set_openai_key, set_voyage_key, set_together_key, set_cohere_key, set_mistral_key diff --git a/latentscope/util/configuration.py b/latentscope/util/configuration.py index bc99875..aa80362 100644 --- a/latentscope/util/configuration.py +++ b/latentscope/util/configuration.py @@ -5,7 +5,8 @@ def get_data_dir(): DATA_DIR = os.getenv('LATENT_SCOPE_DATA') if DATA_DIR is None: - print("LATENT_SCOPE_DATA environment variable not set. Please set it to the directory where you want to store your data.") + print("""LATENT_SCOPE_DATA environment variable not set. Please set it to the directory where you want to store your data. +e.g.: export LATENT_SCOPE_DATA=~/latentscope-data""") sys.exit(1) return DATA_DIR @@ -20,6 +21,10 @@ def update_data_dir(directory, env_file=".env"): os.makedirs(directory) return directory +def get_key(key, env_file=".env"): + load_dotenv(env_file) + return os.getenv(key) + def set_openai_key(openai_key, env_file=".env"): # Load existing .env file, or create one if it doesn't exist load_dotenv(env_file) diff --git a/setup.py b/setup.py index 7663d04..e11fab4 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,10 @@ def run(self): 'console_scripts': [ 'ls-serve=latentscope.server:serve', 'ls-ingest=latentscope.scripts.ingest:main', + 'ls-embed=latentscope.scripts.embed:main', + 'ls-umap=latentscope.scripts.umapper:main', + 'ls-cluster=latentscope.scripts.cluster:main', + 'ls-label=latentscope.scripts.label_clusters:main', ], }, include_package_data=True, diff --git a/web/src/components/DatasetExplore.jsx b/web/src/components/DatasetExplore.jsx index 6aa3596..bdc333c 100644 --- a/web/src/components/DatasetExplore.jsx +++ b/web/src/components/DatasetExplore.jsx @@ -133,10 +133,11 @@ function DatasetDetail() { let rows = data.map((row, index) => { return { index: indices[index], - text: row[text_column], - score: row.score, // TODO: this is custom to one dataset - distance: distances[index], - date: row.date, + ...row + // text: row[text_column], + // score: row.score, // TODO: this is custom to one dataset + // distance: distances[index], + // date: row.date, } }) rows.sort((a, b) => b.score - a.score) @@ -167,21 +168,22 @@ function DatasetDetail() { const [tagrows, setTagrows] = useState([]); useEffect(() => { if(tagset[tag]) { - fetch(`${apiUrl}/tags/rows?dataset=${dataset.id}&tag=${tag}`) - .then(response => response.json()) - .then(data => { - const text_column = dataset.text_column - let rows = data.map((row, index) => { - return { - index: tagset[tag][index], - text: row[text_column], - score: row.score, // TODO: this is custom to one dataset - date: row.date, - } - }) - rows.sort((a, b) => b.score - a.score) - setTagrows(rows) - }).catch(e => console.log(e)); + hydrateIndices(tagset[tag], setTagrows) + // fetch(`${apiUrl}/tags/rows?dataset=${dataset.id}&tag=${tag}`) + // .then(response => response.json()) + // .then(data => { + // const text_column = dataset.text_column + // let rows = data.map((row, index) => { + // return { + // index: tagset[tag][index], + // text: row[text_column], + // score: row.score, // TODO: this is custom to one dataset + // date: row.date, + // } + // }) + // rows.sort((a, b) => b.score - a.score) + // setTagrows(rows) + // }).catch(e => console.log(e)); } else { setTagrows([]) } diff --git a/web/src/components/Home.jsx b/web/src/components/Home.jsx index 5c6949b..38afe78 100644 --- a/web/src/components/Home.jsx +++ b/web/src/components/Home.jsx @@ -122,7 +122,7 @@ function Home() {
{scopes[dataset.id] && scopes[dataset.id].map && scopes[dataset.id]?.map((scope,i) => (
- Explore {scope.name} - {scope.label}

+ Explore {scope.name} - {scope.label}



diff --git a/web/src/components/Setup/Umap.jsx b/web/src/components/Setup/Umap.jsx index 97fc9a1..6bfb8fd 100644 --- a/web/src/components/Setup/Umap.jsx +++ b/web/src/components/Setup/Umap.jsx @@ -22,8 +22,8 @@ Umap.propTypes = { // New embeddings update the list function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) { const [umapJob, setUmapJob] = useState(null); - const { startJob: startUmapJob } = useStartJobPolling(dataset, setUmapJob, '${apiUrl}/jobs/umap'); - const { startJob: deleteUmapJob } = useStartJobPolling(dataset, setUmapJob, '${apiUrl}/jobs/delete/umap'); + const { startJob: startUmapJob } = useStartJobPolling(dataset, setUmapJob, `${apiUrl}/jobs/umap`); + const { startJob: deleteUmapJob } = useStartJobPolling(dataset, setUmapJob, `${apiUrl}/jobs/delete/umap`); const [umaps, setUmaps] = useState([]); function fetchUmaps(datasetId, callback) { @@ -69,7 +69,7 @@ function Umap({ dataset, umap, embedding, clusters, onNew, onChange}) {