diff --git a/proxy-pipeline/train_node.py b/proxy-pipeline/train_node.py index 3eee21c1..4bf7b4c3 100644 --- a/proxy-pipeline/train_node.py +++ b/proxy-pipeline/train_node.py @@ -13,6 +13,7 @@ from sklearn.metrics import mean_squared_error as mse from sklearn.preprocessing import MinMaxScaler, StandardScaler from sklearn.preprocessing import OneHotEncoder, LabelEncoder +from sklearn.datasets import fetch_california_housing import torch import torch.nn as nn @@ -28,6 +29,7 @@ flags.DEFINE_bool('visualize', False, 'enable visualization of the data') flags.DEFINE_bool('train', False, 'enable training of the model') flags.DEFINE_integer('output_index', 0, 'Index of the output to train the model on') +flags.DEFINE_bool('custom_dataset', False, 'Whether to use a custom dataset or not') # Hyperparameters for the model # Dense Block Parameters @@ -59,6 +61,10 @@ def preprocess_data(actions, observations, exp_path): observations = observations.to_frame() # Categorical features categorical_cols = list(set(actions.columns) - set(actions._get_numeric_data().columns)) + if len(categorical_cols) == 0: + NO_CAT = True + else: + NO_CAT = False categorical_actions = actions[categorical_cols] # Numerical features @@ -69,7 +75,7 @@ def preprocess_data(actions, observations, exp_path): os.makedirs(encoder_path) # Encode categorical features - if FLAGS.encode == 'one_hot': + if FLAGS.encode == 'one_hot' and not NO_CAT: # One-hot encode categorical features enc = OneHotEncoder(handle_unknown='ignore') enc.fit(categorical_actions) @@ -79,7 +85,7 @@ def preprocess_data(actions, observations, exp_path): # Transform the categorical features dummy_col_names = pd.get_dummies(categorical_actions).columns categorical_actions = pd.DataFrame(enc.transform(categorical_actions).toarray(), columns=dummy_col_names) - elif FLAGS.encode == 'label': + elif FLAGS.encode == 'label' and not NO_CAT: dummy_actions = pd.DataFrame() for categorical_col in categorical_cols: # Label encode categorical features @@ -91,7 +97,7 @@ def preprocess_data(actions, observations, exp_path): # Transform the categorical features dummy_actions[categorical_col] = enc.transform(categorical_actions[categorical_col]) categorical_actions = pd.DataFrame(dummy_actions, columns=categorical_cols) - else: + elif not NO_CAT: raise ValueError('Encoding method not supported') preprocess_data_path = os.path.join(exp_path, 'preprocess_data') @@ -137,7 +143,10 @@ def preprocess_data(actions, observations, exp_path): raise ValueError('Preprocessing method not supported') # Concatenate numerical and categorical features - actions = pd.concat([numerical_actions, categorical_actions], axis = 1).to_numpy() + if NO_CAT: + actions = numerical_actions.to_numpy() + else: + actions = pd.concat([numerical_actions, categorical_actions], axis = 1).to_numpy() observations = observations.to_numpy() return actions, observations @@ -211,11 +220,15 @@ def main(_): os.makedirs(exp_path) # Load the data - actions_path = os.path.join(FLAGS.data_path, 'actions_feasible.csv') - observations_path = os.path.join(FLAGS.data_path, 'observations_feasible.csv') - - actions = pd.read_csv(actions_path) - observations = pd.read_csv(observations_path) + if FLAGS.custom_dataset: + actions_path = os.path.join(FLAGS.data_path, 'actions_feasible.csv') + observations_path = os.path.join(FLAGS.data_path, 'observations_feasible.csv') + actions = pd.read_csv(actions_path) + observations = pd.read_csv(observations_path) + else: + california = fetch_california_housing() + actions = pd.DataFrame(california.data, columns=california.feature_names) + observations = pd.DataFrame(california.target, columns=['MEDV']) output = observations.copy() if FLAGS.output_index >= output.shape[1]: