-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regression #237
base: master
Are you sure you want to change the base?
Regression #237
Changes from 11 commits
6b69c32
940f0dc
739b4ae
a1bfda1
25b616d
d83aafb
a0a770c
f04f9d4
a283367
76f3ed5
50b69ec
56064a5
b005eca
529d66b
3a25215
f541e41
7248b44
502b717
2dbbb97
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# This is a simple Makefile used to build the example source code. | ||
# This example might requires some modifications in order to work correctly on | ||
# your system. | ||
# If you're not using the Armadillo wrapper, replace `armadillo` with linker commands | ||
# for the BLAS and LAPACK libraries that you are using. | ||
|
||
TARGET := avocado_price_prediction | ||
SRC := avocado_price_prediction.cpp | ||
LIBS_NAME := armadillo | ||
|
||
CXX := g++ | ||
CXXFLAGS += -std=c++17 -Wall -Wextra -O3 -DNDEBUG -fopenmp | ||
# Use these CXXFLAGS instead if you want to compile with debugging symbols and | ||
# without optimizations. | ||
# CXXFLAGS += -std=c++17 -Wall -Wextra -g -O0 | ||
|
||
LDFLAGS += -fopenmp | ||
# Add header directories for any includes that aren't on the | ||
# default compiler search path. | ||
INCLFLAGS := -I . | ||
# If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and | ||
# update the lines below. | ||
# Uncomment the following if you are using the Scatter function for plotting | ||
# INCLFLAGS += -I/usr/include/python3.11 | ||
# INCLFLAGS += -I/path/to/ensmallen/include/ | ||
CXXFLAGS += $(INCLFLAGS) | ||
|
||
OBJS := $(SRC:.cpp=.o) | ||
LIBS := $(addprefix -l,$(LIBS_NAME)) | ||
CLEAN_LIST := $(TARGET) $(OBJS) | ||
|
||
# default rule | ||
default: all | ||
|
||
$(TARGET): $(OBJS) | ||
$(CXX) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS) | ||
|
||
.PHONY: all | ||
all: $(TARGET) | ||
|
||
.PHONY: clean | ||
clean: | ||
@echo CLEAN $(CLEAN_LIST) | ||
@rm -f $(CLEAN_LIST) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
/** | ||
* Predicting Avocado's Average Price using Linear Regression | ||
* Our target is to predict the future price of avocados depending on various | ||
* features (Type, Region, Total Bags, ...). | ||
* | ||
* Dataset | ||
* | ||
* Avocado Prices dataset has the following features: | ||
* PLU - Product Lookup Code in Hass avocado board. | ||
* Date - The date of the observation. | ||
* AveragePrice - Observed average price of single avocado. | ||
* Total Volume - Total number of avocado's sold. | ||
* 4046 - Total number of avocado's with PLU 4046 sold. | ||
* 4225 - Total number of avocado's with PLU 4225 sold. | ||
* 4770 - Total number of avocado's with PLU 4770 sold. | ||
* Total Bags = Small Bags + Large Bags + XLarge Bags. | ||
* Type - Conventional or organic. | ||
* Year - Year of observation. | ||
* Region - City or region of observation. | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* | ||
* Approach | ||
* | ||
* In this example, first we will do EDA on the dataset to find correlation | ||
* between various features. | ||
* Then we'll be using onehot encoding to encode categorical features. | ||
* Finally we will use LinearRegression API from mlpack to learn the correlation | ||
* between various features and the target i.e AveragePrice. | ||
* After training the model, we will use it to do some predictions, followed by | ||
* various evaluation metrics to quantify how well our model behaves. | ||
*/ | ||
#include <mlpack.hpp> | ||
using namespace mlpack; | ||
using namespace mlpack::data; | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
//Drop the dataset header using sed, sed is a Unix utility that parses and transforms text." | ||
//!mkdir -p data && cat avocado.csv | sed 1d > avocado_trim.csv" | ||
//"Drop columns 1 and 2 (\"Unamed: 0\", \"Date\") as these are not required and their presence cause issues while loading the data." | ||
//!rm avocado_trim.csv" | ||
//"!mv avocado_trim2.csv avocado_trim.csv" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this isn't a notebook, we don't have the ability to call out to the shell like this. So maybe we either indicate in the comments that we expect the user to run these commands, or we adapt |
||
|
||
int main() | ||
{ | ||
//!wget -q https://datasets.mlpack.org/avocado.csv.gz" | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Features 9 (Avocado type) and 11 (region of observation) are strings | ||
// (categorical), but armadillo matrices can contain only numeric information; | ||
// so, we have to explicitly define them as categorical in `datasetInfo` | ||
// this allows mlpack to map numeric values to each of those values, | ||
// which can later be unmapped to strings. | ||
// Load the dataset into armadillo matrix. | ||
|
||
arma::mat matrix; | ||
data::DatasetInfo info; | ||
info.Type(9) = data::Datatype::categorical; | ||
info.Type(11) = data::Datatype::categorical; | ||
data::Load("../../../data/avocado.csv", matrix, info); | ||
// Printing header for dataset. | ||
std::cout << std::setw(10) << "AveragePrice" << std::setw(14) | ||
<< "Total Volume" << std::setw(9) << "4046" << std::setw(13) | ||
<< "4225" << std::setw(13) << "4770" << std::setw(17) << "Total Bags" | ||
<< std::setw(13) << "Small Bags" << std::setw(13) << "Large Bags" | ||
<< std::setw(17) << "XLarge Bags" << std::setw(10) << "Type" | ||
<< std::setw(10) << "Year" << std::setw(15) << "Region" << std::endl; | ||
std::cout << matrix.submat(0, 0, matrix.n_rows-1, 5).t() << std::endl; | ||
// Exploratory Data Analysis | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
arma::mat output; | ||
data::OneHotEncoding(matrix, output, info); | ||
arma::Row<double> targets = arma::conv_to<arma::Row<double>>::from(output.row(0)); | ||
// Labels are dropped from the originally loaded data to be used as features. | ||
output.shed_row(0); | ||
// Train Test Split, | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// The dataset has to be split into a training set and a test set. Here the | ||
// dataset has 18249 observations and the `testRatio` is set to 20% of the | ||
// total observations. This indicates the test set should have | ||
// 20% * 18249 = 3649 observations and training test should have | ||
// 14600 observations respectively. | ||
arma::mat Xtrain; | ||
arma::mat Xtest; | ||
arma::Row<double> Ytrain; | ||
arma::Row<double> Ytest; | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
data::Split(output, targets, Xtrain, Xtest, Ytrain, Ytest, 0.2); | ||
// Convert armadillo Rows into rowvec. (Required by mlpacks' LinearRegression API in this format). | ||
arma::rowvec yTrain = arma::conv_to<arma::rowvec>::from(Ytrain); | ||
arma::rowvec yTest = arma::conv_to<arma::rowvec>::from(Ytest); | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
/* Training the linear model. | ||
* Regression analysis is the most widely used method of prediction. | ||
* Linear regression is used when the dataset has a linear correlation | ||
* and as the name suggests, multiple linear regression has one independent | ||
* variable (predictor) and one or more dependent variable(response). | ||
* The simple linear regression equation is represented as | ||
* y = $a + b_{1}x_{1} + b_{2}x_{2} + b_{3}x_{3} + ... + b_{n}x_{n}$ | ||
* where $x_{i}$ is the ith explanatory variable, y is the dependent | ||
* variable, $b_{i}$ is ith coefficient and a is the intercept. | ||
* To perform linear regression we'll be using the `LinearRegression` class from mlpack. | ||
* Create and train Linear Regression model. | ||
*/ | ||
LinearRegression lr(Xtrain, yTrain, 0.5); | ||
arma::rowvec yPreds; | ||
lr.Predict(Xtest, yPreds); | ||
// Save the yTest and yPreds into csv for generating plots. | ||
arma::mat preds; | ||
preds.insert_rows(0, yTest); | ||
preds.insert_rows(1, yPreds); | ||
arma::mat histpreds = yTest - yPreds; | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mlpack::data::Save("./data/predictions.csv", preds); | ||
mlpack::data::Save("./data/predsDiff.csv", yPreds); | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/* | ||
* Model Evaluation, | ||
* Test data is visualized with `yTest` and `yPreds`, the blue points | ||
shrit marked this conversation as resolved.
Show resolved
Hide resolved
|
||
* indicates the data points and the blue line indicates the regression | ||
* line or best fit line. | ||
* Evaluation Metrics for Regression model, | ||
* In the previous cell we have visualized our model performance by plotting | ||
* the best fit line. Now we will use various evaluation metrics to understand | ||
* how well our model has performed. | ||
* Mean Absolute Error (MAE) is the sum of absolute differences between actual | ||
* and predicted values, without considering the direction. | ||
* MAE = \\frac{\\sum_{i=1}^n\\lvert y_{i} - \\hat{y_{i}}\\rvert} {n} | ||
* Mean Squared Error (MSE) is calculated as the mean or average of the | ||
* squared differences between predicted and expected target values in | ||
* a dataset, a lower value is better | ||
* MSE = \\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2, | ||
* Root Mean Squared Error (RMSE), Square root of MSE yields root mean square | ||
* error (RMSE) it indicates the spread of the residual errors. It is always | ||
* positive, and a lower value indicates better performance. | ||
* RMSE = \\sqrt{\\frac {1}{n} \\sum_{i=1}^n (y_{i} - \\hat{y_{i}})^2} | ||
*/ | ||
// Model evaluation metrics. | ||
// From the above metrics, we can notice that our model MAE is ~0.2, | ||
// which is relatively small compared to our average price of $1.405, | ||
// from this and the above plot we can conclude our model is a reasonably | ||
// good fit. | ||
|
||
std::cout << "Mean Absolute Error: " | ||
<< arma::mean(arma::abs(yPreds - yTest)) << std::endl; | ||
std::cout << "Mean Squared Error: " | ||
<< arma::mean(arma::pow(yPreds - yTest,2)) << std::endl; | ||
std::cout << "Root Mean Squared Error: " | ||
<< sqrt(arma::mean(arma::pow(yPreds - yTest,2))) << std::endl; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# This is a simple Makefile used to build the example source code. | ||
# This example might requires some modifications in order to work correctly on | ||
# your system. | ||
# If you're not using the Armadillo wrapper, replace `armadillo` with linker commands | ||
# for the BLAS and LAPACK libraries that you are using. | ||
|
||
TARGET := california-house-price-prediction | ||
SRC := california_house_price_prediction.cpp | ||
LIBS_NAME := armadillo | ||
|
||
CXX := g++ | ||
CXXFLAGS += -std=c++17 -Wall -Wextra -O3 -DNDEBUG -fopenmp | ||
# Use these CXXFLAGS instead if you want to compile with debugging symbols and | ||
# without optimizations. | ||
# CXXFLAGS += -std=c++17 -Wall -Wextra -g -O0 | ||
|
||
LDFLAGS += -fopenmp | ||
# Add header directories for any includes that aren't on the | ||
# default compiler search path. | ||
INCLFLAGS := -I . | ||
# If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and | ||
# update the lines below. | ||
# Uncomment the following if you are using the Scatter function for plotting | ||
# INCLFLAGS += -I/usr/include/python3.11 | ||
# INCLFLAGS += -I/path/to/ensmallen/include/ | ||
CXXFLAGS += $(INCLFLAGS) | ||
|
||
OBJS := $(SRC:.cpp=.o) | ||
LIBS := $(addprefix -l,$(LIBS_NAME)) | ||
CLEAN_LIST := $(TARGET) $(OBJS) | ||
|
||
# default rule | ||
default: all | ||
|
||
$(TARGET): $(OBJS) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here with |
||
$(CXX) $(OBJS) -o $(TARGET) $(LDFLAGS) $(LIBS) | ||
|
||
.PHONY: all | ||
all: $(TARGET) | ||
|
||
.PHONY: clean | ||
clean: | ||
@echo CLEAN $(CLEAN_LIST) | ||
@rm -f $(CLEAN_LIST) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to get
CXXFLAGS
in here too?