-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmoons_demo.ml
133 lines (121 loc) · 4.6 KB
/
moons_demo.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
open Base
open Ocannl
module Tn = Arrayjit.Tnode
module IDX = Train.IDX
module TDSL = Operation.TDSL
module NTDSL = Operation.NTDSL
module CDSL = Train.CDSL
module Utils = Arrayjit.Utils
module Asgns = Arrayjit.Assignments
module Rand = Arrayjit.Rand.Lib
module Debug_runtime = Utils.Debug_runtime
let demo () =
let seed = 3 in
Rand.init seed;
Utils.settings.fixed_state_for_init <- Some seed;
Utils.settings.output_debug_files_in_build_directory <- true;
(* Utils.enable_runtime_debug (); *)
let hid_dim = 16 in
let len = 512 in
let batch_size = 32 in
let epochs = 75 in
(* Utils.settings.debug_log_from_routines <- true; *)
(* TINY for debugging: *)
(* let hid_dim = 2 in let len = 16 in let batch_size = 2 in let epochs = 2 in *)
let n_batches = 2 * len / batch_size in
let steps = epochs * n_batches in
let weight_decay = 0.0002 in
let%op mlp x = "b3" + ("w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x))))) in
let noise () = Rand.float_range (-0.1) 0.1 in
let moons_flat =
Array.concat_map (Array.create ~len ())
~f:
Float.(
fun () ->
let i = Rand.int len in
let v = of_int i * pi / of_int len in
let c = cos v and s = sin v in
[| c + noise (); s + noise (); 1.0 - c + noise (); 0.5 - s + noise () |])
in
let moons_flat = TDSL.init_const ~l:"moons_flat" ~o:[ 2 ] moons_flat in
let moons_classes = Array.init (len * 2) ~f:(fun i -> if i % 2 = 0 then 1. else -1.) in
let moons_classes = TDSL.init_const ~l:"moons_classes" ~o:[ 1 ] moons_classes in
let batch_n, bindings = IDX.get_static_symbol ~static_range:n_batches IDX.empty in
let step_n, bindings = IDX.get_static_symbol bindings in
let%op moons_input = moons_flat @| batch_n in
let%op moons_class = moons_classes @| batch_n in
let%op margin_loss = relu (1 - (moons_class *. mlp moons_input)) in
let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
let update = Train.grad_update scalar_loss in
let%op learning_rate = 0.1 *. (!..steps - !@step_n) /. !..steps in
Train.set_hosted learning_rate.value;
let sgd = Train.sgd_update ~learning_rate ~weight_decay update in
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name:"cuda" ()) in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
let routine =
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
in
let points = Tn.points_2d ~xdim:0 ~ydim:1 moons_flat.value in
let classes = Tn.points_1d ~xdim:0 moons_classes.value in
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
let plot_moons =
PrintBox_utils.plot ~as_canvas:true
[
Scatterplot { points = points1; content = PrintBox.line "#" };
Scatterplot { points = points2; content = PrintBox.line "+" };
]
in
Stdio.printf "\nHalf-moons scatterplot:\n%!";
PrintBox_text.output Stdio.stdout plot_moons;
Stdio.print_endline "\n";
let open Operation.At in
let step_ref = IDX.find_exn routine.bindings step_n in
let batch_ref = IDX.find_exn routine.bindings batch_n in
let epoch_loss = ref 0. in
step_ref := 0;
let%track_sexp _train_loop : unit =
for epoch = 0 to epochs - 1 do
for batch = 0 to n_batches - 1 do
batch_ref := batch;
Utils.capture_stdout_logs @@ fun () ->
Train.run routine;
epoch_loss := !epoch_loss +. scalar_loss.@[0];
Int.incr step_ref
done;
Stdio.printf "Epoch %d, lr=%f, epoch loss=%f\n%!" epoch learning_rate.@[0] !epoch_loss;
epoch_loss := 0.
done
in
let%op mlp_result = mlp "point" in
Train.set_on_host mlp_result.value;
let result_routine =
Train.to_routine
(module Backend)
routine.context IDX.empty
[%cd
~~("moons infer";
mlp_result.forward)]
in
let callback (x, y) =
Tn.set_values point.value [| x; y |];
Utils.capture_stdout_logs @@ fun () ->
Train.run result_routine;
Float.(mlp_result.@[0] >= 0.)
in
let%track_sexp _plotting : unit =
let plot_moons =
PrintBox_utils.plot ~as_canvas:true
[
Scatterplot { points = points1; content = PrintBox.line "#" };
Scatterplot { points = points2; content = PrintBox.line "+" };
Boundary_map
{ content_false = PrintBox.line "."; content_true = PrintBox.line "*"; callback };
]
in
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!";
PrintBox_text.output Stdio.stdout plot_moons;
Stdio.print_endline "\n"
in
()
let () = demo ()