From c9f5b176ef8affd8bcefd21d416a17b416849bf0 Mon Sep 17 00:00:00 2001 From: "Olivie Franklova (CZ)" Date: Tue, 14 May 2024 16:44:25 +0200 Subject: [PATCH] Add cashing to column2Vec --- column2Vec/Column2Vec.py | 4 +- column2Vec/README.md | 19 ++-- column2Vec/functions.py | 70 +++++++++++-- column2Vec/playground.ipynb | 194 ++++++++++++------------------------ constants.py | 21 ++++ requirements.txt | Bin 4380 -> 4412 bytes test/test_column2Vec.py | 8 +- 7 files changed, 165 insertions(+), 151 deletions(-) diff --git a/column2Vec/Column2Vec.py b/column2Vec/Column2Vec.py index 60f3a1c..58a961b 100644 --- a/column2Vec/Column2Vec.py +++ b/column2Vec/Column2Vec.py @@ -43,7 +43,9 @@ def save(self, key: str, function: str, embedding: list): :param function: Function name :param embedding: to save """ - self.__cache.loc[function, key] = embedding + print(f"|{int(function)}| : |{int(key)}|") # todo solve this + self.__cache.at[function, key] = embedding + # self.__cache.loc[function, key] = embedding def save_persistently(self): """ diff --git a/column2Vec/README.md b/column2Vec/README.md index 518d021..10bbb48 100644 --- a/column2Vec/README.md +++ b/column2Vec/README.md @@ -1,21 +1,26 @@ # What is column2Vec -Is word2Vec type tool for creating embeddings vectors for string columns +Is word2Vec type tool for creating embedding vectors for string columns in tables. +We have implemented seven different approaches. ## Structure -folder [**generated**](generated) contains all generated files. Mostly html files representing +Folder [**generated**](generated) contains all generated files. +Mostly html files representing 2D clusters, created by clustering vectors. file [**Column2Vec.py**](Column2Vec.py) contains 7 different implementations of column2Vec. + +## Implementation description - **column2vec_as_sentence** creates one string from column, and then it transforms it to vector - **column2vec_as_sentence_clean** creates one string from column. String contains only numbers and a-z. Then it transforms clean string in to vector. - **column2vec_as_sentence_clean_uniq** creates one string from uniq values in column. String contains only numbers and a-z. Then it transforms clean string in to vector. -- **column2vec_avg** transforms every element in column into vector and then it makes average of them. -- **column2vec_weighted_avg** transforms every element in column into vector and then it makes weighted average of them (based on occurrence). +- **column2vec_avg** transforms every element in column into vector, and then it makes average of them. +- **column2vec_weighted_avg** transforms every element in column into vector, and then it makes weighted average of them (based on occurrence). - **column2vec_sum** transforms every uniq element in column into vector and then sum it. - **column2vec_weighted_sum** transforms every element in column into vector and then sum it. +> Inspired by [Michael J. Mior, Alexander G. Ororbia](https://arxiv.org/pdf/1903.08621) --- # Data and cluster description #### Used tables @@ -78,12 +83,12 @@ rating . duration . date_added . ``` -## How Did I cluster by copilot +## Making clusters by Microsoft Copilot - I wrote: `I will send you few rows of diferent tables could you please clustered columns of these tables ?` - I wrote: `I will send you all tables in cvs format i will say done when i will be done` -- Then I send 15 rows of each table to copilot and I worote done. +- Then I send 15 rows of each table to co-pilot. - I wrote all names of columns in the list above. -- I wrote `Could you please guess the clusters`, this does not worke and copilot response was `As an AI, I can provide a high-level approach to clustering the data based on the columns you’ve provided. However, I’m unable to perform the actual clustering operation or guess the clusters without running a specific clustering algorithm on the data. Here’s a general approach:` +- I wrote `Could you please guess the clusters`, this does not work and copilot response was `As an AI, I can provide a high-level approach to clustering the data based on the columns you’ve provided. However, I’m unable to perform the actual clustering operation or guess the clusters without running a specific clustering algorithm on the data. Here’s a general approach:` - I wrote `Could you show similar groups of columns` and I got the response below. (Ad Clustering by Microsoft Copilot) - I wrote `Could you split it to more groups ?` and I got the response below. (Ad Granular Clustering by Microsoft Copilot) ### Clustering by Microsoft Copilot diff --git a/column2Vec/functions.py b/column2Vec/functions.py index 7ef71dc..6ee78ca 100644 --- a/column2Vec/functions.py +++ b/column2Vec/functions.py @@ -1,23 +1,29 @@ """ Functions usefull for column2Vec. """ +import time from typing import Any +from collections.abc import Callable import numpy as np import pandas as pd +import plotly.express as px +from sentence_transformers import SentenceTransformer from sklearn.cluster import KMeans +from sklearn.manifold import TSNE +from constants import trained_model from similarity.Comparator import cosine_sim from similarity.DataFrameMetadataCreator import DataFrameMetadataCreator from similarity.Types import NONNUMERICAL -def get_data(files: list[str]) -> dict[str, Any]: +def get_nonnumerical_data(files: list[str]) -> dict[str, Any]: """ Reads all csv files (which name is in files). Creates metadata for them. - Save only nonnumerical columns into dictionary. Key is name of column. + Save only nonnumerical columns into dictionary. Key is a name of column. Value is column. - :param files: list names of csv files + :param files: List names of csv files :return: dictionary of all tables. """ result = {} @@ -36,13 +42,36 @@ def get_data(files: list[str]) -> dict[str, Any]: return result +def get_vectors(function: Callable[[pd.Series, SentenceTransformer, str], list], + data: dict[str, Any]) -> dict[str, Any]: + """ + Creates embedding vectors from column by using one of + the column2Vec implementations. + It also prints progress percent and elapsed time. + :param function: Is one of the column2Vec implementations + :param data: Data is a result from get_nonnumerical_data, + dictionary of all columns in all tables. + :return: Dictionary of embeddings, each column has its own embedding. + """ + start = time.time() + result = {} + count = 1 + for key in data: + print("Processing column: " + key + " " + str(round((count / len(data)) * 100, 2)) + "%") + result[key] = function(data[key], trained_model.get_module(), key) + count += 1 + end = time.time() + print(f"ELAPSED TIME :{end - start}") + return result + + def get_clusters(vectors_to_cluster: pd.DataFrame, n_clusters: int) -> list[list[str]]: """ Creates clusters by KMeans for given vectors. - :param vectors_to_cluster: embeddings for columns - :param n_clusters: number of clusters we want - :return: List, for each cluster number it contains list of column names + :param vectors_to_cluster: Embeddings for all column + :param n_clusters: numbers of clusters we want + :return: List, for each cluster number it contains a list of column names """ kmeans = KMeans(n_clusters=n_clusters, random_state=0) # Change n_clusters as needed list_of_vectors = np.array(list(vectors_to_cluster.values())) @@ -59,11 +88,38 @@ def get_clusters(vectors_to_cluster: pd.DataFrame, n_clusters: int) -> list[list return clusters +def plot_clusters(vectors_to_plot: pd.DataFrame, title: str): + """ + From vectors creates clusters by Kmeans then it transforms clusters + by TSNE(t-distributed Stochastic Neighbor Embedding). + It plots de graphics, and it saves the plot as file + :param vectors_to_plot: dataframe + :param title: title of plot containing name of function + """ + n_clusters = 12 + kmeans = KMeans(n_clusters=n_clusters, random_state=0) # Change n_clusters as needed + list_of_vectors = np.array(list(vectors_to_plot.values())) + kmeans.fit(list_of_vectors) + + tsne = TSNE(n_components=2, random_state=0) + reduced_vectors = tsne.fit_transform(list_of_vectors) + + df = pd.DataFrame(reduced_vectors, columns=['x', 'y']) + df['names'] = vectors_to_plot.keys() + # The cluster labels are returned in kmeans.labels_ + df['cluster'] = kmeans.labels_ + + fig = px.scatter(df, x='x', y='y', color='cluster', hover_data=['names']) + fig.update_layout(title=title) + fig.write_html(title.replace(" ", "_") + ".html") + fig.show() + + def compute_distances(vectors: dict): """ Compute distance for each pair of vectors. - :param vectors: dictionary of embedding vectors + :param vectors: Dictionary of embedding vectors :return: matrix with distances """ res = {} diff --git a/column2Vec/playground.ipynb b/column2Vec/playground.ipynb index f78ee09..4078955 100644 --- a/column2Vec/playground.ipynb +++ b/column2Vec/playground.ipynb @@ -14,27 +14,28 @@ "from column2Vec.Column2Vec import column2vec_as_sentence_clean_uniq\n", "from column2Vec.Column2Vec import column2vec_weighted_avg\n", "import time\n", - "from column2Vec.functions import get_clusters" + "from column2Vec.functions import get_clusters\n", + "from column2Vec.functions import get_vectors" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-26T09:57:59.420458Z", - "start_time": "2024-03-26T09:57:59.414725Z" + "end_time": "2024-05-14T14:15:40.050299Z", + "start_time": "2024-05-14T14:15:33.880395Z" } }, "id": "d2f663cd8db4d03b", - "execution_count": 11 + "execution_count": 1 }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 2, "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2024-03-26T09:57:59.444903Z", - "start_time": "2024-03-26T09:57:59.439511Z" + "end_time": "2024-05-14T14:15:40.057810Z", + "start_time": "2024-05-14T14:15:40.052810Z" } }, "outputs": [], @@ -59,51 +60,6 @@ "# dataM2 = pd.read_csv(fileM2)" ] }, - { - "cell_type": "code", - "outputs": [], - "source": [ - "\n", - "\n", - "model = SentenceTransformer('bert-base-nli-mean-tokens')\n", - "def get_data():\n", - " result = {}\n", - " index = 0\n", - " for i in files:\n", - " index += 1\n", - " data = pd.read_csv(i)\n", - " metadata_creator = (DataFrameMetadataCreator(data).\n", - " compute_advanced_structural_types().\n", - " compute_column_kind())\n", - " metadata1 = metadata_creator.get_metadata()\n", - " column_names = metadata1.get_column_names_by_type(NONNUMERICAL)\n", - " for name in column_names:\n", - " print(f\" {i} : {name}\")\n", - " result[name + str(index)] = data[name]\n", - " return result\n", - "\n", - "def get_vectors(function, data):\n", - " start = time.time()\n", - " result = {}\n", - " count = 1\n", - " for key in data:\n", - " print(\"Processing column: \" + key + \" \" + str(round((count/len(data))*100, 2)) + \"%\")\n", - " result[key] = function(data[key], model)\n", - " count += 1\n", - " end = time.time()\n", - " print(f\"ELAPSED TIME :{end - start}\")\n", - " return result" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2024-03-26T09:58:00.132126Z", - "start_time": "2024-03-26T09:57:59.501655Z" - } - }, - "id": "74ad1f08faa50a70", - "execution_count": 13 - }, { "cell_type": "code", "outputs": [], @@ -135,12 +91,12 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-26T09:58:00.137581Z", - "start_time": "2024-03-26T09:58:00.133237Z" + "end_time": "2024-05-14T14:15:40.066746Z", + "start_time": "2024-05-14T14:15:40.059808Z" } }, "id": "19c03920fae6aab8", - "execution_count": 14 + "execution_count": 3 }, { "cell_type": "code", @@ -149,13 +105,13 @@ "name": "stdout", "output_type": "stream", "text": [ - " ../data/aircraft-data_nov_dec.csv : reg_state\n", " ../data/aircraft-data_nov_dec.csv : reg_city\n", + " ../data/aircraft-data_nov_dec.csv : reg_state\n", " ../data/aircraft-data_nov_dec.csv : tail_number\n", " ../data/aircraft-data_nov_dec.csv : flight\n", " ../data/aircraft-data_nov_dec.csv : reg_expiration\n", - " ../data/aircraft-data_nov_dec.csv : manufacturer\n", " ../data/aircraft-data_nov_dec.csv : reg_owner\n", + " ../data/aircraft-data_nov_dec.csv : manufacturer\n", " ../data/aircraft-data_nov_dec.csv : model\n" ] }, @@ -171,38 +127,38 @@ "name": "stdout", "output_type": "stream", "text": [ - " ../data/Airplane_Cleaned.csv : Multi Engine\n", " ../data/Airplane_Cleaned.csv : TP mods\n", + " ../data/Airplane_Cleaned.csv : Multi Engine\n", " ../data/Airplane_Cleaned.csv : Engine Type\n", - " ../data/Airplane_Cleaned.csv : Model\n", " ../data/Airplane_Cleaned.csv : Company\n", + " ../data/Airplane_Cleaned.csv : Model\n", " ../data/autoscout24-germany-dataset.csv : make\n", " ../data/autoscout24-germany-dataset.csv : gear\n", " ../data/autoscout24-germany-dataset.csv : model\n", - " ../data/autoscout24-germany-dataset.csv : fuel\n", " ../data/autoscout24-germany-dataset.csv : offerType\n", + " ../data/autoscout24-germany-dataset.csv : fuel\n", " ../data/CARS_1.csv : fuel_type\n", " ../data/CARS_1.csv : transmission_type\n", - " ../data/CARS_1.csv : body_type\n", " ../data/CARS_1.csv : car_name\n", - " ../data/USA_cars_datasets.csv : country\n", + " ../data/CARS_1.csv : body_type\n", " ../data/USA_cars_datasets.csv : model\n", - " ../data/USA_cars_datasets.csv : vin\n", " ../data/USA_cars_datasets.csv : brand\n", - " ../data/USA_cars_datasets.csv : condition\n", + " ../data/USA_cars_datasets.csv : country\n", + " ../data/USA_cars_datasets.csv : vin\n", " ../data/USA_cars_datasets.csv : title_status\n", + " ../data/USA_cars_datasets.csv : condition\n", " ../data/USA_cars_datasets.csv : state\n", " ../data/USA_cars_datasets.csv : color\n", - " ../data/imdb_top_1000.csv : Certificate\n", " ../data/imdb_top_1000.csv : Poster_Link\n", " ../data/imdb_top_1000.csv : Gross\n", - " ../data/imdb_top_1000.csv : Director\n", + " ../data/imdb_top_1000.csv : Certificate\n", + " ../data/imdb_top_1000.csv : Series_Title\n", " ../data/imdb_top_1000.csv : Star3\n", + " ../data/imdb_top_1000.csv : Director\n", " ../data/imdb_top_1000.csv : Star2\n", " ../data/imdb_top_1000.csv : Star1\n", - " ../data/imdb_top_1000.csv : Overview\n", " ../data/imdb_top_1000.csv : Star4\n", - " ../data/imdb_top_1000.csv : Series_Title\n", + " ../data/imdb_top_1000.csv : Overview\n", " ../data/imdb_top_1000.csv : Genre\n" ] }, @@ -219,31 +175,33 @@ "output_type": "stream", "text": [ " ../data/netflix_titles.csv : show_id\n", - " ../data/netflix_titles.csv : cast\n", - " ../data/netflix_titles.csv : description\n", " ../data/netflix_titles.csv : title\n", + " ../data/netflix_titles.csv : description\n", " ../data/netflix_titles.csv : director\n", - " ../data/netflix_titles.csv : type\n", + " ../data/netflix_titles.csv : cast\n", " ../data/netflix_titles.csv : country\n", - " ../data/netflix_titles.csv : listed_in\n", " ../data/netflix_titles.csv : rating\n", + " ../data/netflix_titles.csv : type\n", " ../data/netflix_titles.csv : duration\n", + " ../data/netflix_titles.csv : listed_in\n", " ../data/netflix_titles.csv : date_added\n" ] } ], "source": [ - "data = get_data()" + "from column2Vec.functions import get_nonnumerical_data\n", + "\n", + "data = get_nonnumerical_data(files)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-26T09:58:32.212781Z", - "start_time": "2024-03-26T09:58:00.137581Z" + "end_time": "2024-05-14T14:16:20.221739Z", + "start_time": "2024-05-14T14:15:40.068761Z" } }, "id": "cfe57003e670ba15", - "execution_count": 15 + "execution_count": 4 }, { "cell_type": "code", @@ -288,74 +246,46 @@ "name": "stdout", "output_type": "stream", "text": [ - "Processing column: reg_state1 1.92%\n", - "Processing column: reg_city1 3.85%\n", - "Processing column: tail_number1 5.77%\n", - "Processing column: flight1 7.69%\n", - "Processing column: reg_expiration1 9.62%\n", - "Processing column: manufacturer1 11.54%\n", - "Processing column: reg_owner1 13.46%\n", - "Processing column: model1 15.38%\n", - "Processing column: Multi Engine2 17.31%\n", - "Processing column: TP mods2 19.23%\n", - "Processing column: Engine Type2 21.15%\n", - "Processing column: Model2 23.08%\n", - "Processing column: Company2 25.0%\n", - "Processing column: make3 26.92%\n", - "Processing column: gear3 28.85%\n", - "Processing column: model3 30.77%\n", - "Processing column: fuel3 32.69%\n", - "Processing column: offerType3 34.62%\n", - "Processing column: fuel_type4 36.54%\n", - "Processing column: transmission_type4 38.46%\n", - "Processing column: body_type4 40.38%\n", - "Processing column: car_name4 42.31%\n", - "Processing column: country5 44.23%\n", - "Processing column: model5 46.15%\n", - "Processing column: vin5 48.08%\n", - "Processing column: brand5 50.0%\n", - "Processing column: condition5 51.92%\n", - "Processing column: title_status5 53.85%\n", - "Processing column: state5 55.77%\n", - "Processing column: color5 57.69%\n", - "Processing column: Certificate6 59.62%\n", - "Processing column: Poster_Link6 61.54%\n", - "Processing column: Gross6 63.46%\n", - "Processing column: Director6 65.38%\n", - "Processing column: Star36 67.31%\n", - "Processing column: Star26 69.23%\n", - "Processing column: Star16 71.15%\n", - "Processing column: Overview6 73.08%\n", - "Processing column: Star46 75.0%\n", - "Processing column: Series_Title6 76.92%\n", - "Processing column: Genre6 78.85%\n", - "Processing column: show_id7 80.77%\n", - "Processing column: cast7 82.69%\n", - "Processing column: description7 84.62%\n", - "Processing column: title7 86.54%\n", - "Processing column: director7 88.46%\n", - "Processing column: type7 90.38%\n", - "Processing column: country7 92.31%\n", - "Processing column: listed_in7 94.23%\n", - "Processing column: rating7 96.15%\n", - "Processing column: duration7 98.08%\n", - "Processing column: date_added7 100.0%\n", - "ELAPSED TIME :549.7098529338837\n" + "Processing column: reg_city1 1.92%\n" + ] + }, + { + "ename": "ValueError", + "evalue": "invalid literal for int() with base 10: 'column2vec_avg'", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[1;32mIn[5], line 1\u001B[0m\n\u001B[1;32m----> 1\u001B[0m vectors_avg \u001B[38;5;241m=\u001B[39m \u001B[43mget_vectors\u001B[49m\u001B[43m(\u001B[49m\u001B[43mcolumn2vec_avg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mdata\u001B[49m\u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\Desktop\\thesis\\simillarity\\column2Vec\\functions.py:61\u001B[0m, in \u001B[0;36mget_vectors\u001B[1;34m(function, data)\u001B[0m\n\u001B[0;32m 59\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m key \u001B[38;5;129;01min\u001B[39;00m data:\n\u001B[0;32m 60\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mProcessing column: \u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;241m+\u001B[39m key \u001B[38;5;241m+\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m \u001B[39m\u001B[38;5;124m\"\u001B[39m \u001B[38;5;241m+\u001B[39m \u001B[38;5;28mstr\u001B[39m(\u001B[38;5;28mround\u001B[39m((count \u001B[38;5;241m/\u001B[39m \u001B[38;5;28mlen\u001B[39m(data)) \u001B[38;5;241m*\u001B[39m \u001B[38;5;241m100\u001B[39m, \u001B[38;5;241m2\u001B[39m)) \u001B[38;5;241m+\u001B[39m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m%\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m---> 61\u001B[0m result[key] \u001B[38;5;241m=\u001B[39m \u001B[43mfunction\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdata\u001B[49m\u001B[43m[\u001B[49m\u001B[43mkey\u001B[49m\u001B[43m]\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrained_model\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mget_module\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mkey\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 62\u001B[0m count \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[0;32m 63\u001B[0m end \u001B[38;5;241m=\u001B[39m time\u001B[38;5;241m.\u001B[39mtime()\n", + "File \u001B[1;32m~\\Desktop\\thesis\\simillarity\\column2Vec\\Column2Vec.py:159\u001B[0m, in \u001B[0;36mcolumn2vec_avg\u001B[1;34m(column, model, key)\u001B[0m\n\u001B[0;32m 157\u001B[0m encoded_columns \u001B[38;5;241m=\u001B[39m model\u001B[38;5;241m.\u001B[39mencode(column_clean)\n\u001B[0;32m 158\u001B[0m to_ret \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mmean(encoded_columns, axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m) \u001B[38;5;66;03m# counts arithmetic mean (average)\u001B[39;00m\n\u001B[1;32m--> 159\u001B[0m \u001B[43mcache\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msave\u001B[49m\u001B[43m(\u001B[49m\u001B[43mkey\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfunction_string\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mto_ret\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 160\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m to_ret\n", + "File \u001B[1;32m~\\Desktop\\thesis\\simillarity\\column2Vec\\Column2Vec.py:46\u001B[0m, in \u001B[0;36mCache.save\u001B[1;34m(self, key, function, embedding)\u001B[0m\n\u001B[0;32m 39\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21msave\u001B[39m(\u001B[38;5;28mself\u001B[39m, key: \u001B[38;5;28mstr\u001B[39m, function: \u001B[38;5;28mstr\u001B[39m, embedding: \u001B[38;5;28mlist\u001B[39m):\n\u001B[0;32m 40\u001B[0m \u001B[38;5;250m \u001B[39m\u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 41\u001B[0m \u001B[38;5;124;03m Saves cache\u001B[39;00m\n\u001B[0;32m 42\u001B[0m \u001B[38;5;124;03m :param key: Column name\u001B[39;00m\n\u001B[0;32m 43\u001B[0m \u001B[38;5;124;03m :param function: Function name\u001B[39;00m\n\u001B[0;32m 44\u001B[0m \u001B[38;5;124;03m :param embedding: to save\u001B[39;00m\n\u001B[0;32m 45\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n\u001B[1;32m---> 46\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m|\u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28;43mint\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfunction\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m| : |\u001B[39m\u001B[38;5;132;01m{\u001B[39;00m\u001B[38;5;28mint\u001B[39m(key)\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m|\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 47\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m__cache\u001B[38;5;241m.\u001B[39mat[function, key] \u001B[38;5;241m=\u001B[39m embedding\n", + "\u001B[1;31mValueError\u001B[0m: invalid literal for int() with base 10: 'column2vec_avg'" ] } ], "source": [ + "\n", + "\n", "vectors_avg = get_vectors(column2vec_avg, data)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2024-03-26T10:07:41.940696Z", - "start_time": "2024-03-26T09:58:32.218307Z" + "end_time": "2024-05-14T14:16:23.159447Z", + "start_time": "2024-05-14T14:16:20.222742Z" } }, "id": "d18443a1c921f509", - "execution_count": 17 + "execution_count": 5 + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "collapsed": false + }, + "id": "826f4adb5cef5607" }, { "cell_type": "code", diff --git a/constants.py b/constants.py index 835fc12..0d9e6d7 100644 --- a/constants.py +++ b/constants.py @@ -1,6 +1,7 @@ """ This file contains constants """ +from sentence_transformers import SentenceTransformer class WarningEnable: @@ -37,5 +38,25 @@ def get_timezone(self): """ return self.__timezone +class TrainedModel: + """ + Class encapsulating trained module + """ + __model = SentenceTransformer('bert-base-nli-mean-tokens') + + def set_module(self, model: SentenceTransformer): + """ + Sets __model + :param model: to be set + """ + self.__model = model + + def get_module(self) -> SentenceTransformer: + """ + :return: __module + """ + return self.__model + warning_enable = WarningEnable() +trained_model = TrainedModel() diff --git a/requirements.txt b/requirements.txt index a3ffea0eb0eb48eced4c6fed3e655e35c2b44aab..a6f3575d105fb31df64f7795d91ee555f8f3851d 100644 GIT binary patch delta 40 scmbQEv`1-!jGzK90~bR9Lk>eeLkW