-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regression #237
base: master
Are you sure you want to change the base?
Regression #237
Changes from all commits
6b69c32
940f0dc
739b4ae
a1bfda1
25b616d
d83aafb
a0a770c
f04f9d4
a283367
76f3ed5
50b69ec
56064a5
b005eca
529d66b
3a25215
f541e41
7248b44
502b717
2dbbb97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# This is a simple Makefile used to build the example source code. | ||
# This example might requires some modifications in order to work correctly on | ||
# your system. | ||
# If you're not using the Armadillo wrapper, replace `armadillo` with linker commands | ||
# for the BLAS and LAPACK libraries that you are using. | ||
|
||
TARGET := avocado_price_prediction | ||
SRC := avocado_price_prediction.cpp | ||
LIBS_NAME := armadillo | ||
|
||
CXX := g++ | ||
CXXFLAGS += -std=c++17 -Wall -Wextra -O3 -DNDEBUG -fopenmp | ||
# Use these CXXFLAGS instead if you want to compile with debugging symbols and | ||
# without optimizations. | ||
# CXXFLAGS += -std=c++17 -Wall -Wextra -g -O0 | ||
|
||
LDFLAGS += -fopenmp | ||
# Add header directories for any includes that aren't on the | ||
# default compiler search path. | ||
INCLFLAGS := -I . | ||
# If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and | ||
# update the lines below. | ||
# Uncomment the following if you are using the Scatter function for plotting | ||
# INCLFLAGS += -I/usr/include/python3.11 | ||
# INCLFLAGS += -I/path/to/ensmallen/include/ | ||
CXXFLAGS += $(INCLFLAGS) | ||
|
||
OBJS := $(SRC:.cpp=.o) | ||
LIBS := $(addprefix -l,$(LIBS_NAME)) | ||
CLEAN_LIST := $(TARGET) $(OBJS) | ||
|
||
# default rule | ||
default: all | ||
|
||
$(TARGET): $(OBJS) | ||
$(CXX) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS) | ||
|
||
.PHONY: all | ||
all: $(TARGET) | ||
|
||
.PHONY: clean | ||
clean: | ||
@echo CLEAN $(CLEAN_LIST) | ||
@rm -f $(CLEAN_LIST) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/** | ||
* Predicting Avocado's Average Price using Linear Regression | ||
* Our target is to predict the future price of avocados depending on various | ||
* features (Type, Region, Total Bags, ...). | ||
* Approach | ||
* | ||
* In this example, we will be using one hot encoding to encode categorical | ||
* features. Then, we will use LinearRegression API from mlpack to learn | ||
* the correlation between various features and the target i.e AveragePrice. | ||
* After training the model, we will use it to do some predictions, followed by | ||
* various evaluation metrics to quantify how well our model behaves. | ||
*/ | ||
|
||
#include <mlpack.hpp> | ||
|
||
using namespace mlpack; | ||
|
||
int main() | ||
{ | ||
/** Dataset | ||
* | ||
* Avocado Prices dataset has the following features: | ||
* PLU - Product Lookup Code in Hass avocado board. | ||
* Date - The date of the observation. | ||
* AveragePrice - Observed average price of single avocado. | ||
* Total Volume - Total number of avocado's sold. | ||
* 4046 - Total number of avocado's with PLU 4046 sold. | ||
* 4225 - Total number of avocado's with PLU 4225 sold. | ||
* 4770 - Total number of avocado's with PLU 4770 sold. | ||
* Total Bags = Small Bags + Large Bags + XLarge Bags. | ||
* Type - Conventional or organic. | ||
* Year - Year of observation. | ||
* Region - City or region of observation. | ||
* | ||
* 9 Avocado type and 11 region of observation are categorical string, | ||
* but armadillo matrices can contain only numeric information | ||
* Therefore, we explicitly define them as categorical in `datasetInfo` | ||
* this allows mlpack to map numeric values to each of those values, | ||
* which can later be unmapped to strings. | ||
*/ | ||
/** PLEASE, delete the header of the dataset once you have downloaded the | ||
* datset to your data/ directory. **/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely we should support headers soon, but for now should we modify the dataset downloader to strip the header? |
||
// Load the dataset into armadillo matrix. | ||
arma::mat matrix; | ||
data::DatasetInfo info; | ||
info.Type(9) = data::Datatype::categorical; | ||
info.Type(11) = data::Datatype::categorical; | ||
data::Load("../../../data/avocado.csv", matrix, info); | ||
|
||
arma::mat output; | ||
data::OneHotEncoding(matrix, output, info); | ||
arma::rowvec targets = arma::conv_to<arma::rowvec>::from(output.row(0)); | ||
|
||
// Labels are dropped from the originally loaded data to be used as features. | ||
output.shed_row(0); | ||
|
||
// Train Test Split, | ||
// The dataset has to be split into a training set and a test set. Here the | ||
// dataset has 18249 observations and the `testRatio` is set to 20% of the | ||
// total observations. This indicates the test set should have | ||
// 20% * 18249 = 3649 observations and training test should have | ||
// 14600 observations respectively. | ||
arma::mat Xtrain, Xtest; | ||
arma::rowvec Ytrain, Ytest; | ||
data::Split(output, targets, Xtrain, Xtest, Ytrain, Ytest, 0.2); | ||
|
||
/* Training the linear model. | ||
* Regression analysis is the most widely used method of prediction. | ||
* Linear regression is used when the dataset has a linear correlation | ||
* and as the name suggests, multiple linear regression has one independent | ||
* variable (predictor) and one or more dependent variable(response). | ||
* The simple linear regression equation is represented as | ||
* y = $a + b_{1}x_{1} + b_{2}x_{2} + b_{3}x_{3} + ... + b_{n}x_{n}$ | ||
* where $x_{i}$ is the ith explanatory variable, y is the dependent | ||
* variable, $b_{i}$ is ith coefficient and a is the intercept. | ||
* To perform linear regression we'll be using the `LinearRegression` class from mlpack. | ||
* Create and train Linear Regression model. | ||
*/ | ||
LinearRegression lr(Xtrain, Ytrain, 0.5); | ||
|
||
arma::rowvec Ypreds; | ||
lr.Predict(Xtest, Ypreds); | ||
|
||
/* | ||
* Model Evaluation, | ||
* To evaulate the model we use Mean Absolute Error (MAE) which | ||
* is the sum of absolute differences between actual | ||
* and predicted values, without considering the direction. | ||
* MAE = \\frac{\\sum_{i=1}^n\\lvert y_{i} - \\hat{y_{i}}\\rvert} {n} | ||
* Mean Squared Error (MSE) is calculated as the mean or average of the | ||
* squared differences between predicted and expected target values in | ||
* a dataset, a lower value is better | ||
* MSE = \\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2, | ||
* Root Mean Squared Error (RMSE), Square root of MSE yields root mean square | ||
* error (RMSE) it indicates the spread of the residual errors. It is always | ||
* positive, and a lower value indicates better performance. | ||
* RMSE = \\sqrt{\\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2} | ||
*/ | ||
// Model evaluation metrics. | ||
// From the above metrics, we can notice that our model MAE is ~0.2, | ||
// which is relatively small compared to our average price of $1.405, | ||
// from this and the above plot we can conclude our model is a reasonably | ||
// good fit. | ||
|
||
std::cout << "Mean Absolute Error: " | ||
<< arma::mean(arma::abs(Ypreds - Ytest)) << std::endl; | ||
std::cout << "Mean Squared Error: " | ||
<< arma::mean(arma::pow(Ypreds - Ytest, 2)) << std::endl; | ||
std::cout << "Root Mean Squared Error: " | ||
<< sqrt(arma::mean(arma::pow(Ypreds - Ytest, 2))) << std::endl; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# This is a simple Makefile used to build the example source code. | ||
# This example might requires some modifications in order to work correctly on | ||
# your system. | ||
# If you're not using the Armadillo wrapper, replace `armadillo` with linker commands | ||
# for the BLAS and LAPACK libraries that you are using. | ||
|
||
TARGET := california-house-price-prediction | ||
SRC := california_house_price_prediction.cpp | ||
LIBS_NAME := armadillo | ||
|
||
CXX := g++ | ||
CXXFLAGS += -std=c++17 -Wall -Wextra -O3 -DNDEBUG -fopenmp | ||
# Use these CXXFLAGS instead if you want to compile with debugging symbols and | ||
# without optimizations. | ||
# CXXFLAGS += -std=c++17 -Wall -Wextra -g -O0 | ||
|
||
LDFLAGS += -fopenmp | ||
# Add header directories for any includes that aren't on the | ||
# default compiler search path. | ||
INCLFLAGS := -I . | ||
# If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and | ||
# update the lines below. | ||
# Uncomment the following if you are using the Scatter function for plotting | ||
# INCLFLAGS += -I/usr/include/python3.11 | ||
# INCLFLAGS += -I/path/to/ensmallen/include/ | ||
CXXFLAGS += $(INCLFLAGS) | ||
|
||
OBJS := $(SRC:.cpp=.o) | ||
LIBS := $(addprefix -l,$(LIBS_NAME)) | ||
CLEAN_LIST := $(TARGET) $(OBJS) | ||
|
||
# default rule | ||
default: all | ||
|
||
$(TARGET): $(OBJS) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here with |
||
$(CXX) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS) | ||
|
||
.PHONY: all | ||
all: $(TARGET) | ||
|
||
.PHONY: clean | ||
clean: | ||
@echo CLEAN $(CLEAN_LIST) | ||
@rm -f $(CLEAN_LIST) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
/** | ||
* Predicting California House Prices with Linear Regression | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Code-wise, this is basically the same example as the avocado one, but with a different dataset. Do you think it is worth including? I think quality over quantity is better, so if this is not demonstrating some new or different functionality vs. the avocado example, I would vote to only keep one of the two. |
||
* | ||
* Objective | ||
* | ||
* To predict California Housing Prices using the most simple Linear Regression | ||
* Model and see how it performs. To understand the modeling workflow using mlpack. | ||
* | ||
* Approach | ||
* | ||
* Here, we will try to recreate the workflow from the book mentioned above. | ||
* Pre-Process the data for the Ml Algorithm. | ||
* Create new features. | ||
* Splitting the data. | ||
* Training the ML model using mlpack. | ||
* Residuals, Errors and Conclusion. | ||
*/ | ||
|
||
#include <mlpack.hpp> | ||
|
||
using namespace mlpack; | ||
|
||
int main() | ||
{ | ||
/** | ||
* Dataset structure | ||
* | ||
* This dataset is a modified version of the California Housing | ||
* dataset available from Luís Torgo's page (University of Porto). | ||
* Luís Torgo obtained it from the StatLib repository. The dataset | ||
* may also be downloaded from StatLib mirrors. | ||
* | ||
* Longitude : Longitude coordinate of the houses. | ||
* Latitude : Latitude coordinate of the houses. | ||
* Housing Median Age : Average lifespan of houses. | ||
* Total Rooms : Number of rooms in a location. | ||
* Total Bedrooms : Number of bedroooms in a location. | ||
* Population : Population in that location. | ||
* Median Income : Median Income of households in a location. | ||
* Median House Value : Median House Value in a location. | ||
* Ocean Proximity : Closeness to shore. | ||
* | ||
* we need to load the dataset as an Armadillo matrix for further operations. | ||
* Our dataset has a total of 9 features: 8 numerical and | ||
* 1 categorical(ocean proximity). We need to map the | ||
* categorical features, as armadillo operates on numeric | ||
* values only. | ||
*/ | ||
arma::mat dataset; | ||
data::DatasetInfo info; | ||
info.Type(9) = mlpack::data::Datatype::categorical; | ||
// Please remove the header of the file if exist, otherwise the results will | ||
// not work | ||
data::Load("../../../data/housing.csv", dataset, info); | ||
|
||
arma::mat encoded_dataset; | ||
// Here, we chose our pre-built encoding method "One Hot Encoding" to deal | ||
// with the categorical values. | ||
data::OneHotEncoding(dataset, encoded_dataset, info); | ||
// The dataset needs to be split into a training and testing set before we learn any model. | ||
// Labels are median_house_value which is row 8 | ||
arma::rowvec labels = | ||
arma::conv_to<arma::rowvec>::from(encoded_dataset.row(8)); | ||
encoded_dataset.shed_row(8); | ||
|
||
arma::mat trainSet, testSet; | ||
arma::rowvec trainLabels, testLabels; | ||
data::Split(encoded_dataset, labels, trainSet, testSet, trainLabels, testLabels, | ||
0.2 /* Percentage of dataset to use for test set. */); | ||
|
||
// Training the linear model | ||
/* Regression analysis is the most widely used method of prediction. | ||
* Linear regression is used when the dataset has a linear correlation | ||
* and as the name suggests, multiple linear regression has one independent | ||
* variable (predictor) and one or more dependent variable(response). | ||
*/ | ||
|
||
/** | ||
* The simple linear regression equation is represented as | ||
* y = $a + b_{1}x_{1} + b_{2}x_{2} + b_{3}x_{3} + ... + b_{n}x_{n}$ | ||
* where: | ||
* $x_{i}$ is the ith explanatory variable, | ||
* y is the dependent variable, | ||
* $b_{i}$ is ith coefficient and a is the intercept. | ||
*/ | ||
|
||
/* To perform linear regression we'll be using the `LinearRegression` | ||
* class from mlpack. | ||
*/ | ||
LinearRegression lr(trainSet, trainLabels, 0.5); | ||
|
||
// The line above creates and train the model. | ||
// Let's create a output vector for storing the results. | ||
arma::rowvec output; | ||
lr.Predict(testSet, output); | ||
lr.ComputeError(trainSet, trainLabels); | ||
std::cout << lr.ComputeError(trainSet, trainLabels); | ||
|
||
// Let's manually check some predictions. | ||
std::cout << testLabels[1] << std::endl; | ||
std::cout << output[1] << std::endl; | ||
std::cout << testLabels[7] << std::endl; | ||
std::cout << output[7] << std::endl; | ||
arma::mat preds; | ||
preds.insert_rows(0, testLabels); | ||
preds.insert_rows(1, output); | ||
|
||
arma::mat diffs = preds.row(1) - preds.row(0); | ||
data::Save("preds.csv", preds); | ||
data::Save("predsDiff.csv", diffs); | ||
|
||
/** | ||
* Model Evaluation | ||
* Evaluation Metrics for Regression model | ||
* In the previous cell we have visualized our model performance by plotting | ||
* the best fit line. Now we will use various evaluation metrics to understand | ||
* how well our model has performed. | ||
* Mean Absolute Error (MAE) is the sum of absolute differences between actual | ||
* and predicted values, without considering the direction. | ||
* MAE = \\frac{\\sum_{i=1}^n\\lvert y_{i} - \\hat{y_{i}}\\rvert} {n} | ||
* Mean Squared Error (MSE) is calculated as the mean or average of the | ||
* squared differences between predicted and expected target values in a | ||
* dataset, a lower value is better | ||
* MSE = \\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2 | ||
* Root Mean Squared Error (RMSE), Square root of MSE yields | ||
* root mean square error (RMSE) it indicates the spread of | ||
* the residual errors. It is always positive, and a lower | ||
* value indicates better performance. | ||
* RMSE = \\sqrt{\\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2} | ||
*/ | ||
std::cout << "Mean Absolute Error: " | ||
<< arma::mean(arma::abs(output - testLabels)) << std::endl; | ||
std::cout << "Mean Squared Error: " | ||
<< arma::mean(arma::pow(output - testLabels,2)) << std::endl; | ||
std::cout << "Root Mean Squared Error: " | ||
<< sqrt(arma::mean(arma::pow(output - testLabels,2))) << std::endl; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to get
CXXFLAGS
in here too?