Skip to content

Commit

Permalink
GP kernels: started evalGratient (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
pa-m committed Jul 14, 2019
1 parent e3dc539 commit f639580
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 26 deletions.
106 changes: 87 additions & 19 deletions gaussian_process/kernels/kernels.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"gonum.org/v1/gonum/floats"

"gonum.org/v1/gonum/mat"
t "gorgonia.org/tensor"
)

// hyperparameter specification
Expand Down Expand Up @@ -71,7 +72,7 @@ type Kernel interface {
Theta() mat.Matrix
Bounds() mat.Matrix
CloneWithTheta(theta mat.Matrix) Kernel
Eval(X, Y mat.Matrix) *mat.Dense
Eval(X, Y mat.Matrix, evalGradient bool) (*mat.Dense,*t.Dense)
Diag(X mat.Matrix) (K *mat.DiagDense)
IsStationary() bool
String() string
Expand Down Expand Up @@ -130,7 +131,7 @@ func (k KernelOperator) CloneWithTheta(theta mat.Matrix) Kernel {
}

// Eval ...
func (k KernelOperator) Eval(X, Y mat.Matrix) *mat.Dense {
func (k KernelOperator) Eval(X, Y mat.Matrix, evalGradient bool) (*mat.Dense,*t.Dense) {
panic("Eval must be implemented by wrapper")
}

Expand Down Expand Up @@ -160,12 +161,15 @@ func (k *Sum) CloneWithTheta(theta mat.Matrix) Kernel {
}

// Eval return the kernel k(X, Y) and optionally its gradient
func (k *Sum) Eval(X, Y mat.Matrix) *mat.Dense {
func (k *Sum) Eval(X, Y mat.Matrix, evalGradient bool) (*mat.Dense,*t.Dense) {

K1 := k.k1.Eval(X, Y)
K2 := k.k2.Eval(X, Y)
K1,K1g := k.k1.Eval(X, Y, evalGradient)
K2,K2g := k.k2.Eval(X, Y, evalGradient)
K1.Add(K1, K2)
return K1
if evalGradient {
K1g.Add(K2g, t.UseUnsafe(), t.WithReuse(K1g))
}
return K1,K1g
}

// Diag returns the diagonal of the kernel k(X, X)
Expand Down Expand Up @@ -196,12 +200,16 @@ func (k *Product) CloneWithTheta(theta mat.Matrix) Kernel {
}

// Eval return the kernel k(X, Y) and optionally its gradient
func (k *Product) Eval(X, Y mat.Matrix) *mat.Dense {
func (k *Product) Eval(X, Y mat.Matrix, evalGradient bool) (*mat.Dense,*t.Dense) {

K1 := k.k1.Eval(X, Y)
K2 := k.k2.Eval(X, Y)
K1,K1g := k.k1.Eval(X, Y, evalGradient)
K2,K2g := k.k2.Eval(X, Y, evalGradient)
K1.MulElem(K1, K2)
return K1
if evalGradient {
K1g.Mul(K2g, t.UseUnsafe(), t.WithReuse(K1g))
}

return K1,K1g
}

// Diag returns the diagonal of the kernel k(X, X)
Expand Down Expand Up @@ -238,12 +246,19 @@ func (k Exponentiation) hyperparameters() hyperparameters {
}

// Eval return the kernel k(X, Y) and optionally its gradient
func (k *Exponentiation) Eval(X, Y mat.Matrix) *mat.Dense {
K := k.Kernel.Eval(X, Y)
func (k *Exponentiation) Eval(X, Y mat.Matrix, evalGradient bool) (*mat.Dense,*t.Dense) {
K,Kg := k.Kernel.Eval(X, Y, evalGradient)
K.Apply(func(_, _ int, v float64) float64 {
return math.Pow(v, k.Exponent)
}, K)
return K
if evalGradient {
Kdata:=K.RawMatrix().Data
Kgdata:=Kg.Data().([]float64)
for i:=range Kgdata {
Kgdata[i]*=k.Exponent*math.Pow(Kdata[i],k.Exponent-1)
}
}
return K,Kg
}

// Diag returns the diagonal of the kernel k(X, X)
Expand Down Expand Up @@ -296,7 +311,7 @@ func (k *ConstantKernel) CloneWithTheta(theta mat.Matrix) Kernel {
// Eval returns
// K : array, shape (n_samples_X, n_samples_Y)
// Kernel k(X, Y)
func (k *ConstantKernel) Eval(X, Y mat.Matrix) *mat.Dense {
func (k *ConstantKernel) Eval(X, Y mat.Matrix, evalGradient bool) (*mat.Dense,*t.Dense) {
nx, _ := X.Dims()
if Y == mat.Matrix(nil) {
Y = X
Expand All @@ -307,7 +322,20 @@ func (k *ConstantKernel) Eval(X, Y mat.Matrix) *mat.Dense {
for i := range kdata {
kdata[i] = k.ConstantValue
}
return K
var Kg *t.Dense
if evalGradient && Y==X {
if k.hyperparameters()[0].IsFixed(){
Kg=t.NewDense(t.Float64,t.Shape{nx,nx,0})
}else{
Kg=t.NewDense(t.Float64,t.Shape{nx,nx,1},t.WithBacking(make([]float64,nx*nx)))
it:=Kg.Iterator()
Kgdata:=Kg.Data().([]float64)
for i,e:=it.Start();e==nil && !it.Done();i,e=it.Next() {
Kgdata[i]=k.ConstantValue
}
}
}
return K,Kg
}

// Diag returns the diagonal of the kernel k(X, X).
Expand Down Expand Up @@ -363,7 +391,7 @@ func (k *WhiteKernel) CloneWithTheta(theta mat.Matrix) Kernel {
}

// Eval return the kernel k(X, Y)
func (k *WhiteKernel) Eval(X, Y mat.Matrix) *mat.Dense {
func (k *WhiteKernel) Eval(X, Y mat.Matrix, evalGradient bool) (*mat.Dense,*t.Dense) {
nx, nfeat := X.Dims()
if Y == mat.Matrix(nil) {
Y = X
Expand All @@ -381,7 +409,22 @@ func (k *WhiteKernel) Eval(X, Y mat.Matrix) *mat.Dense {
}
}
}
return K
var Kg *t.Dense
if evalGradient && Y==X {
if k.hyperparameters()[0].IsFixed(){
Kg=t.NewDense(t.Float64,t.Shape{nx,nx,0})
}else{
Kg=t.NewDense(t.Float64,t.Shape{nx,nx,1},t.WithBacking(make([]float64,nx*nx)))
it:=Kg.Iterator()
Kgdata:=Kg.Data().([]float64)
for i,e:=it.Start();e==nil && !it.Done();i,e=it.Next() {
if i%(nx+1)==0 {
Kgdata[i] = k.NoiseLevel
}
}
}
}
return K,Kg
}

// Diag returns the diagonal of the kernel k(X, X)
Expand Down Expand Up @@ -447,7 +490,7 @@ func (k *RBF) CloneWithTheta(theta mat.Matrix) Kernel {
}

// Eval return the kernel k(X, Y)
func (k *RBF) Eval(X, Y mat.Matrix) (K *mat.Dense) {
func (k *RBF) Eval(X, Y mat.Matrix, evalGradient bool) (K *mat.Dense,Kg *t.Dense) {
nx, nfeat := X.Dims()
if Y == mat.Matrix(nil) {
Y = X
Expand Down Expand Up @@ -481,6 +524,19 @@ func (k *RBF) Eval(X, Y mat.Matrix) (K *mat.Dense) {
K.Set(ix, iy, math.Exp(-.5*(d2)))
}
}
if evalGradient && Y==X {
if k.hyperparameters()[0].IsFixed(){
Kg=t.NewDense(t.Float64,t.Shape{nx,nx,0})
}else{
Kg=t.NewDense(t.Float64,t.Shape{nx,nx,1},t.WithBacking(make([]float64,nx*nx)))
it:=Kg.Iterator()
Kgdata:=Kg.Data().([]float64)
for i,e:=it.Start();e==nil && !it.Done();i,e=it.Next() {
Kgdata[i]=0// TODO Kgdata[i]=K.At(r,c)+ mat.Norm(X.Row(r)-X.Row(c),2)/lengscale

}
}
}
return
}

Expand Down Expand Up @@ -534,7 +590,7 @@ func (k *DotProduct) CloneWithTheta(theta mat.Matrix) Kernel {
}

// Eval return the kernel k(X, Y)
func (k *DotProduct) Eval(X, Y mat.Matrix) (K *mat.Dense) {
func (k *DotProduct) Eval(X, Y mat.Matrix, evalGradient bool) (K *mat.Dense,Kg*t.Dense) {
nx, nfeat := X.Dims()
if Y == mat.Matrix(nil) {
Y = X
Expand All @@ -551,6 +607,18 @@ func (k *DotProduct) Eval(X, Y mat.Matrix) (K *mat.Dense) {
K.Set(ix, iy, s2+mat.Dot(mat.NewVecDense(nfeat, Xrow), mat.NewVecDense(nfeat, Yrow)))
}
}
if evalGradient && Y==X {
if k.hyperparameters()[0].IsFixed(){
Kg=t.NewDense(t.Float64,t.Shape{nx,nx,0})
}else{
Kg=t.NewDense(t.Float64,t.Shape{nx,nx,1},t.WithBacking(make([]float64,nx*nx)))
it:=Kg.Iterator()
Kgdata:=Kg.Data().([]float64)
for i,e:=it.Start();e==nil && !it.Done();i,e=it.Next() {
Kgdata[i]=2*k.Sigma0*k.Sigma0
}
}
}
return
}

Expand Down
26 changes: 19 additions & 7 deletions gaussian_process/kernels/kernels_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ func ExampleConstantKernel() {
X, Y := sample(state, 3, 2), sample(state, 3, 2)
K := &ConstantKernel{ConstantValue: 1.23}
fmt.Printf("K=%s, stationary:%v\n", K, K.IsStationary())
fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(K.Eval(X, Y)), mat.Formatted(K.Diag(X)))
KXY,_:=K.Eval(X,Y,false)
KXX:= K.Diag(X)
fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(KXY), mat.Formatted(KXX))
// Output:
// K=1.11**2, stationary:true
// X=
Expand Down Expand Up @@ -59,7 +61,9 @@ func ExampleWhiteKernel() {
K := &WhiteKernel{NoiseLevel: 1.23}
fmt.Printf("K=%s, stationary:%v\n", K, K.IsStationary())

fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(K.Eval(X, Y)), mat.Formatted(K.Diag(X)))
KXY,_:=K.Eval(X,Y,false)
KXX:= K.Diag(X)
fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(KXY), mat.Formatted(KXX))
// Output:
// K=WhiteKernel(noise_level=1.23), stationary:true
// X=
Expand Down Expand Up @@ -89,7 +93,9 @@ func ExampleRBF() {
K := &RBF{LengthScale: []float64{1.23}}
fmt.Printf("K=%s, stationary:%v\n", K, K.IsStationary())

fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(K.Eval(X, Y)), mat.Formatted(K.Diag(X)))
KXY,_:=K.Eval(X,Y,false)
KXX:= K.Diag(X)
fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(KXY), mat.Formatted(KXX))
// Output:
// K=RBF([1.23]), stationary:true
// X=
Expand Down Expand Up @@ -120,7 +126,9 @@ func ExampleDotProduct() {
K := &DotProduct{Sigma0: 1.23}
fmt.Printf("K=%s, stationary:%v\n", K, K.IsStationary())

fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(K.Eval(X, Y)), mat.Formatted(K.Diag(X)))
KXY,_:=K.Eval(X,Y,false)
KXX:= K.Diag(X)
fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(KXY), mat.Formatted(KXX))
// Output:
// K=DotProduct(sigma_0=1.23), stationary:false
// X=
Expand Down Expand Up @@ -151,7 +159,9 @@ func ExampleSum() {
K := &Sum{KernelOperator{k1: &ConstantKernel{ConstantValue: 1.23}, k2: &WhiteKernel{NoiseLevel: 1.23}}}
fmt.Printf("K=%s, stationary:%v\n", K, K.IsStationary())

fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(K.Eval(X, Y)), mat.Formatted(K.Diag(X)))
KXY,_:=K.Eval(X,Y,false)
KXX:= K.Diag(X)
fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(KXY), mat.Formatted(KXX))
// Output:
// K=1.11**2 + WhiteKernel(noise_level=1.23), stationary:true
// X=
Expand Down Expand Up @@ -185,7 +195,9 @@ func ExampleProduct() {
}
fmt.Printf("K=%s, stationary:%v\n", K, K.IsStationary())

fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(K.Eval(X, Y)), mat.Formatted(K.Diag(X)))
KXY,_:=K.Eval(X,Y,false)
KXX:= K.Diag(X)
fmt.Printf("X=\n%.8f\nY=\n%.8f\nK(X,Y)=\n%.8f\nK(X,X)=\n%.8f\n", mat.Formatted(X), mat.Formatted(Y), mat.Formatted(KXY), mat.Formatted(KXX))
// Output:
// K=1.11**2 * DotProduct(sigma_0=1.23), stationary:false
// X=
Expand Down Expand Up @@ -273,7 +285,7 @@ func TestExponentiation(t *testing.T) {
state := randomkit.NewRandomkitSource(1)
// X=np.reshape(np.random.sample(6),(3,2))
X, Y := sample(state, 3, 2), sample(state, 3, 2)
actual := kernel.Eval(X, Y)
actual,_ := kernel.Eval(X, Y,false)
assertEq(
t,
mat.NewDense(3, 3, []float64{
Expand Down
1 change: 1 addition & 0 deletions gaussian_process/kernels/matutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ func matCopy(dst mat.Mutable, src mat.Matrix) {
}
}
}

5 changes: 5 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@ module github.com/pa-m/sklearn

require (
github.com/ajstarks/svgo v0.0.0-20181006003313-6ce6a3bcf6cd // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.0.0
github.com/fogleman/gg v1.3.0 // indirect
github.com/jung-kurt/gofpdf v1.5.2 // indirect
github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a
github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4
github.com/xtgo/set v1.0.0 // indirect
golang.org/x/exp v0.0.0-20190627132806-fd42eb6b336f
golang.org/x/image v0.0.0-20190703141733-d6a02ce849c9 // indirect
golang.org/x/tools v0.0.0-20190708203411-c8855242db9c // indirect
gonum.org/v1/gonum v0.0.0-20190704103327-70ddf0df3d53
gonum.org/v1/plot v0.0.0-20190615073203-9aa86143727f
gorgonia.org/tensor v0.8.1
gorgonia.org/vecf32 v0.7.0 // indirect
gorgonia.org/vecf64 v0.7.0 // indirect
)
11 changes: 11 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3
github.com/ajstarks/svgo v0.0.0-20181006003313-6ce6a3bcf6cd h1:JdtityihAc6A+gVfYh6vGXfZQg+XOLyBvla/7NbXFCg=
github.com/ajstarks/svgo v0.0.0-20181006003313-6ce6a3bcf6cd/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
github.com/chewxy/math32 v1.0.0 h1:RTt2SACA7BTzvbsAKVQJLZpV6zY2MZw4bW9L2HEKkHg=
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
Expand All @@ -23,13 +25,16 @@ github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a/go.mod h1:gHioqOgOl5
github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4 h1:+LyPTCDcQRARqza7LfS0w7v03e7VYceqQNTE8eRcGA4=
github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4/go.mod h1:2Ix1Kyeujyr6FhU2SPX4iyiEpEBjHHcRV/Mki06ACcE=
github.com/phpdave11/gofpdi v1.0.3/go.mod h1:B7ryN7q4MLItB8BDM5PJAplblJegAAcaI98viOZUihg=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20170806203942-52369c62f446/go.mod h1:uYEyJGbgTkfkS4+E/PavXkNJcbFIpEtjt2B0KDQ5+9M=
github.com/remyoudompheng/bigfft v0.0.0-20190512091148-babf20351dd7/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
Expand Down Expand Up @@ -91,6 +96,12 @@ gonum.org/v1/netlib v0.0.0-20190331212654-76723241ea4e h1:jRyg0XfpwWlhEV8mDfdNGB
gonum.org/v1/netlib v0.0.0-20190331212654-76723241ea4e/go.mod h1:kS+toOQn6AQKjmKJ7gzohV1XkqsFehRA2FbsbkopSuQ=
gonum.org/v1/plot v0.0.0-20190615073203-9aa86143727f h1:5+IdMldM5iTBk6wFBDtdVSSCaIPL922N3xbxmPE3Z1g=
gonum.org/v1/plot v0.0.0-20190615073203-9aa86143727f/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
gorgonia.org/tensor v0.8.1 h1:PTJ81ku5uYs/qsZLMFq02q0DWI4YuJeu0ikieFkkh1o=
gorgonia.org/tensor v0.8.1/go.mod h1:05Y4laKuVlj4qFoZIZW1q/9n1jZkgDBOLmKXZdBLG1w=
gorgonia.org/vecf32 v0.7.0 h1:mkpVzSyT7/Cput5/ZxaMzzp2xbmOtqOyJlTf7AdSMe0=
gorgonia.org/vecf32 v0.7.0/go.mod h1:iHG+kvTMqGYA0SgahfO2k62WRnxmHsqAREGbayRDzy8=
gorgonia.org/vecf64 v0.7.0 h1:ZphOGJfnWlFfY7x8WAJAfO64IAtYqPPq9TEGem+ItZE=
gorgonia.org/vecf64 v0.7.0/go.mod h1:1y4pmcSd+wh3phG+InwWQjYrqwyrtN9h27WLFVQfV1Q=
modernc.org/cc v1.0.0/go.mod h1:1Sk4//wdnYJiUIxnW8ddKpaOJCF37yAdqYnkxUpaYxw=
modernc.org/golex v1.0.0/go.mod h1:b/QX9oBD/LhixY6NDh+IdGv17hgB+51fET1i2kPSmvk=
modernc.org/mathutil v1.0.0/go.mod h1:wU0vUrJsVWBZ4P6e7xtFJEhFSNsfRLJ8H458uRjg03k=
Expand Down

0 comments on commit f639580

Please sign in to comment.