Skip to content

Commit

Permalink
Update populate_db.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bezoar17 authored Feb 17, 2017
1 parent d100c78 commit 99f7cbb
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def populate_products():
# get products in the product_values list
global productids
reader = csv.DictReader(open('db/db_products.csv'))
product_values=list()
product_values=list() # local variable
for row in reader:
product_values.append((row['product_name'],row['persona'],row['trending']))

Expand All @@ -46,19 +46,18 @@ def populate_products():

def populate_trts(train_set_size):

# populate the train and test db ,the user info and their likes go in 80:20 fashion
global productids,useremailtoid,train_userids_back,test_userids_back

reader = csv.DictReader(open('db/db_userinfo.csv'))
all_users=list()
all_users=list() # list of tuple(id,email,persona) of user-info

for row in reader:
all_users.append((row['user_id'],row['user_email'],row['persona']))
useremailtoid[row['user_email']]=row['user_id']

random.shuffle(all_users) # comment this line for a fixed simulation of first 15 in training and last 5 in test set.
train_userids=[i[0] for i in all_users[0:train_set_size]]
test_userids=[i[0] for i in all_users[train_set_size:]]
random.shuffle(all_users) # randomize before parting the training and testing sets
train_userids=[i[0] for i in all_users[0:train_set_size]] # list of original user id's in training set
test_userids=[i[0] for i in all_users[train_set_size:]] # list of original user id's in training set

insert_many_to_db('train.db','INSERT INTO user_info_table(email_id,persona) VALUES (?,?)',[i[1:] for i in all_users[0:train_set_size]])
insert_many_to_db('test.db','INSERT INTO user_info_table(email_id,persona) VALUES (?,?)', [i[1:] for i in all_users[train_set_size:]])
Expand All @@ -80,10 +79,10 @@ def populate_trts(train_set_size):
db.close()

reader = csv.DictReader(open('db/db_userinputs.csv'))
train_userinputs=list()
test_userinputs=list()
train_userinputs=list() # list of user inputs to go in training set
test_userinputs=list() # list of user inputs to go in testing set

for row in reader: # build the user inputs list to add to train and test db
for row in reader: # build the user inputs list to add to train and test db
if row['user_id'] in train_userids:
train_userinputs.append((train_userids_back[row['user_id']],productids[row['product_name']],row['user_input']))
elif row['user_id'] in test_userids:
Expand Down

0 comments on commit 99f7cbb

Please sign in to comment.