From e984c3b58f7672486b17da3e5b8b8ff4e6f86605 Mon Sep 17 00:00:00 2001 From: "dhsig552@163.com" Date: Wed, 6 Sep 2017 16:54:36 +0100 Subject: [PATCH] add comments --- main_simple_seq2seq.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/main_simple_seq2seq.py b/main_simple_seq2seq.py index 9c78d53e..93dc9d96 100644 --- a/main_simple_seq2seq.py +++ b/main_simple_seq2seq.py @@ -1,13 +1,11 @@ #! /usr/bin/python # -*- coding: utf8 -*- - """Sequence to Sequence Learning for Twitter/Cornell Chatbot. References ---------- http://suriyadeepan.github.io/2016-12-31-practical-seq2seq/ """ - import tensorflow as tf import tensorlayer as tl from tensorlayer.layers import * @@ -16,7 +14,7 @@ import numpy as np import time -## select dataset +###============= prepare data from data.twitter import data metadata, idx_q, idx_a = data.load_data(PATH='data/twitter/') # Twitter # from data.cornell_corpus import data @@ -37,7 +35,7 @@ validX = tl.prepro.remove_pad_sequences(validX) validY = tl.prepro.remove_pad_sequences(validY) -## parameters +###============= parameters xseq_len = len(trainX)#.shape[-1] yseq_len = len(trainY)#.shape[-1] assert xseq_len == yseq_len @@ -83,7 +81,7 @@ print(len(target_seqs), len(decode_seqs), len(target_mask)) # exit() -## model +###============= model def model(encode_seqs, decode_seqs, is_train=True, reuse=False): with tf.variable_scope("model", reuse=reuse): # for chatbot, you can use the same embedding layer, @@ -115,18 +113,20 @@ def model(encode_seqs, decode_seqs, is_train=True, reuse=False): net_out = DenseLayer(net_rnn, n_units=xvocab_size, act=tf.identity, name='output') return net_out, net_rnn +# model for training encode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="encode_seqs") decode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="decode_seqs") target_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_seqs") target_mask = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_mask") # tl.prepro.sequences_get_mask() - net_out, _ = model(encode_seqs, decode_seqs, is_train=True, reuse=False) +# model for inferencing encode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name="encode_seqs") decode_seqs2 = tf.placeholder(dtype=tf.int64, shape=[1, None], name="decode_seqs") net, net_rnn = model(encode_seqs2, decode_seqs2, is_train=False, reuse=True) y = tf.nn.softmax(net.outputs) +# loss for training # print(net_out.outputs) # (?, 8004) # print(target_seqs) # (32, ?) # loss_weights = tf.ones_like(target_seqs, dtype=tf.float32) @@ -148,7 +148,7 @@ def model(encode_seqs, decode_seqs, is_train=True, reuse=False): tl.layers.initialize_global_variables(sess) tl.files.load_and_assign_npz(sess=sess, name='n.npz', network=net) -## train +###============= train n_epoch = 50 for epoch in range(n_epoch): epoch_time = time.time() @@ -189,7 +189,7 @@ def model(encode_seqs, decode_seqs, is_train=True, reuse=False): total_err += err; n_iter += 1 - ## inference + ###============= inference if n_iter % 1000 == 0: seeds = ["happy birthday have a nice day", "donald trump won last nights presidential debate according to snap online polls"]