Skip to content

Commit

Permalink
Merge pull request #552 from LinglongQian/main
Browse files Browse the repository at this point in the history
Refactor CSAI imputation & classification
  • Loading branch information
WenjieDu authored Feb 22, 2025
2 parents adbac5b + cc88a76 commit ddcce10
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 323 deletions.
17 changes: 2 additions & 15 deletions pypots/classification/csai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(
n_classes: int,
step_channels: int,
dropout: float = 0.5,
intervals=None,
):
super().__init__()
self.n_steps = n_steps
Expand All @@ -35,13 +34,11 @@ def __init__(
self.classification_weight = classification_weight
self.n_classes = n_classes
self.step_channels = step_channels
self.intervals = intervals

# create models
self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels, intervals)
self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels)
self.f_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.b_classifier = nn.Linear(self.rnn_hidden_size, n_classes)
self.imputer = nn.Linear(self.rnn_hidden_size, n_features)
self.dropout = nn.Dropout(dropout)

def forward(self, inputs: dict, training: bool = True) -> dict:
Expand All @@ -56,16 +53,9 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
reconstruction_loss,
) = self.model(inputs)

results = {
"imputed_data": imputed_data,
}

f_logits = self.f_classifier(self.dropout(f_hidden_states))
b_logits = self.b_classifier(self.dropout(b_hidden_states))

# f_prediction = torch.sigmoid(f_logits)
# b_prediction = torch.sigmoid(b_logits)

f_prediction = torch.softmax(f_logits, dim=1)
b_prediction = torch.softmax(b_logits, dim=1)
classification_pred = (f_prediction + b_prediction) / 2
Expand All @@ -80,11 +70,8 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
# criterion = DiceBCELoss().to(imputed_data.device)
results["consistency_loss"] = consistency_loss
results["reconstruction_loss"] = reconstruction_loss
# print(inputs["labels"].unsqueeze(1))
f_classification_loss = F.nll_loss(torch.log(f_prediction), inputs["y"])
b_classification_loss = F.nll_loss(torch.log(b_prediction), inputs["y"])
# f_classification_loss, _ = criterion(f_prediction, f_logits, inputs["labels"].unsqueeze(1).float())
# b_classification_loss, _ = criterion(b_prediction, b_logits, inputs["labels"].unsqueeze(1).float())
classification_loss = f_classification_loss + b_classification_loss

loss = (
Expand All @@ -97,5 +84,5 @@ def forward(self, inputs: dict, training: bool = True) -> dict:
results["classification_loss"] = classification_loss
results["f_reconstruction"] = f_reconstruction
results["b_reconstruction"] = b_reconstruction

return results
8 changes: 0 additions & 8 deletions pypots/classification/csai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ def __init__(
return_y: bool = True,
removal_percent: float = 0.0,
increase_factor: float = 0.1,
compute_intervals: bool = False,
replacement_probabilities=None,
normalise_mean: list = [],
normalise_std: list = [],
training: bool = True,
):
super().__init__(
data=data,
Expand All @@ -31,9 +27,5 @@ def __init__(
file_type=file_type,
removal_percent=removal_percent,
increase_factor=increase_factor,
compute_intervals=compute_intervals,
replacement_probabilities=replacement_probabilities,
normalise_mean=normalise_mean,
normalise_std=normalise_std,
training=training,
)
32 changes: 5 additions & 27 deletions pypots/classification/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ class CSAI(BaseNNClassifier):
increase_factor :
The factor to increase the frequency of missing value occurrences.
compute_intervals :
Whether to compute time intervals between observations during data processing.
step_channels :
The number of step channels for the model.
Expand Down Expand Up @@ -122,7 +119,6 @@ def __init__(
n_classes: int,
removal_percent: int,
increase_factor: float,
compute_intervals: bool,
step_channels: int,
dropout: float = 0.5,
batch_size: int = 32,
Expand Down Expand Up @@ -160,12 +156,7 @@ def __init__(
self.removal_percent = removal_percent
self.increase_factor = increase_factor
self.step_channels = step_channels
self.compute_intervals = compute_intervals
self.dropout = dropout
self.intervals = None
self.replacement_probabilities = None
self.mean_set = None
self.std_set = None

# Initialise empty model
self.model = _BCSAI(
Expand All @@ -178,7 +169,6 @@ def __init__(
n_classes=self.n_classes,
step_channels=self.step_channels,
dropout=self.dropout,
intervals=self.intervals,
)

self._send_model_to_given_device()
Expand Down Expand Up @@ -210,6 +200,7 @@ def _assemble_input_for_training(self, data: list, training=True) -> dict:
"deltas": back_deltas,
"last_obs": back_last_obs,
},
"intervals": self.intervals,
}
return inputs

Expand All @@ -229,8 +220,6 @@ def _assemble_input_for_testing(self, data: list) -> dict:
back_missing_mask,
back_deltas,
back_last_obs,
X_ori,
indicating_mask,
) = self._send_data_to_given_device(sample)

# assemble input data
Expand All @@ -248,8 +237,7 @@ def _assemble_input_for_testing(self, data: list) -> dict:
"deltas": back_deltas,
"last_obs": back_last_obs,
},
# "X_ori": X_ori,
# "indicating_mask": indicating_mask,
"intervals": self.intervals,
}

