diff --git a/datasets/base.go b/datasets/base.go index 1a1213c..7708da6 100644 --- a/datasets/base.go +++ b/datasets/base.go @@ -14,6 +14,19 @@ import ( "gonum.org/v1/gonum/mat" ) +var ( + // Dir stores the resolved import path of github.com/pa-m/sklearn + Dir string +) + +func init() { + pkg, err := build.Import("github.com/pa-m/sklearn", ".", build.FindOnly) + if err != nil { + panic(err) + } + Dir = pkg.Dir +} + // MLDataset structure returned by LoadIris,LoadBreastCancer,LoadDiabetes,LoadBoston type MLDataset struct { Data [][]float64 `json:"data,omitempty"` @@ -49,27 +62,27 @@ func loadJSON(filepath string) (ds *MLDataset) { // LoadIris load the iris dataset func LoadIris() (ds *MLDataset) { - return loadJSON(localPath("/src/github.com/pa-m/sklearn/datasets/data/iris.json")) + return loadJSON(localPath("datasets/data/iris.json")) } // LoadBreastCancer load the breat cancer dataset func LoadBreastCancer() (ds *MLDataset) { - return loadJSON(localPath("/src/github.com/pa-m/sklearn/datasets/data/cancer.json")) + return loadJSON(localPath("datasets/data/cancer.json")) } // LoadDiabetes load the diabetes dataset func LoadDiabetes() (ds *MLDataset) { - return loadJSON(localPath("/src/github.com/pa-m/sklearn/datasets/data/diabetes.json")) + return loadJSON(localPath("datasets/data/diabetes.json")) } // LoadBoston load the boston housing dataset func LoadBoston() (ds *MLDataset) { - return loadJSON(localPath("/src/github.com/pa-m/sklearn/datasets/data/boston.json")) + return loadJSON(localPath("datasets/data/boston.json")) } // LoadWine load the boston housing dataset func LoadWine() (ds *MLDataset) { - return loadJSON(localPath("/src/github.com/pa-m/sklearn/datasets/data/wine.json")) + return loadJSON(localPath("datasets/data/wine.json")) } // GetXY returns X,Y matrices for dataset @@ -87,30 +100,30 @@ func (ds *MLDataset) GetXY() (X, Y *mat.Dense) { } func localPath(s string) string { - p := strings.Split(build.Default.GOPATH, string(filepath.ListSeparator))[0] + p := strings.Split(Dir, string(filepath.ListSeparator))[0] return filepath.Join(p, s) } // LoadExamScore loads data from ex2data1 from Andrew Ng machine learning course func LoadExamScore() (X, Y *mat.Dense) { - return loadCsv(localPath("/src/github.com/pa-m/sklearn/datasets/data/ex2data1.txt"), nil, 1) + return loadCsv(localPath("datasets/data/ex2data1.txt"), nil, 1) } // LoadMicroChipTest loads data from ex2data2 from Andrew Ng machine learning course func LoadMicroChipTest() (X, Y *mat.Dense) { - return loadCsv(localPath("/src/github.com/pa-m/sklearn/datasets/data/ex2data2.txt"), nil, 1) + return loadCsv(localPath("datasets/data/ex2data2.txt"), nil, 1) } // LoadMnist loads mnist data 5000x400,5000x1 func LoadMnist() (X, Y *mat.Dense) { - mats := LoadOctaveBin(localPath("/src/github.com/pa-m/sklearn/datasets/data/ex4data1.dat.gz")) + mats := LoadOctaveBin(localPath("datasets/data/ex4data1.dat.gz")) return mats["X"], mats["y"] } // LoadMnistWeights loads mnist weights func LoadMnistWeights() (Theta1, Theta2 *mat.Dense) { - mats := LoadOctaveBin(localPath("/src/github.com/pa-m/sklearn/datasets/data/ex4weights.dat.gz")) + mats := LoadOctaveBin(localPath("datasets/data/ex4weights.dat.gz")) return mats["Theta1"], mats["Theta2"] } @@ -144,7 +157,7 @@ func check(err error) { // LoadInternationalAirlinesPassengers ... func LoadInternationalAirlinesPassengers() (Y *mat.Dense) { - f, err := os.Open(realPath(os.Getenv("GOPATH") + "/src/github.com/pa-m/sklearn/datasets/data/international-airline-passengers.csv")) + f, err := os.Open(realPath(localPath("datasets/data/international-airline-passengers.csv"))) check(err) defer f.Close() fb := bufio.NewReader(f)