forked from karpathy/reinforcejs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpuckworld.html
636 lines (527 loc) · 29.5 KB
/
puckworld.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
<title>PuckWorld</title>
<meta name="description" content="">
<meta name="author" content="">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<!-- jquery and jqueryui -->
<script src="https://code.jquery.com/jquery-2.1.3.min.js"></script>
<link href="external/jquery-ui.min.css" rel="stylesheet">
<script src="external/jquery-ui.min.js"></script>
<!-- bootstrap -->
<script src="http://maxcdn.bootstrapcdn.com/bootstrap/3.3.4/js/bootstrap.min.js"></script>
<link href="http://maxcdn.bootstrapcdn.com/bootstrap/3.3.4/css/bootstrap.min.css" rel="stylesheet">
<!-- d3js -->
<script type="text/javascript" src="external/d3.min.js"></script>
<!-- markdown -->
<script type="text/javascript" src="external/marked.js"></script>
<script type="text/javascript" src="external/highlight.pack.js"></script>
<link rel="stylesheet" href="external/highlight_default.css">
<script>hljs.initHighlightingOnLoad();</script>
<!-- mathjax: nvm now loaded dynamically
<script type="text/javascript" src="https://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
-->
<!-- rljs -->
<script type="text/javascript" src="lib/rl.js"></script>
<!-- flotjs -->
<script src="external/jquery.flot.min.js"></script>
<!-- GA -->
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','//www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-3698471-24', 'auto');
ga('send', 'pageview');
</script>
<style>
#wrap {
width:800px;
margin-left: auto;
margin-right: auto;
}
h2 {
text-align: center;
}
body {
font-family: Arial, "Helvetica Neue", Helvetica, sans-serif;
}
#draw {
margin-left: 100px;
}
svg {
border: 1px solid black;
}
</style>
<script type="application/javascript">
var W = 600, H = 600;
var d3line = null;
var d3agent = null;
var d3target = null;
var d3target2 = null;
var d3target2_radius = null;
var initDraw = function() {
var d3elt = d3.select('#draw');
d3elt.html('');
var w = 600;
var h = 600;
svg = d3elt.append('svg').attr('width', w).attr('height', h)
.append('g').attr('transform', 'scale(1)');
// define a marker for drawing arrowheads
svg.append("defs").append("marker")
.attr("id", "arrowhead")
.attr("refX", 3)
.attr("refY", 2)
.attr("markerWidth", 3)
.attr("markerHeight", 4)
.attr("orient", "auto")
.append("path")
.attr("d", "M 0,0 V 4 L3,2 Z");
// draw the puck
d3agent = svg.append('circle')
.attr('cx', 100)
.attr('cy', 100)
.attr('r', env.rad * this.W)
.attr('fill', '#FF0')
.attr('stroke', '#000')
.attr('id', 'puck');
// draw the target
d3target = svg.append('circle')
.attr('cx', 200)
.attr('cy', 200)
.attr('r', 10)
.attr('fill', '#0F0')
.attr('stroke', '#000')
.attr('id', 'target');
// bad target
d3target2 = svg.append('circle')
.attr('cx', 300)
.attr('cy', 300)
.attr('r', 10)
.attr('fill', '#F00')
.attr('stroke', '#000')
.attr('id', 'target2');
d3target2_radius = svg.append('circle')
.attr('cx', 300)
.attr('cy', 300)
.attr('r', 10)
.attr('fill', 'rgba(255,0,0,0.1)')
.attr('stroke', '#000');
// draw line indicating forces
d3line = svg.append('line')
.attr('x1', 0)
.attr('y1', 0)
.attr('x2', 0)
.attr('y2', 0)
.attr('stroke', 'black')
.attr('stroke-width', '2')
.attr("marker-end", "url(#arrowhead)");
}
var updateDraw = function(a, s, r) {
// reflect puck world state on screen
var ppx = env.ppx; var ppy = env.ppy;
var tx = env.tx; var ty = env.ty;
var tx2 = env.tx2; var ty2 = env.ty2;
d3agent.attr('cx', ppx*W).attr('cy', ppy*H);
d3target.attr('cx', tx*W).attr('cy', ty*H);
d3target2.attr('cx', tx2*W).attr('cy', ty2*H);
d3target2_radius.attr('cx', tx2*W).attr('cy', ty2*H).attr('r', env.BADRAD * H);
d3line.attr('x1', ppx*W).attr('y1', ppy*H).attr('x2', ppx*W).attr('y2', ppy*H);
var af = 20;
d3line.attr('visibility', a === 4 ? 'hidden' : 'visible');
if(a === 0) {
d3line.attr('x2', ppx*W - af);
}
if(a === 1) {
d3line.attr('x2', ppx*W + af);
}
if(a === 2) {
d3line.attr('y2', ppy*H - af);
}
if(a === 3) {
d3line.attr('y2', ppy*H + af);
}
// color agent by reward
var vv = r + 0.5;
var ms = 255.0;
if(vv > 0) { g = 255; r = 255 - vv*ms; b = 255 - vv*ms; }
if(vv < 0) { g = 255 + vv*ms; r = 255; b = 255 + vv*ms; }
var vcol = 'rgb('+Math.floor(r)+','+Math.floor(g)+','+Math.floor(b)+')';
d3agent.attr('fill', vcol);
}
var PuckWorld = function() {
this.reset();
}
PuckWorld.prototype = {
reset: function() {
this.ppx = Math.random(); // puck x,y
this.ppy = Math.random();
this.pvx = Math.random()*0.05 -0.025; // velocity
this.pvy = Math.random()*0.05 -0.025;
this.tx = Math.random(); // target
this.ty = Math.random();
this.tx2 = Math.random(); // target
this.ty2 = Math.random(); // target
this.rad = 0.05;
this.t = 0;
this.BADRAD = 0.25;
},
getNumStates: function() {
return 8; // x,y,vx,vy, puck dx,dy
},
getMaxNumActions: function() {
return 5; // left, right, up, down, nothing
},
getState: function() {
var s = [this.ppx - 0.5, this.ppy - 0.5, this.pvx * 10, this.pvy * 10, this.tx-this.ppx, this.ty-this.ppy, this.tx2-this.ppx, this.ty2-this.ppy];
return s;
},
sampleNextState: function(a) {
// world dynamics
this.ppx += this.pvx; // newton
this.ppy += this.pvy;
this.pvx *= 0.95; // damping
this.pvy *= 0.95;
// agent action influences puck velocity
var accel = 0.002;
if(a === 0) this.pvx -= accel;
if(a === 1) this.pvx += accel;
if(a === 2) this.pvy -= accel;
if(a === 3) this.pvy += accel;
// handle boundary conditions and bounce
if(this.ppx < this.rad) {
this.pvx *= -0.5; // bounce!
this.ppx = this.rad;
}
if(this.ppx > 1 - this.rad) {
this.pvx *= -0.5;
this.ppx = 1 - this.rad;
}
if(this.ppy < this.rad) {
this.pvy *= -0.5; // bounce!
this.ppy = this.rad;
}
if(this.ppy > 1 - this.rad) {
this.pvy *= -0.5;
this.ppy = 1 - this.rad;
}
this.t += 1;
if(this.t % 100 === 0) {
this.tx = Math.random(); // reset the target location
this.ty = Math.random();
}
// if(this.t % 73 === 0) {
// this.tx2 = Math.random(); // reset the target location
// this.ty2 = Math.random();
// }
// compute distances
var dx = this.ppx - this.tx;
var dy = this.ppy - this.ty;
var d1 = Math.sqrt(dx*dx+dy*dy);
var dx = this.ppx - this.tx2;
var dy = this.ppy - this.ty2;
var d2 = Math.sqrt(dx*dx+dy*dy);
var dxnorm = dx/d2;
var dynorm = dy/d2;
var speed = 0.001;
this.tx2 += speed * dxnorm;
this.ty2 += speed * dynorm;
// compute reward
var r = -d1; // want to go close to green
if(d2 < this.BADRAD) {
// but if we're too close to red that's bad
r += 2*(d2 - this.BADRAD)/this.BADRAD;
}
//if(a === 4) r += 0.05; // give bonus for gliding with no force
// evolve state in time
var ns = this.getState();
var out = {'ns':ns, 'r':r};
return out;
}
}
function gofast() { steps_per_tick = 100; }
function gonormal() { steps_per_tick = 10; }
function goslow() { steps_per_tick = 1; }
// flot stuff
var nflot = 1000;
function initFlot() {
var container = $("#flotreward");
var res = getFlotRewards();
series = [{
data: res,
lines: {fill: true}
}];
var plot = $.plot(container, series, {
grid: {
borderWidth: 1,
minBorderMargin: 20,
labelMargin: 10,
backgroundColor: {
colors: ["#FFF", "#e4f4f4"]
},
margin: {
top: 10,
bottom: 10,
left: 10,
}
},
xaxis: {
min: 0,
max: nflot
},
yaxis: {
min: -2,
max: 1
}
});
setInterval(function(){
series[0].data = getFlotRewards();
plot.setData(series);
plot.draw();
}, 100);
}
function getFlotRewards() {
// zip rewards into flot data
var res = [];
for(var i=0,n=smooth_reward_history.length;i<n;i++) {
res.push([i, smooth_reward_history[i]]);
}
return res;
}
var steps_per_tick = 1;
var sid = -1;
var action, state;
var smooth_reward_history = [];
var smooth_reward = null;
var flott = 0;
function togglelearn() {
if(sid === -1) {
sid = setInterval(function() {
for(var k=0;k<steps_per_tick;k++) {
state = env.getState();
action = agent.act(state);
var obs = env.sampleNextState(action);
agent.learn(obs.r);
if(smooth_reward == null) { smooth_reward = obs.r; }
smooth_reward = smooth_reward * 0.999 + obs.r * 0.001;
flott += 1;
if(flott === 200) {
// record smooth reward
if(smooth_reward_history.length >= nflot) {
smooth_reward_history = smooth_reward_history.slice(1);
}
smooth_reward_history.push(smooth_reward);
flott = 0;
}
}
updateDraw(action, state, obs.r);
if(typeof agent.expi !== 'undefined') {
$("#expi").html(agent.expi);
}
if(typeof agent.tderror !== 'undefined') {
$("#tde").html(agent.tderror.toFixed(3));
}
//$("#tdest").html('tderror: ' + agent.tderror_estimator.getMean().toFixed(4) + ' +/- ' + agent.tderror_estimator.getStd().toFixed(4));
}, 20);
} else {
clearInterval(sid);
sid = -1;
}
}
function saveAgent() {
$("#mysterybox").fadeIn();
$("#mysterybox").val(JSON.stringify(agent.toJSON()));
}
function loadAgent() {
$.getJSON( "agentzoo/puckagent.json", function( data ) {
agent.fromJSON(data); // corss your fingers...
// set epsilon to be much lower for more optimal behavior
agent.epsilon = 0.05;
$("#slider").slider('value', agent.epsilon);
$("#eps").html(agent.epsilon.toFixed(2));
// kill learning rate to not learn
agent.alpha = 0;
});
}
function resetAgent() {
eval($("#agentspec").val())
agent = new RL.DQNAgent(env, spec);
}
var w; // global world object
var current_interval_id;
var skipdraw = false;
function start() {
env = new PuckWorld();
initDraw();
eval($("#agentspec").val())
//agent = new RL.ActorCritic(env, spec);
agent = new RL.DQNAgent(env, spec);
//agent = new RL.RecurrentReinforceAgent(env, {});
initFlot();
// slider sets agent epsilon
$("#slider").slider({
min: 0,
max: 1,
value: agent.epsilon,
step: 0.01,
slide: function(event, ui) {
agent.epsilon = ui.value;
$("#eps").html(ui.value.toFixed(2));
}
});
$("#eps").html(agent.epsilon.toFixed(2));
togglelearn(); // start
// render markdown
$(".md").each(function(){
$(this).html(marked($(this).html()));
});
renderJax();
}
var jaxrendered = false;
function renderJax() {
if(jaxrendered) { return; }
(function () {
var script = document.createElement("script");
script.type = "text/javascript";
script.src = "http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML";
document.getElementsByTagName("head")[0].appendChild(script);
jaxrendered = true;
})();
}
</script>
<style type="text/css">
canvas { border: 1px solid white; }
</style>
</head>
<body onload="start();">
<a href="https://github.com/karpathy/reinforcejs"><img style="position: absolute; top: 0; right: 0; border: 0;" src="https://camo.githubusercontent.com/38ef81f8aca64bb9a64448d0d70f1308ef5341ab/68747470733a2f2f73332e616d617a6f6e6177732e636f6d2f6769746875622f726962626f6e732f666f726b6d655f72696768745f6461726b626c75655f3132313632312e706e67" alt="Fork me on GitHub" data-canonical-src="https://s3.amazonaws.com/github/ribbons/forkme_right_darkblue_121621.png"></a>
<div id="wrap">
<div id="mynav" style="border-bottom:1px solid #999; padding-bottom: 10px; margin-bottom:50px;">
<div>
<img src="loop.svg" style="width:50px;height:50px;float:left;">
<h1 style="font-size:50px;">REINFORCE<span style="color:#058;">js</span></h1>
</div>
<ul class="nav nav-pills">
<li role="presentation"><a href="index.html">About</a></li>
<li role="presentation"><a href="gridworld_dp.html">GridWorld: DP</a></li>
<li role="presentation"><a href="gridworld_td.html">GridWorld: TD</a></li>
<li role="presentation" class="active"><a href="puckworld.html">PuckWorld: DQN</a></li>
<li role="presentation"><a href="waterworld.html">WaterWorld: DQN</a></li>
</ul>
</div>
<h2>PuckWorld: Deep Q Learning</h2>
<textarea id="agentspec" style="width:100%;height:230px;">
// agent parameter spec to play with (this gets eval()'d on Agent reset)
var spec = {}
spec.update = 'qlearn'; // qlearn | sarsa
spec.gamma = 0.9; // discount factor, [0, 1)
spec.epsilon = 0.2; // initial epsilon for epsilon-greedy policy, [0, 1)
spec.alpha = 0.01; // value function learning rate
spec.experience_add_every = 10; // number of time steps before we add another experience to replay memory
spec.experience_size = 5000; // size of experience replay memory
spec.learning_steps_per_iteration = 20;
spec.tderror_clamp = 1.0; // for robustness
spec.num_hidden_units = 100 // number of neurons in hidden layer
</textarea>
<button class="btn btn-danger" onclick="resetAgent()" style="width:150px;height:50px;margin-bottom:5px;">Reinit agent</button>
<button class="btn btn-primary" onclick="togglelearn()" style="width:150px;height:50px;margin-bottom:5px;">Toggle Run</button>
<button class="btn btn-success" onclick="gofast()" style="width:150px;height:50px;margin-bottom:5px;">Go fast</button>
<button class="btn btn-success" onclick="gonormal()" style="width:150px;height:50px;margin-bottom:5px;">Go normal</button>
<button class="btn btn-success" onclick="goslow()" style="width:150px;height:50px;margin-bottom:5px;">Go slow</button>
<br>
Exploration epsilon: <span id="eps">0.15</span> <div id="slider"></div>
<br>
<div id="draw"></div>
<br>
<div style="float:right;">
<button class="btn btn-primary" onclick="loadAgent()" style="width:200px;height:35px;margin-bottom:5px;margin-right:20px;">Load a Pretrained Agent</button>
</div>
<textarea id="mysterybox" style="width:100%;display:none;">mystery text box</textarea>
<div> Experience write pointer: <div id="expi" style="display:inline-block;"></div> </div>
<div> Latest TD error: <div id="tde" style="display:inline-block;"></div> </div>
<div id="tdest"></div>
<div>
<div style="text-align:center">Average reward graph (high is good)</div>
<div id="flotreward" style="width:800px; height: 400px;"></div>
</div>
<div id="exp" class="md">
### Setup
This demo is a modification of **PuckWorld**:
- The **state space** is now large and continuous: The agent observes its own location (x,y), velocity (vx,vy), the locations of the green target, and the red target (total of 8 numbers).
- There are 4 **actions** available to the agent: To apply thrusters to the left, right, up and down. This gives the agent control over its velocity.
- The PuckWorld **dynamics** integrate the velocity of the agent to change its position. The green target moves once in a while to a random location. The red target always slowly follows the agent.
- The **reward** awarded to the agent is based on its distance to the green target (low is good). However, if the agent is in the vicinity of the red target (inside the disk), the agent gets a negative reward proportional to its distance to the red target.
The optimal strategy for the agent is to always go towards the green target (this is the regular PuckWorld), but to also avoid the red target's area of effect. This makes things more interesting because the agent has to learn to avoid it. Also, sometimes it's fun to watch the red target corner the agent. In this case, the optimal thing to do is to temporarily pay the price to zoom by quickly and away, rather than getting cornered and paying much more reward price when that happens.
**Interface**: The current reward experienced by the agent is shown as its color (green = high, red = low). The action taken by the agent (the medium-sized circle moving around) is shown by the arrow.
### Deep Q Learning
We're going to extend the Temporal Difference Learning (Q-Learning) to continuous state spaces. In the previos demo we've estimated the Q learning function \\(Q(s,a)\\) as a lookup table. Now, we are going to use a function approximator to model \\(Q(s,a) = f\_{\theta}(s,a)\\), where \\(\theta\\) are some parameters (e.g. weights and biases of a neurons in a network). However, as we will see everything else remains exactly the same. The first paper that showed impressive results with this approach was [Playing Atari with Deep Reinforcement Learning](http://arxiv.org/abs/1312.5602) at NIPS workshop in 2013, and more recently the Nature paper [Human-level control through deep reinforcement learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html), both by Mnih et al. However, more impressive than what we'll develop here, their function \\(f\_{\theta,a}\\) was an entire Convolutional Neural Network that looked at the raw pixels of Atari game console. It's hard to get all that to work in JS :(
Recall that in Q-Learning we had training data that is a single chain of \\(s\_t, a\_t, r\_t, s\_{t+1}, a\_{t+1}, r\_{t+1}, s\_{t+2}, \ldots \\) where the states \\(s\\) and the rewards \\(r\\) are samples from the environment dynamics, and the actions \\(a\\) are sampled from the agent's policy \\(a\_t \sim \pi(a \mid s\_t)\\). The agent's policy in Q-learning is to execute the action that is currently estimated to have the highest value (\\( \pi(a \mid s) = \arg\max\_a Q(s,a) \\)), or with a probability \\(\epsilon\\) to take a random action to ensure some exploration. The Q-Learning update at each time step then had the following form:
$$
Q(s\_t, a\_t) \leftarrow Q(s\_t, a\_t) + \alpha \left[ r\_t + \gamma \max\_a Q(s\_{t+1}, a) - Q(s\_t, a\_t) \right]
$$
As mentioned, this equation is describing a regular Stochastic Gradient Descent update with learning rate \\(\alpha\\) and the loss function at time \\(t\\):
$$
L\_t = (r\_t + \gamma \max\_a Q(s\_{t+1}, a) - Q(s\_t, a\_t))^2
$$
Here \\(y = r\_t + \gamma \max\_a Q(s\_{t+1}, a)\\) is thought of as a scalar-valued fixed target, while we backpropagate through the neural network that produced the prediction \\(f\_{\theta} = Q(s\_t,a\_t)\\). Note that the loss has a standard L2 norm regression form, and that we nudge the parameter vector \\(\theta\\) in a way that makes the computed \\(Q(s,a)\\) slightly closer to what it should be (i.e. to satisfy the Bellman equation). This softly encourages the constraint that the reward should be properly diffused, in expectation, backwards through the environment dynamics and the agent's policy.
In other words, Deep Q Learning is a 1-dimensional regression problem with a vanilla neural network, solved with vanilla stochastic gradient descent, except our training data is not fixed but generated by interacting with the environment.
### Bells and Whistles
There are several bells and whistles that make Q Learning with function approximation tracktable in practical problems:
**Modeling Q(s,a)**. First, as mentioned we are using a function approximator to model the Q function: \\(Q(s,a) = f\_{\theta}(s,a)\\). The natural approach to take is to perhaps have a neural network that takes as inputs the state vector \\(s\\) and an action vector \\(a\\) (e.g. encoded with a 1-of-k indicator vector), and output a single number that gives \\(Q(s,a)\\). However, this approach leads to a practical problem because the policy of the agent is to take an action that maximizes Q, that is: \\( \pi(a \mid s) = \arg\max\_a Q(s,a) \\). Therefore, with this naive approach, when the agent wants to produce an action we'd have to iterate over all actions, evaluate Q for each one and take the action that gave the highest Q.
A better idea is to instead predict the value \\(Q(s,a)\\) as a neural network that only takes the state \\(s\\), and produces multiple output, each interpreted as the Q value of taking that action in the given state. This way, deciding what action to take reduces to performing a single forward pass of the network and finding the argmax action. A diagram may help get the idea across:
<div style="text-align:center; margin: 20px;">
<img src="img/qsa.jpeg" style="max-width:500px;"><br>
Simple example with 3-dimensional state space (blue) and 2 possible actions (red). Green nodes are neurons of a neural network f. Left: A naive approach in which it takes multiple forward passes to find the argmax action. Right: A more efficient approach where the Q(s,a) computation is effectively shared among the neurons in the network. Only a single forward pass is required to find the best action to take in any given state.
</div>
Formulated in this way, the update algorithm becomes:
1. Experience a \\(s\_t, a\_t, r\_t, s\_{t+1}\\) transition in the environment.
2. Forward \\(s\_{t+1}\\) to evaluate the (fixed) target \\(y = r\_t + \gamma \max\_a f\_{\theta}(s\_{t+1})\\).
3. Forward \\(f\_{\theta}(s\_t)\\) and apply L2 regression loss on the dimension \\(a\_t\\) of the output which we want to equal to \\(y\\). Due to the L2 loss, the gradient has a simple form where the predicted value is simply subtracted from \\(y\\). The output dimensions corresponding to the other actions have zero gradient.
4. Backpropagate the gradient and perform a parameter update.
**Experience Replay Memory**. An important contribution of the Mnih et al. papers was to suggest an experience replay memory. This is loosely inspired by the brain, and in particular the way it syncs memory traces in the hippocampus with the cortex. What this amounts to is that instead of performing an update and then throwing away the experience tuple \\(s\_t, a\_t, r\_t, s\_{t+1}\\), we keep it around and effectively build up a training set of experiences. Then, we don't learn based on the new experience that comes in at time \\(t\\), but instead sample random expriences from the replay memory and perform an update on each sample. Again, this is exactly as if we were training a Neural Net with SGD on a dataset in a regular Machine Learning setup, except here the dataset is a result of agent interaction. This feature has the effect of removing correlations in the observed state,action,reward sequence and reduces gradual drift and forgetting. The algorihm hence becomes:
1. Experience \\(s\_t, a\_t, r\_t, s\_{t+1}\\) in the environment and add it to the training dataset \\(\mathcal{D}\\). If the size of \\(\mathcal{D}\\) is greater than some threshold, start replacing old experiences.
2. Sample **N** experiences from \\(\mathcal{D}\\) at random and update the Q function.
**Clamping TD Error**. Another interesting feature is to clamp the TD Error gradient at some fixed maximum value. That is, if the TD Error \\(r\_t + \gamma \max\_a Q(s\_{t+1}, a) - Q(s\_t, a\_t)\\) is greater in magnitude then some threshold `spec.tderror_clamp`, then we cap it at that value. This makes the learning more robust to outliers and has the interpretation of using Huber loss, which is an L2 penalty in a small region around the target value and an L1 penalty further away.
**Periodic Target Q Value Updates**. The last modification suggested in Mnih et al. also aims to reduce correlations between updates and the immediately undertaken behavior. The idea is to freeze the Q network once in a while into a frozen, copied network \\(\hat{Q}\\) which is used to only compute the targets. This target network is once in a while updated to the actual current \\(Q\\) network. That is, we use the target network \\(r\_t + \gamma \max\_a \hat{Q}(s\_{t+1},a)\\) to compute the target, and \\(Q\\) to evaluate the \\(Q(s\_t, a\_t)\\) part. In terms of the implementation, this requires one additional line of code where we every now and then we sync \\(\hat{Q} \leftarrow Q\\). I played around with this idea but at least on my simple toy examples I did not find it to give a substantial benefit, so I took it out of REINFORCEjs in the current implementation for sake of brevity.
### REINFORCEjs API use of DQN
If you'd like to use REINFORCEjs DQN in your application, define an `env` object that has the following methods:
- `env.getNumStates()` returns an integer for the dimensionality of the state feature space
- `env.getMaxNumActions()` returns an integer with the number of actions the agent can use
This seems kind of silly and the name `getNumStates` is a bit of a misnomer because it refers to the size of the state feature vector, not the raw number of unique states, but I chose this interface so that it is consistent with the tabular TD code and DP method. Similar to the tabular TD agent, the IO with the agent has an extremely simple interface:
<pre><code class="js">
// create environment
var env = {};
env.getNumStates = function() { return 8; }
env.getMaxNumActions = function() { return 4; }
// create the agent, yay!
var spec = { alpha: 0.01 } // see full options on top of this page
agent = new RL.DQNAgent(env, spec);
setInterval(function(){ // start the learning loop
var action = agent.act(s); // s is an array of length 8
// execute action in environment and get the reward
agent.learn(reward); // the agent improves its Q,policy,model, etc. reward is a float
}, 0);
</code></pre>
As for the available parameters in the DQN agent `spec`:
- `spec.gamma` is the discount rate. When it is zero, the agent will be maximally greedy and won't plan ahead at all. It will grab all the reward it can get right away. For example, children that fail the marshmallow experiment have a very low gamma. This parameter goes up to 1, but cannot be greater than or equal to 1 (this would make the discounted reward infinite).
- `spec.epsilon` controls the epsilon-greedy policy. High epsilon (up to 1) will cause the agent to take more random actions. It is a good idea to start with a high epsilon (e.g. 0.2 or even a bit higher) and decay it over time to be lower (e.g. 0.05).
- `spec.num_hidden_units`: currently the DQN agent is hardcoded to use a neural net with one hidden layer, the size of which is controlled with this parameter. For each problems you may get away with smaller networks.
- `spec.alpha` controls the learning rate. Everyone sets this by trial and error and that's pretty much the best thing we have.
- `spec.experience_add_every`: REINFORCEjs won't add a new experience to replay every single frame to try to conserve resources and get more variaty. You can turn this off by setting this parameter to 1. Default = 5
- `spec.experience_size`: size of memory. More difficult problems may need bigger memory
- `spec.learning_steps_per_iteration`: the more the better, but slower. Default = 20
- `spec.tderror_clamp`: for robustness, clamp the TD Errror gradient at this value.
### Pros and Cons
The very nice property of DQN agents is that the I/O interface is quite generic and very simple: The agent accepts state vectors, produces discrete actions, and learns from relatively arbitrary agents. However, there are several things to keep in mind when applying this agent in practice:
- If the rewards are very sparse in the environment the agent will have trouble learning. Right now there is no priority sweeping support, but one might imagine oversampling experience that have high TD errors. It's not clear how this can be done in most principled way. Similarly, there are no eligibility traces right now though this could be added with a few modifications in future versions.
- The exploration is rather naive, since a random action is taken once in a while. If the environment requires longer sequences of precise actions to get a reward, the agent might have a lot of difficulty finding these by chance, and then also learning from them sufficiently.
- DQN only supports a set number of discrete actions and it is not obvious how one can incorporate (high-dimensional) continuous action spaces.
Speaking of high-dimensional continuous action spaces, this is what Policy Gradient Actor Critic methods are quite good at! Head over to the Policy Gradient (PG) demo. (Oops, coming soon...)
</div>
</div>
<br><br><br><br><br>
</body>
</html>