Skip to content

Commit

Permalink
Make Standard Scaler fit on segments only
Browse files Browse the repository at this point in the history
  • Loading branch information
Eve-ning committed Feb 21, 2024
1 parent 2873d2f commit 4a179c0
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions tests/model_tests/chestnut_dec_may/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@
# os.environ["WANDB_MODE"] = "offline"


def get_y_encoder(targets):
oe = OrdinalEncoder(
handle_unknown="use_encoded_value",
unknown_value=np.nan,
)
oe.fit(np.array(targets).reshape(-1, 1))
return oe


def get_x_scaler(segments):
ss = StandardScaler()
ss.fit(
np.concatenate([segm.reshape(-1, segm.shape[-1]) for segm in segments])
)
return ss


def main(
batch_size=32,
epochs=10,
Expand All @@ -47,17 +64,7 @@ def main(
train_unl_ds = ds.chestnut_20201218.unlabelled(
transform=train_unl_preprocess(2)
)
val_ds = ds.chestnut_20210510_43m(transform=preprocess)

oe = OrdinalEncoder(
handle_unknown="use_encoded_value",
unknown_value=np.nan,
)
oe.fit(np.array(train_lab_ds.targets).reshape(-1, 1))
n_classes = len(oe.categories_[0])

ss = StandardScaler()
ss.fit(train_lab_ds.ar.reshape(-1, train_lab_ds.ar.shape[-1]))
val_ds = ds.chestnut_20210510_43m(transform=val_preprocess)

# Prepare the datamodule and trainer
dm = FRDCDataModule(
Expand Down Expand Up @@ -90,9 +97,12 @@ def main(
),
)

oe = get_y_encoder(train_lab_ds.targets)
ss = get_x_scaler(train_lab_ds.ar_segments)

m = InceptionV3MixMatchModule(
in_channels=train_lab_ds.ar.shape[-1],
n_classes=n_classes,
n_classes=len(oe.categories_[0]),
lr=lr,
x_scaler=ss,
y_encoder=oe,
Expand Down

0 comments on commit 4a179c0

Please sign in to comment.