Skip to content

Commit

Permalink
Fix bugs and improve cache
Browse files Browse the repository at this point in the history
  • Loading branch information
Olivie Franklova (CZ) authored and Olivie Franklova (CZ) committed May 28, 2024
1 parent 4abcd30 commit 58bc133
Showing 1 changed file with 30 additions and 5 deletions.
35 changes: 30 additions & 5 deletions column2Vec/Column2Vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ class Cache:
__cache = pd.DataFrame()
__read_from_file = False
__on = True
__file = "generated/cache.txt"

def __read(self):
try:
self.__cache = pd.io.parsers.read_csv("generated/cache.txt", index_col=0)
self.__cache = pd.io.parsers.read_csv(self.__file, index_col=0)
except Exception as error:
pass

Expand Down Expand Up @@ -61,13 +62,25 @@ def save_persistently(self):
return
print(self.__cache.index)
print(self.__cache.columns)
self.__cache.to_csv("generated/cache.txt", index=True)
self.__cache.to_csv(self.__file, index=True)

def off(self):
self.__on = False
def on(self):
self.__on = True

def set_file(self, file: str):
self.__file = file

def clear_cache(self):
self.__cache = self.__cache[0:0]
self.__read_from_file = False
def clear_persistent_cache(self):
try:
open(self.__file, 'w').close()
except FileNotFoundError as e:
print(e)


cache = Cache()

Expand Down Expand Up @@ -147,7 +160,7 @@ def weighted_create_embed(column: pd.Series, model: SentenceTransformer, key: st
"""
res = cache.get_cache(key, function_string)
if res is not None:
return res
return res, None

uniq_column = column.value_counts(normalize=True)
weights = uniq_column.values
Expand Down Expand Up @@ -182,8 +195,18 @@ def column2vec_weighted_avg(column: pd.Series, model: SentenceTransformer, key:
Convert each item in the column to a vector and return the weighted average of all the vectors
"""
function_string = "column2vec_weighted_avg"
encoded_columns, weights = weighted_create_embed(column, model, key, function_string)
to_ret = np.average(encoded_columns, axis=0, weights=weights) # counts weighted average
res = cache.get_cache(key, function_string)
if res is not None:
return res
uniq_column = column.value_counts(normalize=True)
weights = uniq_column.values
column_clean = pd.Series(uniq_column.keys()).apply(lambda x: re.sub("[^(0-9 |a-z)]",
" ", str(x).lower())).values
res = model.encode(column_clean)


# encoded_columns, weights = weighted_create_embed(column, model, key, function_string)
to_ret = np.average(res, axis=0, weights=weights) # counts weighted average
cache.save(key, function_string, to_ret)
return to_ret

Expand Down Expand Up @@ -216,6 +239,8 @@ def column2vec_weighted_sum(column: pd.Series, model: SentenceTransformer, key:
"""
function_string = "column2vec_weighted_sum"
encoded_columns, weights = weighted_create_embed(column, model, key, function_string)
if weights is None:
return encoded_columns
to_ret = 0
for number, weight in zip(encoded_columns, weights):
to_ret += number * weight
Expand Down

0 comments on commit 58bc133

Please sign in to comment.