From 469516ebf06b9612178a9ca789c39ca3df3486d4 Mon Sep 17 00:00:00 2001 From: Pascal Masschelier Date: Wed, 4 Dec 2019 21:37:29 +0100 Subject: [PATCH] base.SourceCloner. NB ExampleGridSearchCV fails. --- base/source.go | 35 +++++++---- base/source_test.go | 2 +- gaussian_process/gpr.go | 2 +- go.mod | 6 +- go.sum | 13 ++++ linear_model/logistic.go | 2 +- model_selection/search.go | 2 +- model_selection/search_test.go | 10 ++-- model_selection/split.go | 32 +++++----- model_selection/split_test.go | 79 +++++++++++++++++++++---- naive_bayes/naivebayes_test.go | 8 +-- neural_network/multilayer_perceptron.go | 2 +- preprocessing/pca.go | 4 +- svm/svm.go | 2 +- svm/svr.go | 2 +- 15 files changed, 143 insertions(+), 58 deletions(-) diff --git a/base/source.go b/base/source.go index 2aa80d8..663d941 100644 --- a/base/source.go +++ b/base/source.go @@ -3,21 +3,17 @@ package base import ( "sync" - "golang.org/x/exp/rand" - "github.com/pa-m/randomkit" + "golang.org/x/exp/rand" ) // A Source represents a source of uniformly-distributed // pseudo-random int64 values in the range [0, 1<<64). -type Source interface { - Uint64() uint64 - Seed(seed uint64) -} +type Source = rand.Source // SourceCloner is an "golang.org/x/exp/rand".Source with a Clone method type SourceCloner interface { - Clone() rand.Source + SourceClone() Source } // RandomState represents a bit more than random_state pythonic attribute. it's not only a seed but a source with a state as it's name states @@ -34,7 +30,14 @@ func NewSource(seed uint64) *randomkit.RKState { // It is just a standard Source with its operations protected by a sync.Mutex. type LockedSource struct { lk sync.Mutex - src Source + src *randomkit.RKState +} + +// WithLock executes f while s is locked +func (s *LockedSource) WithLock(f func(Source)) { + s.lk.Lock() + f(s.src) + s.lk.Unlock() } // Uint64 ... @@ -52,9 +55,9 @@ func (s *LockedSource) Seed(seed uint64) { s.lk.Unlock() } -// Clone ... -func (s *LockedSource) Clone() rand.Source { - return &LockedSource{src: s.src.(SourceCloner).Clone()} +// SourceClone ... +func (s *LockedSource) SourceClone() Source { + return &LockedSource{src: s.src.SourceClone().(*randomkit.RKState)} } // NewLockedSource returns a rand.Source safe for concurrent access @@ -78,3 +81,13 @@ type NormFloat64er interface { type Intner interface { Intn(int) int } + +// Permer is implemented by a random source having a method Perm(int) []int +type Permer interface { + Perm(int) []int +} + +// Shuffler is implemented by a random source having a method Shuffle(int,func(int,int)) +type Shuffler interface { + Shuffler(int, func(int, int)) +} diff --git a/base/source_test.go b/base/source_test.go index 5c93495..911a3fa 100644 --- a/base/source_test.go +++ b/base/source_test.go @@ -17,7 +17,7 @@ var ( func TestSource(t *testing.T) { s := NewSource(7) - s2 := s.Clone() + s2 := s.SourceClone() var a [5]float64 for i := range a { a[i] = s.Float64() diff --git a/gaussian_process/gpr.go b/gaussian_process/gpr.go index e8b81ee..2e770ff 100644 --- a/gaussian_process/gpr.go +++ b/gaussian_process/gpr.go @@ -46,7 +46,7 @@ func (m *Regressor) IsClassifier() bool { return false } func (m *Regressor) PredicterClone() base.Predicter { clone := *m if cloner, ok := m.RandomState.(base.SourceCloner); ok { - clone.RandomState = cloner.Clone() + clone.RandomState = cloner.SourceClone() } if m.Xtrain != nil { clone.Xtrain = mat.DenseCopyOf(m.Xtrain) diff --git a/go.mod b/go.mod index c9a177c..2349472 100644 --- a/go.mod +++ b/go.mod @@ -8,11 +8,9 @@ require ( github.com/jung-kurt/gofpdf v1.10.1 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a - github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4 + github.com/pa-m/randomkit v0.0.0-20191001073902-db4fd80633df github.com/pkg/errors v0.8.1 - golang.org/x/exp v0.0.0-20190829153037-c13cbed26979 - golang.org/x/image v0.0.0-20190902063713-cb417be4ba39 // indirect - golang.org/x/tools v0.0.0-20190905235650-93dcc2f048f5 // indirect + golang.org/x/exp v0.0.0-20191129062945-2f5052295587 gonum.org/v1/gonum v0.0.0-20190929233944-b20cf7805fc4 gonum.org/v1/plot v0.0.0-20190615073203-9aa86143727f gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect diff --git a/go.sum b/go.sum index 9568fae..de7831c 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802 h1:1BDTz0u9nC3//pOCMdNH+CiXJVYJh5UQNCOBG7jbELc= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af h1:wVe6/Ea46ZMeNkQjjBW6xcqyQA/j5e0D6GytH95g0gQ= @@ -18,6 +19,7 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE= github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= @@ -41,6 +43,8 @@ github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a h1:cgsB0XsJwsMq0JifJ github.com/pa-m/optimize v0.0.0-20190612075243-15ee852a6d9a/go.mod h1:gHioqOgOl5Wa4lmyUg/ojarU7Dfdkh/OnTnGA/WexsY= github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4 h1:+LyPTCDcQRARqza7LfS0w7v03e7VYceqQNTE8eRcGA4= github.com/pa-m/randomkit v0.0.0-20190612075210-f24d270692b4/go.mod h1:2Ix1Kyeujyr6FhU2SPX4iyiEpEBjHHcRV/Mki06ACcE= +github.com/pa-m/randomkit v0.0.0-20191001073902-db4fd80633df h1:waQf2YvgkQdOEK4IvtzwNIuFAo2FZd34JtAb/wrLbbc= +github.com/pa-m/randomkit v0.0.0-20191001073902-db4fd80633df/go.mod h1:rEyYBR/jbMkj6lX7VpWTAPPrjDIi/aNhAXmFuLMZS4o= github.com/phpdave11/gofpdi v1.0.7/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -61,6 +65,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de h1:xSjD6HQTqT0H/k60N5yYBtnN1OEkVy7WIo/DYyxKRO0= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f h1:9kQ594xxPWRNKfTOnPjPcgrIJ19zM3ic57aI7PbMyAA= @@ -71,8 +76,11 @@ golang.org/x/exp v0.0.0-20190312203227-4b39c73a6495 h1:I6A9Ag9FpEKOjcKrRNjQkPHaw golang.org/x/exp v0.0.0-20190312203227-4b39c73a6495/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522 h1:OeRHuibLsmZkFj773W4LcfAGsSxJgfPONhr8cmO+eLA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= +golang.org/x/exp v0.0.0-20190731235908-ec7cb31e5a56/go.mod h1:JhuoJpWY28nO4Vef9tZUw9qufEGTyX1+7lmHxV5q5G4= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979 h1:Agxu5KLo8o7Bb634SVDnhIfpTvxmzUwhbYAzBvXt6h4= golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= +golang.org/x/exp v0.0.0-20191129062945-2f5052295587 h1:5Uz0rkjCFu9BC9gCRN7EkwVvhNyQgGWb8KNJrPwBoHY= +golang.org/x/exp v0.0.0-20191129062945-2f5052295587/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81 h1:00VmoueYNlNz/aHIilyyQz/MHSqGoWJzpFv/HW8xpzI= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067 h1:KYGJGHOQy8oSi1fDlSpcZF0+juKwk/hEMv5SiwHogR0= @@ -87,17 +95,20 @@ golang.org/x/image v0.0.0-20190902063713-cb417be4ba39/go.mod h1:FeLwcggjj3mMvU+o golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190607214518-6fa95d984e88/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mobile v0.0.0-20190830201351-c6da95954960/go.mod h1:mJOp/i0LXPxJZ9weeIadcPqKVfS05Ai7m6/t9z1Hs/Y= golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190611141213-3f473d35a33a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313 h1:pczuHS43Cp2ktBEEmLwScxgjWsBSzdaQiKzUyf3DTTc= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190610200419-93c9922d18ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= @@ -112,6 +123,8 @@ golang.org/x/tools v0.0.0-20190611222205-d73e1c7e250b/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20190905235650-93dcc2f048f5 h1:xU4gBaA7ny56EkBSp9Uw2MVovJDupIfONnEOZ+FChTY= golang.org/x/tools v0.0.0-20190905235650-93dcc2f048f5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a h1:TwMENskLwU2NnWBzrJGEWHqSiGUkO/B4rfyhwqDxDYQ= +golang.org/x/tools v0.0.0-20191012152004-8de300cfc20a/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= gonum.org/v1/gonum v0.0.0-20190331200053-3d26580ed485 h1:OB/uP/Puiu5vS5QMRPrXCDWUPb+kt8f1KW8oQzFejQw= diff --git a/linear_model/logistic.go b/linear_model/logistic.go index 95cb69a..ddeda9e 100644 --- a/linear_model/logistic.go +++ b/linear_model/logistic.go @@ -170,7 +170,7 @@ func NewLogisticRegression() *LogisticRegression { func (m *LogisticRegression) PredicterClone() base.Predicter { clone := *m if sc, ok := m.RandomState.(base.SourceCloner); ok { - clone.RandomState = sc.Clone() + clone.RandomState = sc.SourceClone() } return &clone } diff --git a/model_selection/search.go b/model_selection/search.go index 01081df..6550a8b 100644 --- a/model_selection/search.go +++ b/model_selection/search.go @@ -70,7 +70,7 @@ func (gscv *GridSearchCV) PredicterClone() base.Predicter { } clone := *gscv if sourceCloner, ok := clone.RandomState.(base.SourceCloner); ok && sourceCloner != base.SourceCloner(nil) { - clone.RandomState = sourceCloner.Clone() + clone.RandomState = sourceCloner.SourceClone() } return &clone } diff --git a/model_selection/search_test.go b/model_selection/search_test.go index 36b6ea2..df01daa 100644 --- a/model_selection/search_test.go +++ b/model_selection/search_test.go @@ -88,8 +88,9 @@ func chkRandomState(rs rand.Source) { panic(fmt.Errorf("wrong random state\nexpected:%s\n%s\ngot :%s\n%s", expected, "", got, "")) } } + func ExampleGridSearchCV() { - RandomState := base.NewSource(7) + RandomState := base.NewLockedSource(7) ds := datasets.LoadBoston() X, Y := preprocessing.NewStandardScaler().FitTransform(ds.X, ds.Y) @@ -99,14 +100,15 @@ func ExampleGridSearchCV() { mlp.BatchSize = 20 mlp.LearningRateInit = .005 mlp.MaxIter = 100 + scorer := func(Y, Ypred mat.Matrix) float64 { return metrics.MeanSquaredError(Y, Ypred, nil, "").At(0, 0) } gscv := &GridSearchCV{ Estimator: mlp, ParamGrid: map[string][]interface{}{ - "Alpha": {2e-4, 5e-4, 1e-3}, - "WeightDecay": {.0002, .0001, 0}, + "Alpha": {1e-4, 2e-4, 5e-4, 1e-3}, + "WeightDecay": {1e-4, 1e-5, 1e-6,1e-7,1e-8, 0}, }, Scorer: scorer, LowerScoreIsBetter: true, @@ -120,7 +122,7 @@ func ExampleGridSearchCV() { // Output: // Alpha 0.0002 - // WeightDecay 0 + // WeightDecay 1e-06 } diff --git a/model_selection/split.go b/model_selection/split.go index 684130d..efc4397 100644 --- a/model_selection/split.go +++ b/model_selection/split.go @@ -1,12 +1,10 @@ package modelselection import ( - "github.com/pa-m/sklearn/base" "math" - "sort" + "github.com/pa-m/sklearn/base" "golang.org/x/exp/rand" - "gonum.org/v1/gonum/mat" ) @@ -41,7 +39,7 @@ func (splitter *KFold) SplitterClone() Splitter { } clone := *splitter if sourceCloner, ok := clone.RandomState.(base.SourceCloner); ok && sourceCloner != base.SourceCloner(nil) { - clone.RandomState = sourceCloner.Clone() + clone.RandomState = sourceCloner.SourceClone() } return &clone } @@ -114,29 +112,33 @@ func (splitter *KFold) GetNSplits(X, Y *mat.Dense) int { // TrainTestSplit splits X and Y into test set and train set // testsize must be between 0 and 1 -// it does'nt yet produce same sets than scikit-learn du to a different shuffle method +// it produce same sets than scikit-learn func TrainTestSplit(X, Y mat.Matrix, testsize float64, randomstate uint64) (Xtrain, Xtest, ytrain, ytest *mat.Dense) { NSamples, NFeatures := X.Dims() _, NOutputs := Y.Dims() var testlen int if testsize > 1 { - testlen = int(math.Round(math.Min(float64(NSamples), testsize))) + testlen = int(math.Ceil(math.Min(float64(NSamples), testsize))) } else { - testlen = int(math.Round(float64(NSamples) * testsize)) + testlen = int(math.Ceil(float64(NSamples) * testsize)) } Xtest = mat.NewDense(testlen, NFeatures, nil) ytest = mat.NewDense(testlen, NOutputs, nil) Xtrain = mat.NewDense(NSamples-testlen, NFeatures, nil) ytrain = mat.NewDense(NSamples-testlen, NOutputs, nil) src := base.NewLockedSource(randomstate) - shuffler := rand.New(src) - ind := make([]int, NSamples) - for i := range ind { - ind[i] = i - } - //shuffle ind - slice := sort.IntSlice(ind) - shuffler.Shuffle(slice.Len(), slice.Swap) + + var ind []int + src.WithLock(func(src base.Source) { + permer, ok := src.(base.Permer) + if !ok { + panic("Source does not implement Perm") + } + { + ind = permer.Perm(NSamples) + } + + }) for i := 0; i < NSamples; i++ { j := ind[i] if i < testlen { diff --git a/model_selection/split_test.go b/model_selection/split_test.go index dc0dcaf..81b917c 100644 --- a/model_selection/split_test.go +++ b/model_selection/split_test.go @@ -2,8 +2,8 @@ package modelselection import ( "fmt" - "github.com/pa-m/randomkit" - "testing" + + "github.com/pa-m/sklearn/datasets" "github.com/pa-m/sklearn/base" "golang.org/x/exp/rand" @@ -50,15 +50,70 @@ func perm(r base.Intner, n int) []int { return m } -func TestTrainTestSplit(t *testing.T) { - rs := randomkit.NewRandomkitSource(42) - NSamples := 178 - ind := make([]int, NSamples) - for i := range ind { - ind[i] = i - } - permer := rand.New(rs) - ind = permer.Perm(178) - fmt.Println(ind) + +func _ExampleTrainTestSplit() { + + features, target := datasets.LoadWine().GetXY() + RandomState := uint64(42) + _, _, Ytrain, Ytest := TrainTestSplit(features, target, .30, RandomState) + Ntrain, _ := Ytrain.Dims() + ytrain := make([]float64, Ntrain) + mat.Col(ytrain, 0, Ytrain) + fmt.Println(ytrain[:8]) + Ntest, _ := Ytest.Dims() + ytest := make([]float64, Ntest) + mat.Col(ytest, 0, Ytest) + fmt.Println(ytest[:8]) + // Output: + //[2 1 1 0 1 0 2 1] + //[0 0 2 0 1 0 1 2] +} + +func ExampleTrainTestSplit() { + /* + >>> import numpy as np + >>> from sklearn.model_selection import train_test_split + >>> X, y = np.arange(10).reshape((5, 2)), range(5) + >>> X_train, X_test, y_train, y_test = train_test_split( + ... X, y, test_size=0.33, random_state=42) + ... + >>> X_train + array([[4, 5], + [0, 1], + [6, 7]]) + >>> y_train + [2, 0, 3] + >>> X_test + array([[2, 3], + [8, 9]]) + >>> y_test + [1, 4] + + */ + X := mat.NewDense(5, 2, []float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}) + Y := mat.NewDense(5, 1, []float64{0, 1, 2, 3, 4}) + RandomState := uint64(42) + Xtrain, Xtest, Ytrain, Ytest := TrainTestSplit(X, Y, .33, RandomState) + fmt.Printf("X_train:\n%g\n", mat.Formatted(Xtrain)) + fmt.Printf("Y_train:\n%g\n", mat.Formatted(Ytrain)) + fmt.Printf("X_test:\n%g\n", mat.Formatted(Xtest)) + fmt.Printf("Y_test:\n%g\n", mat.Formatted(Ytest)) + + // Output: + //X_train: + //⎡4 5⎤ + //⎢0 1⎥ + //⎣6 7⎦ + //Y_train: + //⎡2⎤ + //⎢0⎥ + //⎣3⎦ + //X_test: + //⎡2 3⎤ + //⎣8 9⎦ + //Y_test: + //⎡1⎤ + //⎣4⎦ + } diff --git a/naive_bayes/naivebayes_test.go b/naive_bayes/naivebayes_test.go index d04840f..d3cbadb 100644 --- a/naive_bayes/naivebayes_test.go +++ b/naive_bayes/naivebayes_test.go @@ -21,7 +21,7 @@ func ExampleGaussianNB() { unscaledClf := pipeline.MakePipeline(pca, gnb) unscaledClf.Fit(Xtrain, Ytrain) - fmt.Printf("Prediction accuracy for the normal test dataset with PCA %.3f\n", unscaledClf.Score(Xtest, Ytest)) + fmt.Printf("Prediction accuracy for the normal test dataset with PCA %.2f %%\n", 100*unscaledClf.Score(Xtest, Ytest)) std = preprocessing.NewStandardScaler() pca = preprocessing.NewPCA() @@ -30,9 +30,9 @@ func ExampleGaussianNB() { clf := pipeline.MakePipeline(std, pca, gnb) clf.Fit(Xtrain, Ytrain) score := clf.Score(Xtest, Ytest) - fmt.Printf("Prediction accuracy for the standardized test dataset with PCA %.3f\n", score) + fmt.Printf("Prediction accuracy for the standardized test dataset with PCA %.2f %%\n", 100*score) // Output: - //Prediction accuracy for the normal test dataset with PCA 0.925 - //Prediction accuracy for the standardized test dataset with PCA 0.981 + // Prediction accuracy for the normal test dataset with PCA 70.37 % + // Prediction accuracy for the standardized test dataset with PCA 98.15 % } diff --git a/neural_network/multilayer_perceptron.go b/neural_network/multilayer_perceptron.go index 38c61bc..436735f 100644 --- a/neural_network/multilayer_perceptron.go +++ b/neural_network/multilayer_perceptron.go @@ -37,7 +37,7 @@ func (mlp *MLPRegressor) PredicterClone() base.Predicter { } clone := *mlp if sourceCloner, ok := clone.RandomState.(base.SourceCloner); ok && sourceCloner != base.SourceCloner(nil) { - clone.RandomState = sourceCloner.Clone() + clone.RandomState = sourceCloner.SourceClone() } return &clone } diff --git a/preprocessing/pca.go b/preprocessing/pca.go index debdc2b..067575b 100644 --- a/preprocessing/pca.go +++ b/preprocessing/pca.go @@ -44,7 +44,9 @@ func (m *PCA) Fit(Xmatrix, Ymatrix mat.Matrix) base.Fiter { } m.NComponents = nComponents } else { - m.NComponents = c + if m.NComponents == 0 { + m.NComponents = c + } } return m diff --git a/svm/svm.go b/svm/svm.go index c484fab..06749b9 100644 --- a/svm/svm.go +++ b/svm/svm.go @@ -241,7 +241,7 @@ func (m *SVC) PredicterClone() base.Predicter { } clone := *m if sourceCloner, ok := clone.RandomState.(base.SourceCloner); ok && sourceCloner != base.SourceCloner(nil) { - clone.RandomState = sourceCloner.Clone() + clone.RandomState = sourceCloner.SourceClone() } return &clone } diff --git a/svm/svr.go b/svm/svr.go index fc945fc..780f841 100644 --- a/svm/svr.go +++ b/svm/svr.go @@ -39,7 +39,7 @@ func (m *SVR) PredicterClone() base.Predicter { } clone := *m if sourceCloner, ok := clone.RandomState.(base.SourceCloner); ok && sourceCloner != base.SourceCloner(nil) { - clone.RandomState = sourceCloner.Clone() + clone.RandomState = sourceCloner.SourceClone() } return &clone }