Skip to content

Commit

Permalink
GP: add unfitted predict
Browse files Browse the repository at this point in the history
  • Loading branch information
pa-m committed Sep 6, 2019
1 parent 86b455d commit 3275a0a
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 24 deletions.
50 changes: 44 additions & 6 deletions gaussian_process/gpr.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Regressor struct {
base.RandomState
Xtrain *mat.Dense
Ytrain *mat.Dense
YtrainMean *mat.Dense
KernelOpt kernels.Kernel
L *mat.Cholesky
LogMarginalLikelihoodValue float64
Expand Down Expand Up @@ -64,6 +65,9 @@ func (m *Regressor) PredicterClone() base.Predicter {

// GetNOutputs returns Y columns count
func (m *Regressor) GetNOutputs() int {
if m.Ytrain == nil {
return 1
}
return m.Ytrain.RawMatrix().Cols
}

Expand All @@ -78,16 +82,50 @@ func (m *Regressor) Fit(X, Y mat.Matrix) base.Fiter {
return m
}

// Predict using the Gaussian process regression model
func (m *Regressor) Predict(X mat.Matrix, Y mat.Mutable) *mat.Dense {
// PredictEx predicts using the Gaussian process regression model, returning Ymean and std or cov
func (m *Regressor) PredictEx(X mat.Matrix, Y mat.Mutable, returnStd, returnCov bool) (*mat.Dense, *mat.DiagDense, *mat.Dense) {
NSamples, _ := X.Dims()
var Yd *mat.Dense
var Ymean, Ycov *mat.Dense
var Ystd *mat.DiagDense
if _, ok := Y.(*mat.Dense); ok {
Yd = Y.(*mat.Dense)
Ymean = Y.(*mat.Dense)
} else {
Ymean = mat.NewDense(NSamples, m.GetNOutputs(), nil)
}
if m.Xtrain == nil {
// # Unfitted;predict based on GP prior
// y_mean = np.zeros(X.shape[0])
if returnCov {
Ycov, _ = m.Kernel.Eval(X, nil, false)
} else if returnStd {
Ystd = m.Kernel.Diag(X)
for i := 0; i < NSamples; i++ {
Ystd.SetDiag(i, math.Sqrt(Ystd.At(i, i)))
}
}

} else {
Yd = mat.NewDense(NSamples, m.GetNOutputs(), nil)
// # Predict based on GP posterior
//K_trans = self.kernel_(X, self.X_train_)
//y_mean = K_trans.dot(self.alpha_) # Line 4 (y_mean = f_star)
//y_mean = self._y_train_mean + y_mean # undo normal.

Ktrans, _ := m.Kernel.Eval(X, m.Xtrain, false)
Ymean.Mul(Ktrans, mat.NewDense(NSamples, 1, m.Alpha))
Ymean.Add(Ymean, m.YtrainMean)
if returnStd {
// TODO
} else if returnCov {
// TODO
}
}
return base.FromDense(Y, Yd)
return base.FromDense(Y, Ymean), Ystd, Ycov
}

// Predict using the Gaussian process regression model
func (m *Regressor) Predict(X mat.Matrix, Y mat.Mutable) *mat.Dense {
Ymean, _, _ := m.PredictEx(X, Y, false, false)
return base.FromDense(Y, Ymean)
}

// Score returns R2 score
Expand Down
23 changes: 22 additions & 1 deletion gaussian_process/gpr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ func TestRegressor_LogMarginalLikelihood(t *testing.T) {
}}

//# now the noisy case
//X = np.linspace(0.1, 9.9, 20)
gp := NewRegressor(kernel)
X := mat.NewDense(20, 1, nil)
//X = np.linspace(0.1, 9.9, 20)
{
x := X.RawMatrix().Data
for i := range x {
Expand Down Expand Up @@ -81,3 +81,24 @@ func TestRegressor_LogMarginalLikelihood(t *testing.T) {
t.Errorf("expected grad %g, got %g", expectedGrad, grad)
}
}

func TestRegressor_Predict(t *testing.T) {
//gp = Regressor(kernel=kernel, n_restarts_optimizer=9)
kernel := &kernels.Product{KernelOperator: kernels.KernelOperator{
K1: &kernels.ConstantKernel{ConstantValue: 1, ConstantValueBounds: [2]float64{1e-3, 1e3}},
K2: &kernels.RBF{LengthScale: []float64{10}, LengthScaleBounds: [][2]float64{{1e-2, 1e2}}},
}}
gp := NewRegressor(kernel)
X := mat.NewDense(20, 1, nil)
//X = np.linspace(0.1, 9.9, 20)
{
x := X.RawMatrix().Data
for i := range x {
x[i] = .1 + ((9.9 - .1) * float64(i) / (float64(len(x) - 1)))
}
}
Ymean := gp.Predict(X, nil)
if mat.Norm(Ymean, math.Inf(1)) > 0 {
t.Error("unfitted predict, expected 0")
}
}
30 changes: 13 additions & 17 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
module github.com/pa-m/sklearn

require (
github.com/ajstarks/svgo v0.0.0-20181006003313-6ce6a3bcf6cd // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.0.0
github.com/ajstarks/svgo v0.0.0-20190826172357-de52242f3d65 // indirect
github.com/chewxy/math32 v1.0.4
github.com/fogleman/gg v1.3.0 // indirect
github.com/jung-kurt/gofpdf v1.5.4 // indirect
github.com/jung-kurt/gofpdf v1.10.1 // indirect
github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a
github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4
github.com/phpdave11/gofpdi v1.0.5 // indirect
github.com/pkg/errors v0.8.1
github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237 // indirect
github.com/stretchr/objx v0.2.0 // indirect
github.com/xtgo/set v1.0.0 // indirect
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 // indirect
golang.org/x/exp v0.0.0-20190718202018-cfdd5522f6f6
golang.org/x/image v0.0.0-20190729225735-1bd0cf576493 // indirect
golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028 // indirect
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 // indirect
golang.org/x/sys v0.0.0-20190730183949-1393eb018365 // indirect
golang.org/x/tools v0.0.0-20190730215328-ed3277de2799
gonum.org/v1/gonum v0.0.0-20190724213354-3129c79de289
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 // indirect
golang.org/x/exp v0.0.0-20190829153037-c13cbed26979
golang.org/x/image v0.0.0-20190902063713-cb417be4ba39 // indirect
golang.org/x/mobile v0.0.0-20190830201351-c6da95954960 // indirect
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 // indirect
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd // indirect
golang.org/x/tools v0.0.0-20190905235650-93dcc2f048f5
gonum.org/v1/gonum v0.0.0-20190904110519-2065cbd6b42a
gonum.org/v1/plot v0.0.0-20190615073203-9aa86143727f
gorgonia.org/tensor v0.8.1
gorgonia.org/vecf32 v0.7.0 // indirect
gorgonia.org/vecf64 v0.7.0 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gorgonia.org/tensor v0.9.1
)
Loading

0 comments on commit 3275a0a

Please sign in to comment.