Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
foamliu committed Oct 12, 2018
1 parent f3c5b75 commit 8ce4478
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion demo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# import the necessary packages
import os
import random

import jieba
import pandas as pd
import torch

from config import device, save_folder, test_a_folder, test_a_filename
from data_gen import parse_user_reviews, batch2TrainData
from utils import *
from utils import Lang, encode_text

if __name__ == '__main__':
voc = Lang('data/WORDMAP.json')
Expand All @@ -24,6 +27,8 @@
# Set dropout layers to eval mode
encoder.eval()

filename = os.path.join(test_a_folder, test_a_filename)
user_reviews = pd.read_csv(filename)
samples = parse_user_reviews('test_a')

samples = random.sample(samples, 10)
Expand Down
7 changes: 6 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import time

import numpy as np
import torch
from torch import nn
from torch import optim

from config import device, label_names, print_every, hidden_size, encoder_n_layers, dropout, learning_rate, start_epoch, \
epochs
from data_gen import SaDataset
from models import EncoderRNN
from utils import *
from utils import AverageMeter, ExpoAverageMeter, accuracy, Lang, timestamp, adjust_learning_rate, save_checkpoint


def train(epoch, train_data, encoder, optimizer):
Expand Down

0 comments on commit 8ce4478

Please sign in to comment.