return inputs
Expand All @@ -263,7 +251,7 @@ def fit(
# Create dataset
if isinstance(train_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"CSAI does not support lazy loading because intervals need to be calculated ahead. "
"Hence the whole train set will be loaded into memory."
)
train_set = load_dict_from_h5(train_set)
Expand All @@ -273,13 +261,10 @@ def fit(
return_y=True,
removal_percent=self.removal_percent,
increase_factor=self.increase_factor,
compute_intervals=self.compute_intervals,
)

self.intervals = training_set.intervals
self.replacement_probabilities = training_set.replacement_probabilities
self.mean_set = training_set.mean_set
self.std_set = training_set.std_set

train_loader = DataLoader(
training_set,
Expand All @@ -291,7 +276,7 @@ def fit(
if val_set is not None:
if isinstance(val_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"CSAI does not support lazy loading because intervals need to be calculated ahead. "
"Hence the whole val set will be loaded into memory."
)
val_set = load_dict_from_h5(val_set)
Expand All @@ -304,10 +289,7 @@ def fit(
return_y=True,
removal_percent=self.removal_percent,
increase_factor=self.increase_factor,
compute_intervals=self.compute_intervals,
replacement_probabilities=self.replacement_probabilities,
normalise_mean=self.mean_set,
normalise_std=self.std_set,
)
val_loader = DataLoader(
val_set,
Expand All @@ -333,7 +315,7 @@ def predict(

if isinstance(test_set, str):
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"CSAI does not support lazy loading because intervals need to be calculated ahead. "
"Hence the whole test set will be loaded into memory."
)
test_set = load_dict_from_h5(test_set)
Expand All @@ -343,11 +325,7 @@ def predict(
return_y=False,
removal_percent=self.removal_percent,
increase_factor=self.increase_factor,
compute_intervals=self.compute_intervals,
replacement_probabilities=self.replacement_probabilities,
normalise_mean=self.mean_set,
normalise_std=self.std_set,
training=False,
)
test_loader = DataLoader(
test_set,
Expand Down
3 changes: 3 additions & 0 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
FiLM,
Pyraformer,
Autoformer,
CSAI,
CSDI,
Informer,
USGAN,
Expand Down Expand Up @@ -75,6 +76,7 @@
# imputation models, sorted by the first letter of the model name
"pypots.imputation.Autoformer": Autoformer,
"pypots.imputation.BRITS": BRITS,
"pypots.imputation.CSAI": CSAI,
"pypots.imputation.CSDI": CSDI,
"pypots.imputation.Crossformer": Crossformer,
"pypots.imputation.DLinear": DLinear,
Expand Down Expand Up @@ -112,6 +114,7 @@
"pypots.imputation.TimeLLM": TimeLLM,
# classification models
"pypots.classification.BRITS": BRITS_classification,
"pypots.classification.CSAI": CSAI_classification,
"pypots.classification.GRUD": GRUD_classification,
"pypots.classification.Raindrop": Raindrop,
"pypots.classification.CSAI": CSAI_classification,
Expand Down
10 changes: 1 addition & 9 deletions pypots/imputation/csai/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ class _BCSAI(nn.Module):
step_channels :
number of channels for each step in the sequence
intervals :
time intervals between the observations, used for handling irregular time-series
consistency_weight :
weight assigned to the consistency loss during training
Expand All @@ -51,9 +48,6 @@ class _BCSAI(nn.Module):
step_channels :
number of channels for each step in the sequence
intervals :
time intervals between observations
consistency_weight :
weight assigned to the consistency loss
Expand All @@ -77,18 +71,16 @@ def __init__(
step_channels,
consistency_weight,
imputation_weight,
intervals=None,
):
super().__init__()
self.n_steps = n_steps
self.n_features = n_features
self.rnn_hidden_size = rnn_hidden_size
self.step_channels = step_channels
self.intervals = intervals
self.consistency_weight = consistency_weight
self.imputation_weight = imputation_weight

self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels, intervals)
self.model = BackboneBCSAI(n_steps, n_features, rnn_hidden_size, step_channels)

def forward(self, inputs: dict, training: bool = True) -> dict:
(
Expand Down
Loading

0 comments on commit ddcce10

Please sign in to comment.