Skip to content

Commit

Permalink
Node HelloWorld added
Browse files Browse the repository at this point in the history
  • Loading branch information
a-saraf committed Oct 3, 2023
1 parent 9bcb40a commit 23e3aab
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions proxy-pipeline/train_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.metrics import mean_squared_error as mse
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.datasets import fetch_california_housing

import torch
import torch.nn as nn
Expand All @@ -28,6 +29,7 @@
flags.DEFINE_bool('visualize', False, 'enable visualization of the data')
flags.DEFINE_bool('train', False, 'enable training of the model')
flags.DEFINE_integer('output_index', 0, 'Index of the output to train the model on')
flags.DEFINE_bool('custom_dataset', False, 'Whether to use a custom dataset or not')

# Hyperparameters for the model
# Dense Block Parameters
Expand Down Expand Up @@ -59,6 +61,10 @@ def preprocess_data(actions, observations, exp_path):
observations = observations.to_frame()
# Categorical features
categorical_cols = list(set(actions.columns) - set(actions._get_numeric_data().columns))
if len(categorical_cols) == 0:
NO_CAT = True
else:
NO_CAT = False
categorical_actions = actions[categorical_cols]

# Numerical features
Expand All @@ -69,7 +75,7 @@ def preprocess_data(actions, observations, exp_path):
os.makedirs(encoder_path)

# Encode categorical features
if FLAGS.encode == 'one_hot':
if FLAGS.encode == 'one_hot' and not NO_CAT:
# One-hot encode categorical features
enc = OneHotEncoder(handle_unknown='ignore')
enc.fit(categorical_actions)
Expand All @@ -79,7 +85,7 @@ def preprocess_data(actions, observations, exp_path):
# Transform the categorical features
dummy_col_names = pd.get_dummies(categorical_actions).columns
categorical_actions = pd.DataFrame(enc.transform(categorical_actions).toarray(), columns=dummy_col_names)
elif FLAGS.encode == 'label':
elif FLAGS.encode == 'label' and not NO_CAT:
dummy_actions = pd.DataFrame()
for categorical_col in categorical_cols:
# Label encode categorical features
Expand All @@ -91,7 +97,7 @@ def preprocess_data(actions, observations, exp_path):
# Transform the categorical features
dummy_actions[categorical_col] = enc.transform(categorical_actions[categorical_col])
categorical_actions = pd.DataFrame(dummy_actions, columns=categorical_cols)
else:
elif not NO_CAT:
raise ValueError('Encoding method not supported')

preprocess_data_path = os.path.join(exp_path, 'preprocess_data')
Expand Down Expand Up @@ -137,7 +143,10 @@ def preprocess_data(actions, observations, exp_path):
raise ValueError('Preprocessing method not supported')

# Concatenate numerical and categorical features
actions = pd.concat([numerical_actions, categorical_actions], axis = 1).to_numpy()
if NO_CAT:
actions = numerical_actions.to_numpy()
else:
actions = pd.concat([numerical_actions, categorical_actions], axis = 1).to_numpy()
observations = observations.to_numpy()

return actions, observations
Expand Down Expand Up @@ -211,11 +220,15 @@ def main(_):
os.makedirs(exp_path)

# Load the data
actions_path = os.path.join(FLAGS.data_path, 'actions_feasible.csv')
observations_path = os.path.join(FLAGS.data_path, 'observations_feasible.csv')

actions = pd.read_csv(actions_path)
observations = pd.read_csv(observations_path)
if FLAGS.custom_dataset:
actions_path = os.path.join(FLAGS.data_path, 'actions_feasible.csv')
observations_path = os.path.join(FLAGS.data_path, 'observations_feasible.csv')
actions = pd.read_csv(actions_path)
observations = pd.read_csv(observations_path)
else:
california = fetch_california_housing()
actions = pd.DataFrame(california.data, columns=california.feature_names)
observations = pd.DataFrame(california.target, columns=['MEDV'])

output = observations.copy()
if FLAGS.output_index >= output.shape[1]:
Expand Down

0 comments on commit 23e3aab

Please sign in to comment.