Skip to content

Commit

Permalink
refine tests a bit and accept strings
Browse files Browse the repository at this point in the history
Fix #20
  • Loading branch information
robertleeplummerjr committed Jan 10, 2017
1 parent f823d93 commit b95855e
Show file tree
Hide file tree
Showing 14 changed files with 201 additions and 200 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
![](logo.png)
# brain

[![Gitter](https://badges.gitter.im/Join Chat.svg)](https://gitter.im/harthur/brain?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
Expand Down
65 changes: 44 additions & 21 deletions browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -2391,10 +2391,6 @@ Object.defineProperty(exports, "__esModule", {

var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();

var _lookup = require('../lookup');

var _lookup2 = _interopRequireDefault(_lookup);

var _matrix = require('./matrix');

var _matrix2 = _interopRequireDefault(_matrix);
Expand Down Expand Up @@ -2452,7 +2448,6 @@ var RNN = function () {

this.stepCache = {};
this.runs = 0;
this.totalPerplexity = null;
this.totalCost = null;
this.ratioClipped = null;
this.model = null;
Expand Down Expand Up @@ -2621,18 +2616,18 @@ var RNN = function () {
*
* @param {Number[]} input
* @param {Number} [learningRate]
* @returns {*}
* @returns {number}
*/

}, {
key: 'trainPattern',
value: function trainPattern(input) {
var learningRate = arguments.length <= 1 || arguments[1] === undefined ? null : arguments[1];

var err = this.runInput(input);
var error = this.runInput(input);
this.runBackpropagate(input);
this.step(learningRate);
return err;
return error;
}

/**
Expand All @@ -2649,15 +2644,16 @@ var RNN = function () {
var max = input.length;
var log2ppl = 0;
var cost = 0;

var error = 0;
var equation = void 0;
while (model.equations.length <= input.length + 1) {
//first and last are zeros
this.bindEquation();
}
for (var inputIndex = -1, inputMax = input.length; inputIndex < inputMax; inputIndex++) {
// start and end tokens are zeros
equation = model.equations[inputIndex + 1];
var equationIndex = inputIndex + 1;
equation = model.equations[equationIndex];

var source = inputIndex === -1 ? 0 : input[inputIndex] + 1; // first step: start with START token
var target = inputIndex === max - 1 ? 0 : input[inputIndex + 1] + 1; // last step: end with END token
Expand All @@ -2668,14 +2664,13 @@ var RNN = function () {

log2ppl += -Math.log2(probabilities.weights[target]); // accumulate base 2 log prob and do smoothing
cost += -Math.log(probabilities.weights[target]);

// write gradients into log probabilities
logProbabilities.recurrence = probabilities.weights;
logProbabilities.recurrence[target] -= 1;
}

this.totalCost = cost;
return this.totalPerplexity = Math.pow(2, log2ppl / (max - 1));
return Math.pow(2, log2ppl / (max - 1));
}

/**
Expand Down Expand Up @@ -2832,7 +2827,7 @@ var RNN = function () {

/**
*
* @param {Object[]} data a collection of objects: `{input: 'string', output: 'string'}`
* @param {Object[]|String[]} data an array of objects: `{input: 'string', output: 'string'}` or an array of strings
* @param {Object} [options]
* @returns {{error: number, iterations: number}}
*/
Expand Down Expand Up @@ -3103,21 +3098,43 @@ RNN.defaults = {
regc: 0.000001,
clipval: 5,
json: null,
/**
*
* @param {*[]} data
* @returns {Number[]}
*/
setupData: function setupData(data) {
if (!data[0].hasOwnProperty('input') || !data[0].hasOwnProperty('output')) {
if (typeof data[0] !== 'string' && !Array.isArray(data[0]) && (!data[0].hasOwnProperty('input') || !data[0].hasOwnProperty('output'))) {
return data;
}
var values = [];
for (var i = 0; i < data.length; i++) {
values = values.concat(data[i].input, data[i].output);
}
this.vocab = _vocab2.default.fromArrayInputOutput(values);
var result = [];
for (var _i2 = 0, max = data.length; _i2 < max; _i2++) {
result.push(this.formatDataIn(data[_i2].input, data[_i2].output));
if (typeof data[0] === 'string' || Array.isArray(data[0])) {
for (var i = 0; i < data.length; i++) {
values = values.concat(data[i]);
}
this.vocab = new _vocab2.default(values);

for (var _i2 = 0, max = data.length; _i2 < max; _i2++) {
result.push(this.formatDataIn(data[_i2]));
}
} else {
for (var _i3 = 0; _i3 < data.length; _i3++) {
values = values.concat(data[_i3].input, data[_i3].output);
}
this.vocab = _vocab2.default.fromArrayInputOutput(values);
for (var _i4 = 0, _max = data.length; _i4 < _max; _i4++) {
result.push(this.formatDataIn(data[_i4].input, data[_i4].output));
}
}
return result;
},
/**
*
* @param {*[]} input
* @param {*[]} output
* @returns {Number[]}
*/
formatDataIn: function formatDataIn(input) {
var output = arguments.length <= 1 || arguments[1] === undefined ? null : arguments[1];

Expand All @@ -3130,6 +3147,12 @@ RNN.defaults = {
}
return input;
},
/**
*
* @param {Number[]} input
* @param {Number[]} output
* @returns {*}
*/
formatDataOut: function formatDataOut(input, output) {
if (this.vocab !== null) {
return this.vocab.toCharacters(output).join('');
Expand All @@ -3150,7 +3173,7 @@ RNN.trainDefaults = {
keepNetworkIntact: false
};

},{"../lookup":3,"../utilities/random":37,"../utilities/vocab":41,"../utilities/zeros":42,"./matrix":13,"./matrix/copy":11,"./matrix/equation":12,"./matrix/max-i":14,"./matrix/random-matrix":20,"./matrix/sample-i":25,"./matrix/softmax":28}],32:[function(require,module,exports){
},{"../utilities/random":37,"../utilities/vocab":41,"../utilities/zeros":42,"./matrix":13,"./matrix/copy":11,"./matrix/equation":12,"./matrix/max-i":14,"./matrix/random-matrix":20,"./matrix/sample-i":25,"./matrix/softmax":28}],32:[function(require,module,exports){
'use strict';

Object.defineProperty(exports, "__esModule", {
Expand Down
8 changes: 4 additions & 4 deletions browser.min.js

Large diffs are not rendered by default.

63 changes: 43 additions & 20 deletions dist/recurrent/rnn.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/recurrent/rnn.js.map

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions examples/html/lstm-math/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,14 @@
window.onload = function() {
var canvas = document.getElementById('canvas');
var ctx = canvas.getContext('2d');
var perplexities = [];
var errors = [];
var costs = [];
var labels = [];
var chart = new Chart.Line(ctx, {
data: {
labels: labels,
datasets: [{
data: perplexities,
data: errors,
backgroundColor: 'rgba(255, 99, 132, 0.25)',
borderColor: 'rgba(255,99,132,1)',
borderWidth: 2,
Expand All @@ -190,7 +190,7 @@

setInterval(function () {
labels.push(new Date());
perplexities.push(net.totalPerplexity);
errors.push(net.error);
//chart.addData(new Date(), , net.totalCost);
chart.update();
}, 1000);
Expand Down
Binary file added logo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit b95855e

Please sign in to comment.