-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathknn_loss.go
49 lines (41 loc) · 1.15 KB
/
knn_loss.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
package gortex
import (
"github.com/vseledkin/gortex/assembler"
"github.com/vseledkin/gortex/vptree"
)
type KnnLoss struct {
index *vptree.VPTree
centroid map[int][]float32
ids []int
k int
}
func CreateKnnLoss(k int, keys []int, vectors [][]float32) *KnnLoss {
kl := &KnnLoss{k: k}
items := make([]*vptree.Item, len(keys))
dim := len(vectors[0])
for i := range keys {
item := &vptree.Item{ID: keys[i]}
copy(item.Vector, vectors[i])
items[i]=item
}
kl.index = vptree.NewVPTree(vptree.Euclidean, items)
// compute centroid for group of k nearest neighbours
kl.centroid = map[int][]float32{}
for i, id := range keys {
nearest, _ := kl.index.Search(&vptree.Item{id, vectors[i]}, kl.k+1, -1.)
centroid := make([]float32, dim)
nearest = nearest[1:]
for i := range nearest { // except himself
assembler.Sxpy(nearest[i].Vector, centroid)
}
assembler.Sscale(1.0/float32(len(nearest)), centroid)
kl.centroid[id] = centroid
}
return kl
}
func (kl *KnnLoss) Loss(G *Graph, id int, vector *Matrix) *Matrix {
centroid := kl.centroid[id]
target := Mat(len(centroid), 1)
target.W = centroid
return G.MSE_t(vector, target)
}