-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathload-csv.py
61 lines (54 loc) · 2.09 KB
/
load-csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st #used here just for secrets
import csv
import os
from redis import Redis
import numpy as np
from sentence_transformers import SentenceTransformer
from redis.commands.search.query import Query
from redis.commands.search.field import (
NumericField,
TagField,
TextField,
VectorField,
)
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
host = os.getenv('REDIS_HOST', default = 'localhost')
port = os.getenv('REDIS_PORT', default = 6379)
pwd = os.getenv('REDIS_PWD', default = '')
redis = Redis(host=host, port=port, password=pwd)
#redis = Redis(host='redis-10688.c73.us-east-1-2.ec2.cloud.redislabs.com', port=10688, password='E0wIuvKswurEcRmzXonkQ93vm1T9SPN5')
redis.flushdb()
# Load the machine learning model
model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
with open('./Labelled Tweets.csv', newline='') as csvfile:
csvreader = csv.reader(csvfile)
tweethash={}
for tweet in csvreader:
#p = redis.pipeline(transaction=False)
keyname = "tweet:{}".format(tweet[0])
#del tweet["specs"]
tweethash["text"]=tweet[2]
tweethash["text_embeddings"] = model.encode(tweet[2]).astype(np.float32).tobytes()
redis.hset(keyname, mapping=tweethash)
#p.hset(keyname, mapping=tweethash)
#p.execute()
# Create an index
indexDefinition = IndexDefinition(
prefix=["tweet:"],
index_type=IndexType.HASH,
)
number_of_vectors=2000
redis.ft("tweet:idx").create_index(
(
TextField("text", no_stem=True, sortable=True),
#TagField("symbol"),
#NumericField("price", sortable=True),
VectorField("text_embeddings", "FLAT", { "TYPE": "FLOAT32",
"DIM": 768,
"DISTANCE_METRIC": "COSINE",
"INITIAL_CAP": number_of_vectors,
"BLOCK_SIZE": number_of_vectors
})
),
definition=indexDefinition
)