Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assert clause prevents finding important features. #9

Open
franciscomalveiro opened this issue May 10, 2023 · 1 comment
Open

Assert clause prevents finding important features. #9

franciscomalveiro opened this issue May 10, 2023 · 1 comment

Comments

@franciscomalveiro
Copy link

franciscomalveiro commented May 10, 2023

Hello!
First of all, thank you for the development of this project.
I've been (trying to) use it to extract feature importance of LSTM and Transformer models, but I've stumbled on an assert clause that stops the process.

More specifically:

  • I have developed these models using PyTorch, and have developed a model wrapper to adapt them to your framework;
  • My dataset consists on 222 instances of sequence length 4, where each element of this sequence contains 6 features.

When no features are considered important, it all goes smoothly, displaying on the terminal:
No important features identified, skipping window feature importance window visualization.

However, when the framework detects important features, it hits an assert clause, this one:

assert perturbed_slice.base is X_hat

perturbed_slice.base is None, where, according to the assert clause, it should be, as it indicates, X_hat.
For that reason, the process stops.

Following the guidelines you provide at Contributions:

  • I am running Manjaro Linux 22.1.0;
  • I think no local details are necessary, as ...
  • ... I was able to reproduce the bug without using my data or my models.

To try to find out the issue (it could reside on my data, models or model wrapper), I've tried to replicate the problem solely with the functions you provide:

import anamod, synmod
output_dir = '.'
num_instances = 222
num_features = 6
fraction_relevant_features = 0.9 # 0.1 works, 0.9 blows
sequence_length = 4

synthesized_features, X, model = synmod.synthesize(output_dir=output_dir, num_instances=num_instances, seed=100,
                                                    num_features=num_features, fraction_relevant_features=fraction_relevant_features,
                                                    synthesis_type='temporal', sequence_length=sequence_length, model_type='classifier')

y = model.predict(X, labels=True)


importance_level = 0.1
output_dir = '.'
loss_function = 'binary_cross_entropy'
feature_names = ['A', 'B', 'C', 'D', 'E', 'F']
explainer = anamod.TemporalModelAnalyzer(
                        model, 
                        X, 
                        y,
                        output_dir=output_dir, 
                        loss_function=loss_function,
                        feature_names=feature_names,
                        importance_significance_level=importance_level,
                        visualize=True
                            )

explainer.analyze()

Changing the value of fraction_relevant_features toggles between working and not working:

  • fraction_relevant_features = 0.1 makes the analyser detect no important features, and for that reason concludes successfully;
  • fraction_relevant_features = 0.9 makes the program hit the assertclause.

PS: To check whether that would be the single problem in the process, I have tried commenting that assert clause. The process then finishes, but displaying a «deformed» plot .

feature_importance_windows

However the results may not be correct ones (the assert was there probably for a reason...)
It would be nice if the plot size was adjusted accordingly, or to be set beforehand, to avoid this.

Thanks!

@franciscomalveiro
Copy link
Author

I took a look at the source code, and found the following:

  • The problem appears at

    assert perturbed_slice.base is X_hat

    The assert clause fails given that perturbed_slice.base = None.

  • perturbed_slice is generated by

    perturbed_slice = self._perturbation_fn.operate(X_hat[axis0, axis1, axis2])

  • perturbed_slice.base is X_hat, unless it enters in the following block

    if self.pool is not None:
    idx = self.pool.__next__() # Caller needs to catch StopIteration
    return X[idx, ...]

    When the code reaches, this block, perturbed_slice.base = None and for that reason the assert clause fails.

I've noticed the comment you have left there, so maybe there is something missing in the implementation...?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant