-
-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathmain.v
46 lines (38 loc) · 832 Bytes
/
main.v
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
module main
import vsl.ml
fn main() {
// Generate random data with two classes
x := [
[1.0, 2.0],
[2.0, 3.0],
[3.0, 3.0],
[2.0, 1.0],
[6.0, 7.0],
[8.0, 6.0],
[7.0, 8.0],
[8.0, 7.0],
[4.0, 5.0],
[5.0, 5.0],
[4.5, 6.0],
[7.0, 6.0],
]
y := [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0]
mut data := ml.Data.from_raw_xy_sep(x, y)!
// Create a KNN model
mut knn := ml.KNN.new(mut data, 'Example KNN')!
// Set weights to give more importance to class 1
weights := {
0.0: 1.0
1.0: 2.0
}
knn.set_weights(weights)!
// Train the KNN model
knn.train()
// Predict the class for a new point
to_pred := [4.0, 5.0]
prediction := knn.predict(k: 3, to_pred: to_pred)!
// Print the prediction
println('Prediction: ${prediction}')
// Plot the KNN model
knn.get_plotter().show()!
}