diff --git a/algo/curve_anomaly/cont_det/cont_device.py b/algo/curve_anomaly/cont_det/cont_device.py index ab8a687..8ac79a8 100644 --- a/algo/curve_anomaly/cont_det/cont_device.py +++ b/algo/curve_anomaly/cont_det/cont_device.py @@ -118,7 +118,11 @@ def get_reconstruction_errors(model_input_data_array, model, use_cuda): model_input = torch.Tensor(data_series) if use_cuda: model_input = model_input.cuda() - model_output = model(model_input) + try: + model_output = model(model_input) + except RuntimeError: + print("Probably not enough data collected yet!") + return None, None errors.append(abs(model_output.detach().cpu().numpy()-data_series).sum()/205) model.train() return errors, model_output.detach().cpu().numpy() @@ -134,6 +138,9 @@ def test(data_list, model, use_cuda, anomalies, training_max, reconstruction_err model_input_data_array = np.array(data_series_smooth[-model_input_window_length:]).reshape(1,-1) new_reconstruction_error = get_reconstruction_errors(model_input_data_array, model, use_cuda)[0][0] reconstructed_curve = get_reconstruction_errors(model_input_data_array, model, use_cuda)[1].flatten() + if new_reconstruction_error == None: # This happens if not enough data is collected yet! + model.train() + return None, anomalies, reconstruction_errors if reconstruction_errors == None: reconstruction_errors = [new_reconstruction_error] else: