Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when attempting to Benchmark the model for survival analysis #282

Open
uclrmhigid opened this issue Aug 5, 2024 · 1 comment
Open
Labels
question Further information is requested

Comments

@uclrmhigid
Copy link

## Description
I am attempting to generate synthetic data conditional on ethnicity for my survival data. I am able to generate the data but agetting an error regarding time_to_event when attempting to Benchmark the model. I am getting the error "ValueError: The time_to_event_column contains 1 values less than or equal to zero. Please remove them." when running Benchmarks.evaluate

## How to Reproduce
import sys
import numpy as np
import warnings
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import SurvivalAnalysisDataLoader
from synthcity.utils.serialization import load, save
from synthcity.utils.serialization import load, load_from_file, save, save_to_file
from synthcity.benchmark import Benchmarks

Set up logging and filter warnings

log.add(sink=sys.stderr, level="INFO")
warnings.filterwarnings("ignore")

loader = SurvivalAnalysisDataLoader(
subset_df,
target_column="event_cmp",
time_to_event_column="tstop",
)

syn_model = Plugins().get("survival_gan")
cond = subset_df["Race=Asian or Pacific Islander"]
syn_model.fit(loader, cond=cond)
count = 10
syn_model.generate(count=count, cond=np.ones(count)).dataframe()
buff = save(syn_model)
type(buff)
reloaded = load(buff)
reloaded.name()
score = Benchmarks.evaluate(
[(f"test_{model}", model, {}) for model in ["adsgan", "survival_gan", "survae"]],
loader,
synthetic_size=1000,
repeats=2,
task_type="survival_analysis",
)

## Expected Behavior
Score of quality of the plugin

## Screenshots
Get error:
{
"name": "ValueError",
"message": "The time_to_event_column contains 1 values less than or equal to zero. Please remove them.",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [83], in <cell line: 5>()
1 # synthcity absolute
2 #Can't get to work
3 from synthcity.benchmark import Benchmarks
----> 5 score = Benchmarks.evaluate(
6 [(f"test_{model}", model, {}) for model in ["adsgan", "survival_gan", "survae"]],
7 loader,
8 synthetic_size=1000,
9 repeats=1,
10 task_type="survival_analysis",
11 )

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\benchmark\init.py:288, in Benchmarks.evaluate(tests, X, X_test, metrics, repeats, synthetic_size, synthetic_constraints, synthetic_cache, synthetic_reuse_if_exists, augmented_reuse_if_exists, task_type, workspace, augmentation_rule, strict_augmentation, ad_hoc_augment_vals, use_metric_cache, **generate_kwargs)
286 else:
287 X_augmented = None
--> 288 evaluation = Metrics.evaluate(
289 X_test if X_test is not None else X.test(),
290 X_syn,
291 X.train(),
292 X_ref_syn,
293 X_augmented,
294 metrics=metrics,
295 task_type=task_type,
296 workspace=workspace,
297 use_cache=use_metric_cache,
298 )
300 mean_score = evaluation["mean"].to_dict()
301 errors = evaluation["errors"].to_dict()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\metrics\eval.py:204, in Metrics.evaluate(X_gt, X_syn, X_train, X_ref_syn, X_augmented, reduction, n_histogram_bins, metrics, task_type, random_state, workspace, use_cache)
201 metrics = Metrics.list()
203 X_gt, _ = X_gt.encode()
--> 204 X_syn, _ = X_syn.encode()
206 if X_train:
207 X_train, _ = X_train.encode()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\plugins\core\dataloader.py:244, in DataLoader.encode(self, encoders)
242 encoded[col] = encoder.transform(encoded[col]).values
243 encoders[col] = encoder
--> 244 return self.from_info(encoded, self.info()), encoders

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\plugins\core\dataloader.py:647, in SurvivalAnalysisDataLoader.from_info(data, info)
644 if not isinstance(data, pd.DataFrame):
645 raise ValueError(f"Invalid data type {type(data)}")
--> 647 return SurvivalAnalysisDataLoader(
648 data,
649 target_column=info["target_column"],
650 time_to_event_column=info["time_to_event_column"],
651 sensitive_features=info["sensitive_features"],
652 important_features=info["important_features"],
653 time_horizons=info["time_horizons"],
654 fairness_column=info["fairness_column"],
655 )

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:40, in pydantic.decorator.validate_arguments.validate.wrapper_function()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:134, in pydantic.decorator.ValidatedFunction.call()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pydantic\decorator.py:206, in pydantic.decorator.ValidatedFunction.execute()

File ~\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\synthcity\plugins\core\dataloader.py:531, in SurvivalAnalysisDataLoader.init(self, data, time_to_event_column, target_column, time_horizons, sensitive_features, important_features, fairness_column, random_state, train_size, **kwargs)
529 row_diff = data.shape[0] - data_filtered.shape[0]
530 if row_diff > 0:
--> 531 raise ValueError(
532 f"The time_to_event_column contains {row_diff} values less than or equal to zero. Please remove them."
533 )
535 if len(time_horizons) == 0:
536 time_horizons = np.linspace(T.min(), T.max(), num=5)[1:-1].tolist()

ValueError: The time_to_event_column contains 1 values less than or equal to zero. Please remove them."
}

## System Information
Python 3.10.11

## Additional Context
I'm not sure if its relevant, but even if I set subset_df = subset_df[subset_df['tstop'] > 0] prior to running any of this code, I still get the same error

@robsdavis
Copy link
Contributor

robsdavis commented Sep 19, 2024

Hi @uclrmhigid, Thanks for submitting this issue.

The error you are seeing comes from this part of the code:

        T = data[time_to_event_column]
        data_filtered = data[T > 0]
        row_diff = data.shape[0] - data_filtered.shape[0]
        if row_diff > 0:
            raise ValueError(
                f"The time_to_event_column contains {row_diff} values less than or equal to zero. Please remove them."
            )

Does your time to event column contain any values not greater than 0? If yes then, this is the expected behaviour and you will need to remove or re-label these datapoints.

If no, are able to share your data somehow for me to re-create this? if not, can you create a toy dataset that causes this issue?

@robsdavis robsdavis added question Further information is requested labels Sep 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants