diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml
index c3c2e8ac..0529e537 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -9,20 +9,24 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- python-version: [3.6, 3.7, 3.8, 3.9]
- os: [ubuntu-latest, macos-10.15, windows-latest]
+ python-version: ['3.7', '3.8', '3.9', '3.10']
+ os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- - if: matrix.os == 'windows-latest'
+ - if: matrix.os == 'windows-latest' && matrix.python-version != '3.10'
name: Install dependencies - Windows
run: |
python -m pip install --upgrade pip
python -m pip install 'torch==1.8.0' -f https://download.pytorch.org/whl/cpu/torch/
- python -m pip install 'torchvision==0.9.0' -f https://download.pytorch.org/whl/cpu/torchvision/
+ - if: matrix.os == 'windows-latest' && matrix.python-version == '3.10'
+ name: Install dependencies - Windows
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install 'torch==1.11.0' -f https://download.pytorch.org/whl/cpu/torch/
- name: Install dependencies
run: |
python -m pip install --upgrade pip
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 0630c219..07dde39e 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -10,7 +10,7 @@ jobs:
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.8
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
diff --git a/.github/workflows/minimum.yml b/.github/workflows/minimum.yml
index cb2f3af5..5c067bb7 100644
--- a/.github/workflows/minimum.yml
+++ b/.github/workflows/minimum.yml
@@ -9,20 +9,24 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- python-version: [3.6, 3.7, 3.8, 3.9]
- os: [ubuntu-latest, macos-10.15, windows-latest]
+ python-version: ['3.7', '3.8', '3.9', '3.10']
+ os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- - if: matrix.os == 'windows-latest'
+ - if: matrix.os == 'windows-latest' && matrix.python-version != 3.10
name: Install dependencies - Windows
run: |
python -m pip install --upgrade pip
python -m pip install 'torch==1.8.0' -f https://download.pytorch.org/whl/cpu/torch/
- python -m pip install 'torchvision==0.9.0' -f https://download.pytorch.org/whl/cpu/torchvision/
+ - if: matrix.os == 'windows-latest' && matrix.python-version == 3.10
+ name: Install dependencies - Windows
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install 'torch==1.11.0' -f https://download.pytorch.org/whl/cpu/torch/
- name: Install dependencies
run: |
python -m pip install --upgrade pip
diff --git a/.github/workflows/readme.yml b/.github/workflows/readme.yml
index 2fe4b64c..e96c87fb 100644
--- a/.github/workflows/readme.yml
+++ b/.github/workflows/readme.yml
@@ -9,12 +9,12 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- python-version: [3.6, 3.7, 3.8, 3.9]
- os: [ubuntu-latest, macos-10.15] # skip windows bc rundoc fails
+ python-version: ['3.7', '3.8', '3.9', '3.10']
+ os: [ubuntu-latest, macos-latest] # skip windows bc rundoc fails
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
diff --git a/.github/workflows/unit.yml b/.github/workflows/unit.yml
index 18558731..ec7bd348 100644
--- a/.github/workflows/unit.yml
+++ b/.github/workflows/unit.yml
@@ -9,20 +9,24 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
- python-version: [3.6, 3.7, 3.8, 3.9]
- os: [ubuntu-latest, macos-10.15, windows-latest]
+ python-version: ['3.7', '3.8', '3.9', '3.10']
+ os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v1
+ uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- - if: matrix.os == 'windows-latest'
+ - if: matrix.os == 'windows-latest' && matrix.python-version != 3.10
name: Install dependencies - Windows
run: |
python -m pip install --upgrade pip
python -m pip install 'torch==1.8.0' -f https://download.pytorch.org/whl/cpu/torch/
- python -m pip install 'torchvision==0.9.0' -f https://download.pytorch.org/whl/cpu/torchvision/
+ - if: matrix.os == 'windows-latest' && matrix.python-version == 3.10
+ name: Install dependencies - Windows
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install 'torch==1.11.0' -f https://download.pytorch.org/whl/cpu/torch/
- name: Install dependencies
run: |
python -m pip install --upgrade pip
diff --git a/AUTHORS.rst b/AUTHORS.rst
index d80a376d..312c3c48 100644
--- a/AUTHORS.rst
+++ b/AUTHORS.rst
@@ -1,10 +1 @@
-Credits
-=======
-
-Contributors
-------------
-
-* Lei Xu
-* Kalyan Veeramachaneni
-* Manuel Alvarez
-* Carles Sala
+See: https://github.com/sdv-dev/SDGym/graphs/contributors
diff --git a/BENCHMARK.md b/BENCHMARK.md
index c5a6bb6f..acd1ea77 100644
--- a/BENCHMARK.md
+++ b/BENCHMARK.md
@@ -87,13 +87,13 @@ The most basic scenario is to pass a synthesizer function, but the sdgym.run fun
can also be used to evaluate any `Synthesizer` class, as far as it is a subclass of
`sdgym.synthesizers.BaseSynthesizer`.
-For example, if we want to evaluate the `Independent` we can do so by passing the class
+For example, if we want to evaluate the `IndependentSynthesizer` we can do so by passing the class
directly to the sdgym.run function:
```python3
-In [5]: from sdgym.synthesizers import Independent
+In [5]: from sdgym.synthesizers import IndependentSynthesizer
-In [6]: scores = sdgym.run(synthesizers=Independent)
+In [6]: scores = sdgym.run(synthesizers=IndependentSynthesizer)
```
#### Evaluating multiple Synthesizers
@@ -103,7 +103,7 @@ The `sdgym.run` function can be used to evaluate more than one Synthesizer at a
In order to do this, all you need to do is pass a list of functions instead of a single
object.
-For example, if we want to evaluate our synthesizer function and also the `Independent`
+For example, if we want to evaluate our synthesizer function and also the `IndependentSynthesizer`
we can pass both of them inside a list:
```python3
@@ -113,12 +113,12 @@ In [8]: scores = sdgym.run(synthesizers=synthesizers)
```
Or, if we wanted to evaluate all the SDGym Synthesizers at once (note that this takes a lot of time
-to run!), we could just pass all the subclasses of `Baseline`:
+to run!), we could just pass all the subclasses of `BaselineSynthesizer`:
```python3
-In [9]: from sdgym.synthesizers import Baseline
+In [9]: from sdgym.synthesizers import BaselineSynthesizer
-In [10]: scores = sdgym.run(Baseline.get_subclasses())
+In [10]: scores = sdgym.run(BaselineSynthesizer.get_subclasses())
```
#### Customizing the Synthesizer names.
@@ -132,7 +132,7 @@ putting the names as keys and the functions or classes as the values:
```python3
In [11]: synthesizers = {
...: 'My Synthesizer': my_synthesizer_function,
- ...: 'SDGym Independent': Independent
+ ...: 'SDGym Independent': IndependentSynthesizer
...: }
In [12]: scores = sdgym.run(synthesizers=synthesizers)
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index ab9843a4..0c5afb3f 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -120,8 +120,8 @@ Before you submit a pull request, check that it meets these guidelines:
4. If the pull request adds functionality, the docs should be updated. Put
your new functionality into a function with a docstring, and add the
feature to the list in README.rst.
-5. The pull request should work for Python 3.5 and 3.6. Check
- https://travis-ci.org/sdv-dev/SDGym/pull_requests
+5. The pull request should work for Python 3.7, 3.8, 3.9, and 3.10. Check
+ https://github.com/sdv-dev/SDGym/actions
and make sure that all the checks pass.
Unit Testing Guidelines
diff --git a/Dockerfile b/Dockerfile
index b62959fe..580f56bf 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,8 +1,8 @@
FROM nvidia/cuda:11.0.3-cudnn8-devel-ubuntu18.04
CMD nvidia-smi
-RUN apt-get update && apt-get install -y build-essential && apt-get -y install curl
-RUN apt-get -y install python3.8 python3-distutils && ln -s /usr/bin/python3.8 /usr/bin/python
+RUN apt-get update && apt-get install -y build-essential curl python3.7 python3.7-dev \
+ python3-distutils && ln -s /usr/bin/python3.7 /usr/bin/python
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && \
python get-pip.py && ln -s /usr/bin/pip3 /usr/bin/pip
@@ -18,8 +18,8 @@ COPY /privbayes/ /SDGym/privbayes
WORKDIR /SDGym
# Install project
-RUN make install-all compile
-RUN pip install -U numpy==1.20
+RUN pip install . --no-binary pomegranate
+RUN make compile
ENV PRIVBAYES_BIN /SDGym/privbayes/privBayes.bin
ENV TF_CPP_MIN_LOG_LEVEL 2
diff --git a/HISTORY.md b/HISTORY.md
index fd6af46b..50972c93 100644
--- a/HISTORY.md
+++ b/HISTORY.md
@@ -1,5 +1,50 @@
# History
+## v0.6.0 - 2021-02-01
+This release introduces methods for benchmarking single table data and creating custom synthesizers, which can be based on existing SDGym-defined synthesizers or on user-defined functions. This release also adds support for Python 3.10 and drops support for Python 3.6.
+
+### New Features
+* Benchmarking progress bar should update on one line - Issue [#204](https://github.com/sdv-dev/SDGym/issues/204) by @katxiao
+* Support local additional datasets folder with zip files - Issue [#186](https://github.com/sdv-dev/SDGym/issues/186) by @katxiao
+* Enforce that each synthesizer is unique in benchmark_single_table - Issue [#190](https://github.com/sdv-dev/SDGym/issues/190) by @katxiao
+* Simplify the file names inside the detailed_results_folder - Issue [#191](https://github.com/sdv-dev/SDGym/issues/191) by @katxiao
+* Use SDMetrics silent report generation - Issue [#179](https://github.com/sdv-dev/SDGym/issues/179) by @katxiao
+* Remove arguments in get_available_datasets - Issue [#197](https://github.com/sdv-dev/SDGym/issues/197) by @katxiao
+* Accept metadata.json as valid metadata file - Issue [#194](https://github.com/sdv-dev/SDGym/issues/194) by @katxiao
+* Check if file or folder exists before writing benchmarking results - Issue [#196](https://github.com/sdv-dev/SDGym/issues/196) by @katxiao
+* Rename benchmarking argument "evaluate_quality" to "compute_quality_score" - Issue [#195](https://github.com/sdv-dev/SDGym/issues/195) by @katxiao
+* Add option to disable sdmetrics in benchmarking - Issue [#182](https://github.com/sdv-dev/SDGym/issues/182) by @katxiao
+* Prefix remote bucket with 's3' - Issue [#183](https://github.com/sdv-dev/SDGym/issues/183) by @katxiao
+* Benchmarking error handling - Issue [#177](https://github.com/sdv-dev/SDGym/issues/177) by @katxiao
+* Allow users to specify custom synthesizers' display names - Issue [#174](https://github.com/sdv-dev/SDGym/issues/174) by @katxiao
+* Update benchmarking results columns - Issue [#172](https://github.com/sdv-dev/SDGym/issues/172) by @katxiao
+* Allow custom datasets - Issue [#166](https://github.com/sdv-dev/SDGym/issues/166) by @katxiao
+* Use new datasets s3 bucket - Issue [#161](https://github.com/sdv-dev/SDGym/issues/161) by @katxiao
+* Create benchmark_single_table method - Issue [#151](https://github.com/sdv-dev/SDGym/issues/151) by @katxiao
+* Update summary metrics - Issue [#134](https://github.com/sdv-dev/SDGym/issues/134) by @katxiao
+* Benchmark individual methods - Issue [#159](https://github.com/sdv-dev/SDGym/issues/159) by @katxiao
+* Add method to create a sdv variant synthesizer - Issue [#152](https://github.com/sdv-dev/SDGym/issues/152) by @katxiao
+* Add method to generate a multi table synthesizer - Issue [#149](https://github.com/sdv-dev/SDGym/issues/149) by @katxiao
+* Add method to create single table synthesizers - Issue [#148](https://github.com/sdv-dev/SDGym/issues/148) by @katxiao
+* Updating existing synthesizers to new API - Issue [#154](https://github.com/sdv-dev/SDGym/issues/154) by @katxiao
+
+### Bug Fixes
+* Pip encounters dependency issues with ipython - Issue [#187](https://github.com/sdv-dev/SDGym/issues/187) by @katxiao
+* IndependentSynthesizer is printing out ConvergeWarning too many times - Issue [#192](https://github.com/sdv-dev/SDGym/issues/192) by @katxiao
+* Size values in benchmarking results seems inaccurate - Issue [#184](https://github.com/sdv-dev/SDGym/issues/184) by @katxiao
+* Import error in the example for benchmarking the synthesizers - Issue [#139](https://github.com/sdv-dev/SDGym/issues/139) by @katxiao
+* Updates and bugfixes - Issue [#132](https://github.com/sdv-dev/SDGym/issues/132) by @csala
+
+### Maintenance
+* Update README - Issue [#203](https://github.com/sdv-dev/SDGym/issues/203) by @katxiao
+* Support Python Versions >=3.7 and <3.11 - Issue [#170](https://github.com/sdv-dev/SDGym/issues/170) by @katxiao
+* SDGym Package Maintenance Updates documentation - Issue [#163](https://github.com/sdv-dev/SDGym/issues/163) by @katxiao
+* Remove YData - Issue [#168](https://github.com/sdv-dev/SDGym/issues/168) by @katxiao
+* Update to newest SDV - Issue [#157](https://github.com/sdv-dev/SDGym/issues/157) by @katxiao
+* Update slack invite link. - Issue [#144](https://github.com/sdv-dev/SDGym/issues/144) by @pvk-developer
+* updating workflows to work with windows - Issue [#136](https://github.com/sdv-dev/SDGym/issues/136) by @amontanez24
+* Update conda dependencies - Issue [#130](https://github.com/sdv-dev/SDGym/issues/130) by @katxiao
+
## v0.5.0 - 2021-12-13
This release adds support for Python 3.9, and updates dependencies to accept the latest versions when possible.
diff --git a/INSTALL.md b/INSTALL.md
index 497304d8..b46c78cf 100644
--- a/INSTALL.md
+++ b/INSTALL.md
@@ -2,7 +2,7 @@
## Requirements
-**SDGym** has been developed and tested on [Python 3.6, 3.7, 3.8 and 3.9](https://www.python.org/downloads/)
+**SDGym** has been developed and tested on [Python 3.7, 3.8, 3.9, and 3.10](https://www.python.org/downloads/)
Also, although it is not strictly required, the usage of a [virtualenv](
https://virtualenv.pypa.io/en/latest/) is highly recommended in order to avoid
diff --git a/LICENSE b/LICENSE
index 05f84df8..96e0846c 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,21 +1,108 @@
-MIT License
-
-Copyright (c) 2019, MIT Data To AI Lab
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
+Business Source License 1.1
+
+Parameters
+
+Licensor: DataCebo, Inc.
+
+Licensed Work: SDGym
+ The Licensed Work is (c) DataCebo, Inc.
+
+Additional Use Grant: You may make use of the Licensed Work, and derivatives of the Licensed
+ Work, provided that you do not use the Licensed Work, or derivatives of
+ the Licensed Work, for a Synthetic Data Benchmarking Service.
+
+ A "Synthetic Data Benchmarking Service" is a commercial offering
+ that allows third parties (other than your employees and
+ contractors) to access the functionality of the Licensed
+ Work so that such third parties directly benefit from the
+ data specification, data transformation, machine learning, synthetic
+ data creation, synthetic data evaluation or performance evaluation
+ features of the Licensed Work.
+
+
+Change Date: Change date is four years from release date.
+ Please see https://github.com/sdv-dev/SDGym/releases
+ for exact dates.
+
+Change License: MIT License
+
+
+Notice
+
+The Business Source License (this document, or the "License") is not an Open
+Source license. However, the Licensed Work will eventually be made available
+under an Open Source License, as stated in this License.
+
+License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved.
+"Business Source License" is a trademark of MariaDB Corporation Ab.
+
+-----------------------------------------------------------------------------
+
+Business Source License 1.1
+
+Terms
+
+The Licensor hereby grants you the right to copy, modify, create derivative
+works, redistribute, and make non-production use of the Licensed Work. The
+Licensor may make an Additional Use Grant, above, permitting limited
+production use.
+
+Effective on the Change Date, or the fourth anniversary of the first publicly
+available distribution of a specific version of the Licensed Work under this
+License, whichever comes first, the Licensor hereby grants you rights under
+the terms of the Change License, and the rights granted in the paragraph
+above terminate.
+
+If your use of the Licensed Work does not comply with the requirements
+currently in effect as described in this License, you must purchase a
+commercial license from the Licensor, its affiliated entities, or authorized
+resellers, or you must refrain from using the Licensed Work.
+
+All copies of the original and modified Licensed Work, and derivative works
+of the Licensed Work, are subject to this License. This License applies
+separately for each version of the Licensed Work and the Change Date may vary
+for each version of the Licensed Work released by Licensor.
+
+You must conspicuously display this License on each original or modified copy
+of the Licensed Work. If you receive the Licensed Work in original or
+modified form from a third party, the terms and conditions set forth in this
+License apply to your use of that work.
+
+Any use of the Licensed Work in violation of this License will automatically
+terminate your rights under this License for the current and all other
+versions of the Licensed Work.
+
+This License does not grant you any right in any trademark or logo of
+Licensor or its affiliates (provided that you may use a trademark or logo of
+Licensor as expressly required by this License).
+
+TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON
+AN "AS IS" BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS,
+EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND
+TITLE.
+
+MariaDB hereby grants you permission to use this License’s text to license
+your works, and to refer to it using the trademark "Business Source License",
+as long as you comply with the Covenants of Licensor below.
+
+Covenants of Licensor
+
+In consideration of the right to use this License’s text and the "Business
+Source License" name and trademark, Licensor covenants to MariaDB, and to all
+other recipients of the licensed work to be provided by Licensor:
+
+1. To specify as the Change License the GPL Version 2.0 or any later version,
+ or a license that is compatible with GPL Version 2.0 or a later version,
+ where "compatible" means that software provided under the Change License can
+ be included in a program with software provided under GPL Version 2.0 or a
+ later version. Licensor may specify additional Change Licenses without
+ limitation.
+
+2. To either: (a) specify an additional grant of rights to use that does not
+ impose any additional restriction on the right granted in this License, as
+ the Additional Use Grant; or (b) insert the text "None".
+
+3. To specify a Change Date.
+
+4. Not to modify this License in any other way.
diff --git a/Makefile b/Makefile
index d3f74d57..af9acbcd 100644
--- a/Makefile
+++ b/Makefile
@@ -84,29 +84,6 @@ install-test: clean-build clean-compile clean-pyc compile ## install the package
install-develop: clean-build clean-pyc clean-pyc compile ## install the package in editable mode and dependencies for development
pip install -e .[dev]
-.PHONY: install-ydata
-install-ydata: clean-build clean-compile clean-pyc compile ## install the package with ydata
- pip install 'ydata-synthetic>=0.3.0,<0.4'
- pip install .
-
-.PHONY: install-ydata-develop
-install-ydata-develop: clean-build clean-compile clean-pyc compile ## install the package with ydata and dependencies for development
- pip install 'ydata-synthetic>=0.3.0,<0.4'
- pip install -e .[dev]
-
-.PHONY: install-gretel
-install-gretel: clean-build clean-compile clean-pyc compile ## install the package with gretel
- pip install .[gretel]
-
-.PHONY: install-gretel-develop
-install-gretel-develop: clean-build clean-compile clean-pyc compile ## install the package with gretel and dependencies for development
- pip install -e .[dev,gretel]
-
-.PHONY: install-all
-install-all: clean-build clean-compile clean-pyc compile ## install the package with gretel and ydata
- pip install 'ydata-synthetic>=0.3.0,<0.4'
- pip install .[gretel]
-
# LINT TARGETS
.PHONY: lint
diff --git a/README.md b/README.md
index 1653a5a0..721759cf 100644
--- a/README.md
+++ b/README.md
@@ -1,167 +1,158 @@
-
-
-
-
- An Open Source Project from the Data to AI Lab, at MIT
+
+
+
+ This repository is part of The Synthetic Data Vault Project, a project from DataCebo.
[![Development Status](https://img.shields.io/badge/Development%20Status-2%20--%20Pre--Alpha-yellow)](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
[![Travis](https://travis-ci.org/sdv-dev/SDGym.svg?branch=master)](https://travis-ci.org/sdv-dev/SDGym)
[![PyPi Shield](https://img.shields.io/pypi/v/sdgym.svg)](https://pypi.python.org/pypi/sdgym)
[![Downloads](https://pepy.tech/badge/sdgym)](https://pepy.tech/project/sdgym)
+[![Slack](https://img.shields.io/badge/Community-Slack-blue?style=plastic&logo=slack)](https://bit.ly/sdv-slack-invite)
+
+
-
-
-Benchmarking framework for Synthetic Data Generators
-
-* Website: https://sdv.dev
-* Documentation: https://sdv.dev/SDV
-* Repository: https://github.com/sdv-dev/SDGym
-* License: [MIT](https://github.com/sdv-dev/SDGym/blob/master/LICENSE)
-* Development Status: [Pre-Alpha](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
+
# Overview
-Synthetic Data Gym (SDGym) is a framework to benchmark the performance of synthetic data
-generators based on [SDV](https://github.com/sdv-dev/SDV) and [SDMetrics](
-https://github.com/sdv-dev/SDMetrics).
-
-SDGym is a part of the [The Synthetic Data Vault](https://sdv.dev/) project.
-
-## What is a Synthetic Data Generator?
-
-A **Synthetic Data Generator** is a Python function (or method) that takes as input some
-data, which we call the *real* data, learns a model from it, and outputs new *synthetic* data that
-has the same structure and similar mathematical properties as the *real* one.
-
-Please refer to the [synthesizers documentation](SYNTHESIZERS.md) for instructions about how to
-implement your own Synthetic Data Generator and integrate with SDGym. You can also read about how
-to use the ones already included in **SDGym** and see how to run them.
+The Synthetic Data Gym (SDGym) is a benchmarking framework for modeling and generating
+synthetic data. Measure performance and memory usage across different synthetic data modeling
+techniques – classical statistics, deep learning and more!
-## Benchmark datasets
+
-**SDGym** evaluates the performance of **Synthetic Data Generators** using *single table*,
-*multi table* and *timeseries* datasets stored as CSV files alongside an [SDV Metadata](
-https://sdv.dev/SDV/user_guides/relational/relational_metadata.html) JSON file.
+The SDGym library integrates with the Synthetic Data Vault ecosystem. You can use any of its
+synthesizers, datasets or metrics for benchmarking. You also customize the process to include
+your own work.
-Further details about the list of available datasets and how to add your own datasets to
-the collection can be found in the [datasets documentation](DATASETS.md).
+* **Datasets**: Select any of the publicly available datasets from the SDV project, or input your own data.
+* **Synthesizers**: Choose from any of the SDV synthesizers and baselines. Or write your own custom
+machine learning model.
+* **Evaluation**: In addition to performance and memory usage, you can also measure synthetic data
+quality and privacy through a variety of metrics
# Install
-**SDGym** can be installed using the following commands:
-
-**Using `pip`:**
+Install SDGym using pip or conda. We recommend using a virtual environment to avoid conflicts with other software on your device.
```bash
pip install sdgym
```
-**Using `conda`:**
-
```bash
-conda install -c sdv-dev -c conda-forge sdgym
+conda install -c conda-forge sdgym
```
-For more installation options please visit the [SDGym installation Guide](INSTALL.md)
+For more information about using SDGym, visit the [SDGym Documentation](https://docs.sdv.dev/sdgym).
# Usage
-## Benchmarking your own Synthesizer
-
-SDGym evaluates **Synthetic Data Generators**, which are Python functions (or classes) that take
-as input some data, which we call the *real* data, learn a model from it, and output new
-*synthetic* data that has the same structure and similar mathematical properties as the *real* one.
+Let's benchmark synthetic data generation for single tables. First, let's define which modeling
+techniques we want to use. Let's choose a few synthesizers from the SDV library and a few others
+to use as baselines.
-As an example, let use define a synthesizer function that applies the [GaussianCopula model from SDV
-](https://sdv.dev/SDV/user_guides/single_table/gaussian_copula.html) with `gaussian` distribution.
+```python
+# these synthesizers come from the SDV library
+# each one uses different modeling techniques
+sdv_synthesizers = ['GaussianCopulaSynthesizer', 'CTGANSynthesizer']
-```python3
-import numpy as np
-from sdv.tabular import GaussianCopula
+# these basic synthesizers are available in SDGym
+# as baselines
+baseline_synthesizers = ['UniformSynthesizer']
+```
+Now, we can benchmark the different techniques:
+```python
+import sdgym
-def gaussian_copula(real_data, metadata):
- gc = GaussianCopula(default_distribution='gaussian')
- table_name = metadata.get_tables()[0]
- gc.fit(real_data[table_name])
- return {table_name: gc.sample()}
+sdgym.benchmark_single_table(
+ synthesizers=(sdv_synthesizers + baseline_synthesizers)
+)
```
-|:information_source: You can learn how to create your own synthesizer function [here](SYNTHESIZERS.md).|
-|:-|
+The result is a detailed performance, memory and quality evaluation across the synthesizers
+on a variety of publicly available datasets.
-We can now try to evaluate this function on the `asia` and `alarm` datasets:
+## Supplying a custom synthesizer
-```python3
-import sdgym
+Benchmark your own synthetic data generation techniques. Define your synthesizer by
+specifying the training logic (using machine learning) and the sampling logic.
-scores = sdgym.run(synthesizers=gaussian_copula, datasets=['asia', 'alarm'])
+```python
+def my_training_logic(data, metadata):
+ # create an object to represent your synthesizer
+ # train it using the data
+ return synthesizer
+
+def my_sampling_logic(trained_synthesizer, num_rows):
+ # use the trained synthesizer to create
+ # num_rows of synthetic data
+ return synthetic_data
```
-|:information_source: You can learn about different arguments for `sdgym.run` function [here](BENCHMARK.md).|
-|:-|
+Learn more in the [Custom Synthesizers Guide](https://docs.sdv.dev/sdgym/customization/synthesizers/custom-synthesizers).
-The output of the `sdgym.run` function will be a `pd.DataFrame` containing the results obtained
-by your synthesizer on each dataset.
+## Customizing your datasets
-| synthesizer | dataset | modality | metric | score | metric_time | model_time |
-|-----------------|---------|--------------|-----------------|------------|-------------|------------|
-| gaussian_copula | asia | single-table | BNLogLikelihood | -2.842690 | 2.762427 | 0.752364 |
-| gaussian_copula | alarm | single-table | BNLogLikelihood | -20.223178 | 7.009401 | 3.173832 |
+The SDGym library includes many publicly available datasets that you can include right away.
+List these using the ``get_available_datasets`` feature.
-## Benchmarking the SDGym Synthesizers
+```python
+sdgym.get_available_datasets()
+```
-If you want to run the SDGym benchmark on the SDGym Synthesizers you can directly pass the
-corresponding class, or a list of classes, to the `sdgym.run` function.
+```
+dataset_name size_MB num_tables
+KRK_v1 0.072128 1
+adult 3.907448 1
+alarm 4.520128 1
+asia 1.280128 1
+...
+```
-For example, if you want to run the complete benchmark suite to evaluate all the existing
-synthesizers you can run (:warning: this will take a lot of time to run!):
+You can also include any custom, private datasets that are stored on your computer on an
+Amazon S3 bucket.
-```python
-from sdgym.synthesizers import (
- CLBN, CopulaGAN, CTGAN, HMA1, Identity, Independent,
- MedGAN, PAR, PrivBN, SDV, TableGAN, TVAE,
- Uniform, VEEGAN)
-
-all_synthesizers = [
- CLBN,
- CTGAN,
- CopulaGAN,
- HMA1,
- Identity,
- Independent,
- MedGAN,
- PAR,
- PrivBN,
- SDV,
- TVAE,
- TableGAN,
- Uniform,
- VEEGAN,
-]
-scores = sdgym.run(synthesizers=all_synthesizers)
+```
+my_datasets_folder = 's3://my-datasets-bucket'
```
-For further details about all the arguments and possibilities that the `benchmark` function offers
-please refer to the [benchmark documentation](BENCHMARK.md)
+For more information, see the docs for [Customized Datasets](https://docs.sdv.dev/sdgym/customization/datasets).
-# Additional References
+# What's next?
-* Datasets used in SDGym are detailed [here](DATASETS.md).
-* How to write a synthesizer is detailed [here](SYNTHESIZERS.md).
-* How to use benchmark function is detailed [here](BENCHMARK.md).
-* Detailed leaderboard results for all the releases are available [here](
-https://docs.google.com/spreadsheets/d/1iNJDVG_tIobcsGUG5Gn4iLa565vVhz2U/edit).
+Visit the [SDGym Documentation](https://docs.sdv.dev/sdgym) to learn more!
-# The Synthetic Data Vault
+---
-
-
-
-
-
This repository is part of The Synthetic Data Vault Project
-
-* Website: https://sdv.dev
-* Documentation: https://sdv.dev/SDV
+
+
+
+
+
+
+[The Synthetic Data Vault Project](https://sdv.dev) was first created at MIT's [Data to AI Lab](
+https://dai.lids.mit.edu/) in 2016. After 4 years of research and traction with enterprise, we
+created [DataCebo](https://datacebo.com) in 2020 with the goal of growing the project.
+Today, DataCebo is the proud developer of SDV, the largest ecosystem for
+synthetic data generation & evaluation. It is home to multiple libraries that support synthetic
+data, including:
+
+* 🔄 Data discovery & transformation. Reverse the transforms to reproduce realistic data.
+* 🧠 Multiple machine learning models -- ranging from Copulas to Deep Learning -- to create tabular,
+ multi table and time series data.
+* 📊 Measuring quality and privacy of synthetic data, and comparing different synthetic data
+ generation models.
+
+[Get started using the SDV package](https://sdv.dev/SDV/getting_started/install.html) -- a fully
+integrated solution and your one-stop shop for synthetic data. Or, use the standalone libraries
+for specific needs.
diff --git a/conda/meta.yaml b/conda/meta.yaml
index 014ff08d..25750609 100644
--- a/conda/meta.yaml
+++ b/conda/meta.yaml
@@ -1,5 +1,5 @@
{% set name = 'sdgym' %}
-{% set version = '0.5.0' %}
+{% set version = '0.6.0.dev2' %}
package:
name: "{{ name|lower }}"
@@ -20,47 +20,47 @@ requirements:
- pip
- pytest-runner
- graphviz
- - python >=3.6,<3.9
+ - python >=3.6,<3.10
- appdirs >=1.3,<2
- boto3 >=1.15.0,<2
- botocore >=1.18,<2
- humanfriendly >=8.2,<11
- numpy >=1.18.0,<2
- pandas >=1.1.3,<2
- - pomegranate >=0.14.1,<15
+ - pomegranate >=0.13.4,<15
- psutil >=5.7,<6
- scikit-learn >=0.24,<2
- tabulate >=0.8.3,<0.9
- pytorch >=1.8.0,<2
- - tqdm >=4.14,<5
+ - tqdm >=4.15,<5
- XlsxWriter >=1.2.8,<4
- - rdt >=0.4.1,<0.6
+ - rdt >=0.6.1,<0.7
- sdmetrics >=0.4.1,<0.5
- - sdv >=0.9.0
+ - sdv >=0.13.0
run:
- - python >=3.6,<3.9
+ - python >=3.6,<3.10
- appdirs >=1.3,<2
- boto3 >=1.15.0,<2
- botocore >=1.18,<2
- humanfriendly >=8.2,<11
- numpy >=1.18.0,<2
- pandas >=1.1.3,<2
- - pomegranate >=0.14.1,<15
+ - pomegranate >=0.13.4,<15
- psutil >=5.7,<6
- scikit-learn >=0.24,<2
- tabulate >=0.8.3,<0.9
- pytorch >=1.8.0,<2
- - tqdm >=4.14,<5
+ - tqdm >=4.15,<5
- XlsxWriter >=1.2.8,<4
- - rdt >=0.4.1,<0.6
+ - rdt >=0.6.1,<0.7
- sdmetrics >=0.4.1,<0.5
- - sdv >=0.9.0
+ - sdv >=0.13.0
about:
home: "https://github.com/sdv-dev/SDGym"
- license: MIT
- license_family: MIT
- license_file:
+ license: BUSL-1.1
+ license_family:
+ license_file: LICENSE
summary: "A framework to benchmark the performance of synthetic data generators for non-temporal tabular data"
doc_url:
dev_url:
diff --git a/docs/images/SDGym_Results.png b/docs/images/SDGym_Results.png
new file mode 100644
index 00000000..ec97e9c3
Binary files /dev/null and b/docs/images/SDGym_Results.png differ
diff --git a/sdgym/__init__.py b/sdgym/__init__.py
index 381bb1a7..04e66e22 100644
--- a/sdgym/__init__.py
+++ b/sdgym/__init__.py
@@ -4,23 +4,29 @@
tabular data.
"""
-__author__ = 'MIT Data To AI Lab'
-__copyright__ = 'Copyright (c) 2018, MIT Data To AI Lab'
-__email__ = 'dailabmit@gmail.com'
-__license__ = 'MIT'
-__version__ = '0.5.0'
+__author__ = 'DataCebo, Inc.'
+__copyright__ = 'Copyright (c) 2022 DataCebo, Inc.'
+__email__ = 'info@sdv.dev'
+__license__ = 'BSL-1.1'
+__version__ = '0.6.0.dev2'
+
+import logging
from sdgym import benchmark, synthesizers
-from sdgym.benchmark import run
+from sdgym.benchmark import benchmark_single_table
from sdgym.collect import collect_results
from sdgym.datasets import load_dataset
from sdgym.summary import make_summary_spreadsheet
+# Clear the logging wrongfully configured by tensorflow/absl
+list(map(logging.root.removeHandler, logging.root.handlers))
+list(map(logging.root.removeFilter, logging.root.filters))
+
__all__ = [
'benchmark',
'synthesizers',
- 'run',
'load_dataset',
'collect_results',
- 'make_summary_spreadsheet'
+ 'make_summary_spreadsheet',
+ 'benchmark_single_table',
]
diff --git a/sdgym/__main__.py b/sdgym/__main__.py
index aa31ee3d..eb20254a 100644
--- a/sdgym/__main__.py
+++ b/sdgym/__main__.py
@@ -13,7 +13,7 @@
import tqdm
import sdgym
-from sdgym.synthesizers.base import Baseline
+from sdgym.synthesizers.base import BaselineSynthesizer
from sdgym.utils import get_synthesizers
@@ -137,7 +137,7 @@ def _list_available(args):
def _list_synthesizers(args):
- synthesizers = Baseline.get_baselines()
+ synthesizers = BaselineSynthesizer.get_baselines()
_print_table(pd.DataFrame(get_synthesizers(list(synthesizers))))
diff --git a/sdgym/benchmark.py b/sdgym/benchmark.py
index 35e82016..dc0b4ea0 100644
--- a/sdgym/benchmark.py
+++ b/sdgym/benchmark.py
@@ -4,6 +4,8 @@
import logging
import multiprocessing
import os
+import pickle
+import tracemalloc
import uuid
from datetime import datetime
from pathlib import Path
@@ -12,36 +14,90 @@
import numpy as np
import pandas as pd
import tqdm
+from sdmetrics.reports.multi_table import QualityReport as MultiTableQualityReport
+from sdmetrics.reports.single_table import QualityReport as SingleTableQualityReport
from sdgym.datasets import get_dataset_paths, load_dataset, load_tables
from sdgym.errors import SDGymError
from sdgym.metrics import get_metrics
from sdgym.progress import TqdmLogger, progress
from sdgym.s3 import is_s3_path, write_csv, write_file
-from sdgym.synthesizers.base import Baseline
+from sdgym.synthesizers import CTGANSynthesizer, FastMLPreset, GaussianCopulaSynthesizer
+from sdgym.synthesizers.base import BaselineSynthesizer, SingleTableBaselineSynthesizer
from sdgym.synthesizers.utils import get_num_gpus
from sdgym.utils import (
- build_synthesizer, format_exception, get_synthesizers, import_object, used_memory)
+ build_synthesizer, format_exception, get_duplicates, get_size_of, get_synthesizers,
+ import_object, used_memory)
LOGGER = logging.getLogger(__name__)
+DEFAULT_SYNTHESIZERS = [GaussianCopulaSynthesizer, FastMLPreset, CTGANSynthesizer]
+DEFAULT_DATASETS = [
+ 'adult',
+ 'alarm',
+ 'census',
+ 'child',
+ 'expedia_hotel_logs',
+ 'insurance',
+ 'intrusion',
+ 'news',
+ 'covtype',
+]
+DEFAULT_METRICS = [('NewRowSynthesis', {'synthetic_sample_size': 1000})]
+N_BYTES_IN_MB = 1000 * 1000
def _synthesize(synthesizer_dict, real_data, metadata):
synthesizer = synthesizer_dict['synthesizer']
+ get_synthesizer = None
+ sample_from_synthesizer = None
if isinstance(synthesizer, str):
synthesizer = import_object(synthesizer)
if isinstance(synthesizer, type):
- if issubclass(synthesizer, Baseline):
- synthesizer = synthesizer().fit_sample
+ if issubclass(synthesizer, BaselineSynthesizer):
+ s_obj = synthesizer()
+ get_synthesizer = s_obj.get_trained_synthesizer
+ sample_from_synthesizer = s_obj.sample_from_synthesizer
else:
- synthesizer = build_synthesizer(synthesizer, synthesizer_dict)
-
+ get_synthesizer, sample_from_synthesizer = build_synthesizer(
+ synthesizer, synthesizer_dict)
+
+ if isinstance(synthesizer, tuple):
+ get_synthesizer, sample_from_synthesizer = synthesizer
+
+ data = real_data.copy()
+ num_samples = None
+ modalities = getattr(synthesizer, 'MODALITIES', [])
+ is_single_table = (
+ isinstance(synthesizer, type)
+ and issubclass(synthesizer, SingleTableBaselineSynthesizer)
+ ) or (
+ len(modalities) == 1
+ and 'single-table' in modalities
+ )
+ if is_single_table:
+ table_name = list(real_data.keys())[0]
+ metadata = metadata.get_table_meta(table_name)
+ data = list(real_data.values())[0]
+ num_samples = len(data)
+
+ tracemalloc.start()
now = datetime.utcnow()
- synthetic_data = synthesizer(real_data.copy(), metadata)
- elapsed = datetime.utcnow() - now
- return synthetic_data, elapsed
+ synthesizer_obj = get_synthesizer(data, metadata)
+ synthesizer_size = len(pickle.dumps(synthesizer_obj)) / N_BYTES_IN_MB
+ train_now = datetime.utcnow()
+ synthetic_data = sample_from_synthesizer(synthesizer_obj, num_samples)
+ sample_now = datetime.utcnow()
+
+ peak_memory = tracemalloc.get_traced_memory()[1] / N_BYTES_IN_MB
+ tracemalloc.stop()
+ tracemalloc.clear_traces()
+
+ if is_single_table:
+ synthetic_data = {list(real_data.keys())[0]: synthetic_data}
+
+ return synthetic_data, train_now - now, sample_now - train_now, synthesizer_size, peak_memory
def _prepare_metric_args(real_data, synthetic_data, metadata):
@@ -57,42 +113,64 @@ def _prepare_metric_args(real_data, synthetic_data, metadata):
return real_data, synthetic_data, metadata
-def _compute_scores(metrics, real_data, synthetic_data, metadata, output):
- metrics = get_metrics(metrics, metadata)
- metric_args = _prepare_metric_args(real_data, synthetic_data, metadata)
+def _compute_scores(metrics, real_data, synthetic_data, metadata, output, compute_quality_score):
+ metrics = metrics or []
+ if len(metrics) > 0:
+ metrics, metric_kwargs = get_metrics(metrics, metadata)
+ metric_args = _prepare_metric_args(real_data, synthetic_data, metadata)
+
+ scores = []
+ output['scores'] = scores
+ for metric_name, metric in metrics.items():
+ scores.append({
+ 'metric': metric_name,
+ 'error': 'Metric Timeout',
+ })
+ output['scores'] = scores # re-inject list to multiprocessing output
+
+ error = None
+ score = None
+ normalized_score = None
+ start = datetime.utcnow()
+ try:
+ LOGGER.info('Computing %s on dataset %s', metric_name, metadata._metadata['name'])
+ score = metric.compute(*metric_args, **metric_kwargs.get(metric_name, {}))
+ normalized_score = metric.normalize(score)
+ except Exception:
+ LOGGER.exception('Metric %s failed on dataset %s. Skipping.',
+ metric_name, metadata._metadata['name'])
+ _, error = format_exception()
+
+ scores[-1].update({
+ 'score': score,
+ 'normalized_score': normalized_score,
+ 'error': error,
+ 'metric_time': (datetime.utcnow() - start).total_seconds()
+ })
+ output['scores'] = scores # re-inject list to multiprocessing output
+
+ if compute_quality_score:
+ start = datetime.utcnow()
+ if metadata.modality == 'single-table':
+ quality_report = SingleTableQualityReport()
+ table_name = list(real_data.keys())[0]
+ table_metadata = metadata.get_table_meta(table_name)
+ table_real_data = list(real_data.values())[0]
+ table_synthetic_data = list(synthetic_data.values())[0]
+
+ quality_report = SingleTableQualityReport()
+ quality_report.generate(
+ table_real_data, table_synthetic_data, table_metadata, verbose=False)
+ else:
+ quality_report = MultiTableQualityReport()
+ quality_report.generate(real_data, synthetic_data, metadata, verbose=False)
- scores = []
- output['scores'] = scores
- for metric_name, metric in metrics.items():
- scores.append({
- 'metric': metric_name,
- 'error': 'Metric Timeout',
- })
- output['scores'] = scores # re-inject list to multiprocessing output
+ output['quality_score_time'] = (datetime.utcnow() - start).total_seconds()
+ output['quality_score'] = quality_report.get_score()
- error = None
- score = None
- normalized_score = None
- start = datetime.utcnow()
- try:
- LOGGER.info('Computing %s on dataset %s', metric_name, metadata._metadata['name'])
- score = metric.compute(*metric_args)
- normalized_score = metric.normalize(score)
- except Exception:
- LOGGER.exception('Metric %s failed on dataset %s. Skipping.',
- metric_name, metadata._metadata['name'])
- _, error = format_exception()
-
- scores[-1].update({
- 'score': score,
- 'normalized_score': normalized_score,
- 'error': error,
- 'metric_time': (datetime.utcnow() - start).total_seconds()
- })
- output['scores'] = scores # re-inject list to multiprocessing output
-
-
-def _score(synthesizer, metadata, metrics, iteration, output=None, max_rows=None):
+
+def _score(synthesizer, metadata, metrics, output=None, max_rows=None,
+ compute_quality_score=False):
if output is None:
output = {}
@@ -102,44 +180,58 @@ def _score(synthesizer, metadata, metrics, iteration, output=None, max_rows=None
output['error'] = 'Load Timeout' # To be deleted if there is no error
try:
real_data = load_tables(metadata, max_rows)
+ output['dataset_size'] = get_size_of(real_data) / N_BYTES_IN_MB
- LOGGER.info('Running %s on %s dataset %s; iteration %s; %s',
- name, metadata.modality, metadata._metadata['name'], iteration, used_memory())
+ LOGGER.info('Running %s on %s dataset %s; %s',
+ name, metadata.modality, metadata._metadata['name'], used_memory())
output['error'] = 'Synthesizer Timeout' # To be deleted if there is no error
- synthetic_data, model_time = _synthesize(synthesizer, real_data.copy(), metadata)
+ synthetic_data, train_time, sample_time, synthesizer_size, peak_memory = _synthesize(
+ synthesizer, real_data.copy(), metadata)
output['synthetic_data'] = synthetic_data
- output['model_time'] = model_time.total_seconds()
+ output['train_time'] = train_time.total_seconds()
+ output['sample_time'] = sample_time.total_seconds()
+ output['synthesizer_size'] = synthesizer_size
+ output['peak_memory'] = peak_memory
- LOGGER.info('Scoring %s on %s dataset %s; iteration %s; %s',
- name, metadata.modality, metadata._metadata['name'], iteration, used_memory())
+ LOGGER.info('Scoring %s on %s dataset %s; %s',
+ name, metadata.modality, metadata._metadata['name'], used_memory())
del output['error'] # No error so far. _compute_scores tracks its own errors by metric
- _compute_scores(metrics, real_data, synthetic_data, metadata, output)
+ _compute_scores(
+ metrics, real_data, synthetic_data, metadata, output, compute_quality_score)
output['timeout'] = False # There was no timeout
except Exception:
- LOGGER.exception('Error running %s on dataset %s; iteration %s',
- name, metadata._metadata['name'], iteration)
+ LOGGER.exception('Error running %s on dataset %s;',
+ name, metadata._metadata['name'])
exception, error = format_exception()
output['exception'] = exception
output['error'] = error
output['timeout'] = False # There was no timeout
finally:
- LOGGER.info('Finished %s on dataset %s; iteration %s; %s',
- name, metadata._metadata['name'], iteration, used_memory())
+ LOGGER.info('Finished %s on dataset %s; %s',
+ name, metadata._metadata['name'], used_memory())
return output
-def _score_with_timeout(timeout, synthesizer, metadata, metrics, iteration):
+def _score_with_timeout(timeout, synthesizer, metadata, metrics, max_rows=None,
+ compute_quality_score=False):
with multiprocessing.Manager() as manager:
output = manager.dict()
process = multiprocessing.Process(
target=_score,
- args=(synthesizer, metadata, metrics, iteration, output),
+ args=(
+ synthesizer,
+ metadata,
+ metrics,
+ output,
+ max_rows,
+ compute_quality_score,
+ ),
)
process.start()
@@ -148,8 +240,8 @@ def _score_with_timeout(timeout, synthesizer, metadata, metrics, iteration):
output = dict(output)
if output['timeout']:
- LOGGER.error('Timeout running %s on dataset %s; iteration %s',
- synthesizer['name'], metadata._metadata['name'], iteration)
+ LOGGER.error('Timeout running %s on dataset %s;',
+ synthesizer['name'], metadata._metadata['name'])
return output
@@ -158,47 +250,76 @@ def _run_job(args):
# Reset random seed
np.random.seed()
- synthesizer, metadata, metrics, iteration, cache_dir, \
- timeout, run_id, aws_key, aws_secret, max_rows = args
+ synthesizer, metadata, metrics, cache_dir, \
+ timeout, run_id, max_rows, compute_quality_score = args
name = synthesizer['name']
dataset_name = metadata._metadata['name']
- LOGGER.info('Evaluating %s on %s dataset %s with timeout %ss; iteration %s; %s',
- name, metadata.modality, dataset_name, timeout, iteration, used_memory())
+ LOGGER.info('Evaluating %s on %s dataset %s with timeout %ss; %s',
+ name, metadata.modality, dataset_name, timeout, used_memory())
- if timeout:
- output = _score_with_timeout(timeout, synthesizer, metadata, metrics, iteration)
- else:
- output = _score(synthesizer, metadata, metrics, iteration, max_rows=max_rows)
-
- scores = output.get('scores')
- if not scores:
- scores = pd.DataFrame({'score': [None]})
- else:
- scores = pd.DataFrame(scores)
-
- scores.insert(0, 'synthesizer', name)
- scores.insert(1, 'dataset', metadata._metadata['name'])
- scores.insert(2, 'modality', metadata.modality)
- scores.insert(3, 'iteration', iteration)
- scores['model_time'] = output.get('model_time')
- scores['run_id'] = run_id
+ output = {}
+ try:
+ if timeout:
+ output = _score_with_timeout(
+ timeout,
+ synthesizer,
+ metadata,
+ metrics,
+ max_rows=max_rows,
+ compute_quality_score=compute_quality_score,
+ )
+ else:
+ output = _score(
+ synthesizer,
+ metadata,
+ metrics,
+ max_rows=max_rows,
+ compute_quality_score=compute_quality_score,
+ )
+ except Exception as error:
+ output['exception'] = error
+
+ evaluate_time = None
+ if 'scores' in output or 'quality_score_time' in output:
+ evaluate_time = output.get('quality_score_time', 0)
+
+ for score in output.get('scores', []):
+ if score['metric'] == 'NewRowSynthesis':
+ evaluate_time += score['metric_time']
+
+ scores = pd.DataFrame({
+ 'Synthesizer': [name],
+ 'Dataset': [metadata._metadata['name']],
+ 'Dataset_Size_MB': [output.get('dataset_size')],
+ 'Train_Time': [output.get('train_time')],
+ 'Peak_Memory_MB': [output.get('peak_memory')],
+ 'Synthesizer_Size_MB': [output.get('synthesizer_size')],
+ 'Sample_Time': [output.get('sample_time')],
+ 'Evaluate_Time': [evaluate_time],
+ })
+
+ if compute_quality_score:
+ scores.insert(len(scores.columns), 'Quality_Score', output.get('quality_score'))
+
+ for score in output.get('scores', []):
+ scores.insert(len(scores.columns), score['metric'], score['normalized_score'])
if 'error' in output:
scores['error'] = output['error']
if cache_dir:
cache_dir_name = str(cache_dir)
- base_path = f'{cache_dir_name}/{name}_{dataset_name}_{iteration}_{run_id}'
+ base_path = f'{cache_dir_name}/{name}_{dataset_name}'
if scores is not None:
- write_csv(scores, f'{base_path}_scores.csv', aws_key, aws_secret)
+ write_csv(scores, f'{base_path}_scores.csv', None, None)
if 'synthetic_data' in output:
synthetic_data = compress_pickle.dumps(output['synthetic_data'], compression='gzip')
- write_file(synthetic_data, f'{base_path}.data.gz', aws_key, aws_secret)
+ write_file(synthetic_data, f'{base_path}.data.gz', None, None)
if 'exception' in output:
exception = output['exception'].encode('utf-8')
- write_file(exception, f'{base_path}_error.txt', aws_key, aws_secret)
+ write_file(exception, f'{base_path}_error.txt', None, None)
return scores
@@ -226,153 +347,170 @@ def _run_on_dask(jobs, verbose):
return dask.compute(*persisted)
-def run(synthesizers=None, datasets=None, datasets_path=None, modalities=None, bucket=None,
- metrics=None, iterations=1, workers=1, cache_dir=None, show_progress=False,
- timeout=None, output_path=None, aws_key=None, aws_secret=None, jobs=None,
- max_rows=None, max_columns=None):
- """Run the SDGym benchmark and return a leaderboard.
+def benchmark_single_table(synthesizers=DEFAULT_SYNTHESIZERS, custom_synthesizers=None,
+ sdv_datasets=DEFAULT_DATASETS, additional_datasets_folder=None,
+ limit_dataset_size=False, compute_quality_score=True,
+ sdmetrics=DEFAULT_METRICS, timeout=None, output_filepath=None,
+ detailed_results_folder=None, show_progress=False,
+ multi_processing_config=None):
+ """Run the SDGym benchmark on single-table datasets.
The ``synthesizers`` object can either be a single synthesizer or, an iterable of
synthesizers or a dict containing synthesizer names as keys and synthesizers as values.
- If ``add_leaderboard`` is ``True``, append the obtained scores to the leaderboard
- stored in the ``lederboard_path``. By default, the leaderboard used is the one which
- is included in the package, which contains the scores obtained by the SDGym Synthesizers.
-
- If ``replace_existing`` is ``True`` and any of the given synthesizers already existed
- in the leaderboard, the old rows are dropped.
-
Args:
- synthesizers (function, class, list, tuple or dict):
- The synthesizer or synthesizers to evaluate. It can be a single synthesizer
- (function or method or class), or an iterable of synthesizers, or a dict
- containing synthesizer names as keys and synthesizers as values. If the input
- is not a dict, synthesizer names will be extracted from the given object.
- datasets (list[str]):
- Names of the datasets to use for the benchmark. Defaults to all the ones available.
- datasets_path (str):
- Path to where the datasets can be found. If not given, use the default path.
- modalities (list[str]):
- Filter datasets by the given modalities. If not given, filter datasets by the
- synthesizer modalities.
- metrics (list[str]):
- List of metrics to apply.
- bucket (str):
- Name of the bucket from which the datasets must be downloaded if not found locally.
- iterations (int):
- Number of iterations to perform over each dataset and synthesizer. Defaults to 1.
- workers (int or str):
- If ``workers`` is given as an integer value other than 0 or 1, a multiprocessing
- Pool is used to distribute the computation across the indicated number of workers.
- If ``workers`` is -1, the number of workers will be automatically determined by
- the number of GPUs (if available) or the number of CPU cores. If the string ``dask``
- is given, the computation is distributed using ``dask``. In this case, setting up the
- ``dask`` cluster and client is expected to be handled outside of this function.
- cache_dir (str):
- If a ``cache_dir`` is given, intermediate results are stored in the indicated directory
- as CSV files as they get computted. This allows inspecting results while the benchmark
- is still running and also recovering results in case the process does not finish
- properly. Defaults to ``None``.
+ synthesizers (list[string]):
+ The synthesizer(s) to evaluate. Defaults to ``[GaussianCopulaSynthesizer, FASTMLPreset,
+ CTGANSynthesizer]``. The available options are:
+
+ - ``GaussianCopulaSynthesizer``
+ - ``CTGANSynthesizer``
+ - ``CopulaGANSynthesizer``
+ - ``TVAESynthesizer``
+ - ``FASTMLPreset``
+ - any custom created synthesizer or variant
+
+ custom_synthesizers (list[class]):
+ A list of custom synthesizer classes to use. These can be completely custom or
+ they can be synthesizer variants (the output from ``create_single_table_synthesizer``
+ or ``create_sdv_synthesizer_variant``). Defaults to ``None``.
+ sdv_datasets (list[str] or ``None``):
+ Names of the SDV demo datasets to use for the benchmark. Defaults to
+ ``[adult, alarm, census, child, expedia_hotel_logs, insurance, intrusion, news,
+ covtype]``. Use ``None`` to disable using any sdv datasets.
+ additional_datasets_folder (str or ``None``):
+ The path to a folder (local or an S3 bucket). Datasets found in this folder are
+ run in addition to the SDV datasets. If ``None``, no additional datasets are used.
+ limit_dataset_size (bool):
+ Use this flag to limit the size of the datasets for faster evaluation. If ``True``,
+ limit the size of every table to 1,000 rows (randomly sampled) and the first 10
+ columns.
+ compute_quality_score (bool):
+ Whether or not to evaluate an overall quality score.
+ sdmetrics (list[str]):
+ A list of the different SDMetrics to use. If you'd like to input specific parameters
+ into the metric, provide a tuple with the metric name followed by a dictionary of
+ the parameters.
+ timeout (bool or ``None``):
+ The maximum number of seconds to wait for synthetic data creation. If ``None``, no
+ timeout is enforced.
+ output_filepath (str or ``None``):
+ A file path for where to write the output as a csv file. If ``None``, no output
+ is written.
+ detailed_results_folder (str or ``None``):
+ The folder for where to store the intermediary results. If ``None``, do not store
+ the intermediate results anywhere.
show_progress (bool):
Whether to use tqdm to keep track of the progress. Defaults to ``False``.
- timeout (int):
- Maximum number of seconds to wait for each dataset to
- finish the evaluation process. If not passed, wait until
- all the datasets are done.
- output_path (str):
- If an ``output_path`` is given, the generated leaderboard will be stored in the
- indicated path as a CSV file. The given path must be a complete path including
- the ``.csv`` filename.
- aws_key (str):
- If an ``aws_key`` is provided, the given access key id will be used to read
- from the specified bucket.
- aws_secret (str):
- If an ``aws_secret`` is provided, the given secret access key will be used to read
- from the specified bucket.
- jobs (list[tuple]):
- List of jobs to execute, as a sequence of tuple-like objects containing synthesizer,
- dataset and iteration-id specifications. If not passed, the jobs list is build by
- combining the synthesizers and datasets given instead.
- max_rows (int):
- Cap the number of rows to model from each dataset. Rows will be selected in order.
- max_columns (int):
- Cap the number of columns to model from each dataset.
- Columns will be selected in order.
+ multi_processing_config (dict or ``None``):
+ The config to use if multi-processing is desired. For example,
+ {
+ 'package_name': 'dask' or 'multiprocessing',
+ 'num_workers': 4
+ }
Returns:
pandas.DataFrame:
- A table containing one row per synthesizer + dataset + metric + iteration.
+ A table containing one row per synthesizer + dataset + metric.
"""
- if cache_dir and not is_s3_path(cache_dir):
- cache_dir = Path(cache_dir)
- os.makedirs(cache_dir, exist_ok=True)
+ if output_filepath and os.path.exists(output_filepath):
+ raise ValueError(
+ f'{output_filepath} already exists. '
+ 'Please provide a file that does not already exist.'
+ )
+
+ if detailed_results_folder and os.path.exists(detailed_results_folder):
+ raise ValueError(
+ f'{detailed_results_folder} already exists. '
+ 'Please provide a folder that does not already exist.'
+ )
+
+ duplicates = get_duplicates(synthesizers) if synthesizers else {}
+ if custom_synthesizers:
+ duplicates.update(get_duplicates(custom_synthesizers))
+ if len(duplicates) > 0:
+ raise ValueError(
+ 'Synthesizers must be unique. Please remove repeated values in the `synthesizers` '
+ 'and `custom_synthesizers` parameters.'
+ )
+
+ if detailed_results_folder and not is_s3_path(detailed_results_folder):
+ detailed_results_folder = Path(detailed_results_folder)
+ os.makedirs(detailed_results_folder, exist_ok=True)
+
+ max_rows, max_columns = (1000, 10) if limit_dataset_size else (None, None)
run_id = os.getenv('RUN_ID') or str(uuid.uuid4())[:10]
- if workers == -1:
- num_gpus = get_num_gpus()
- if num_gpus > 0:
- workers = num_gpus
- else:
- workers = multiprocessing.cpu_count()
+ synthesizers = get_synthesizers(synthesizers)
+ if custom_synthesizers:
+ custom_synthesizers = get_synthesizers(custom_synthesizers)
+ synthesizers.extend(custom_synthesizers)
- if jobs is None:
- synthesizers = get_synthesizers(synthesizers)
- datasets = get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret)
+ datasets = []
+ if sdv_datasets is not None:
+ datasets = get_dataset_paths(sdv_datasets, None, None, None, None)
- job_tuples = list()
- for dataset in datasets:
- for synthesizer in synthesizers:
- for iteration in range(iterations):
- job_tuples.append((synthesizer, dataset, iteration))
+ if additional_datasets_folder:
+ additional_datasets = get_dataset_paths(None, None, additional_datasets_folder, None, None)
+ datasets.extend(additional_datasets)
- else:
- job_tuples = list()
- for synthesizer, dataset, iteration in jobs:
- job_tuples.append((
- get_synthesizers([synthesizer])[0],
- get_dataset_paths([dataset], datasets_path, bucket, aws_key, aws_secret)[0],
- iteration
- ))
+ job_tuples = list()
+ for dataset in datasets:
+ for synthesizer in synthesizers:
+ job_tuples.append((synthesizer, dataset))
job_args = list()
- for synthesizer, dataset, iteration in job_tuples:
- metadata = load_dataset(dataset, max_columns=max_columns)
- modalities_ = modalities or synthesizer.get('modalities')
- if not modalities_ or metadata.modality in modalities_:
- args = (
- synthesizer,
- metadata,
- metrics,
- iteration,
- cache_dir,
- timeout,
- run_id,
- aws_key,
- aws_secret,
- max_rows,
- )
- job_args.append(args)
+ for synthesizer, dataset in job_tuples:
+ metadata = load_dataset('single_table', dataset, max_columns=max_columns)
+ dataset_modality = metadata.modality
+ synthesizer_modalities = synthesizer.get('modalities')
+ if (dataset_modality and dataset_modality != 'single-table') or (
+ synthesizer_modalities and 'single-table' not in synthesizer_modalities
+ ):
+ continue
+
+ args = (
+ synthesizer,
+ metadata,
+ sdmetrics,
+ detailed_results_folder,
+ timeout,
+ run_id,
+ max_rows,
+ compute_quality_score,
+ )
+ job_args.append(args)
- if workers == 'dask':
- scores = _run_on_dask(job_args, show_progress)
- else:
- if workers in (0, 1):
- scores = map(_run_job, job_args)
+ workers = 1
+ if multi_processing_config:
+ if multi_processing_config['package_name'] == 'dask':
+ workers = 'dask'
+ scores = _run_on_dask(job_args, show_progress)
else:
- pool = concurrent.futures.ProcessPoolExecutor(workers)
- scores = pool.map(_run_job, job_args)
-
- scores = tqdm.tqdm(scores, total=len(job_args), file=TqdmLogger())
- if show_progress:
- scores = tqdm.tqdm(scores, total=len(job_args))
+ num_gpus = get_num_gpus()
+ if num_gpus > 0:
+ workers = num_gpus
+ else:
+ workers = multiprocessing.cpu_count()
+
+ if workers in (0, 1):
+ scores = map(_run_job, job_args)
+ elif workers != 'dask':
+ pool = concurrent.futures.ProcessPoolExecutor(workers)
+ scores = pool.map(_run_job, job_args)
+
+ if show_progress:
+ scores = tqdm.tqdm(scores, total=len(job_args), position=0, leave=True)
+ else:
+ scores = tqdm.tqdm(scores, total=len(job_args), file=TqdmLogger(), position=0, leave=True)
if not scores:
raise SDGymError("No valid Dataset/Synthesizer combination given")
- scores = pd.concat(scores)
+ scores = pd.concat(scores, ignore_index=True)
- if output_path:
- write_csv(scores, output_path, aws_key, aws_secret)
+ if output_filepath:
+ write_csv(scores, output_filepath, None, None)
return scores
diff --git a/sdgym/datasets.py b/sdgym/datasets.py
index ccbc4440..14d2a5bb 100644
--- a/sdgym/datasets.py
+++ b/sdgym/datasets.py
@@ -2,6 +2,7 @@
import itertools
import json
import logging
+import os
from pathlib import Path
from zipfile import ZipFile
@@ -14,18 +15,26 @@
LOGGER = logging.getLogger(__name__)
DATASETS_PATH = Path(appdirs.user_data_dir()) / 'SDGym' / 'datasets'
-BUCKET = 'sdv-datasets'
+BUCKET = 's3://sdv-demo-datasets'
BUCKET_URL = 'https://{}.s3.amazonaws.com/'
TIMESERIES_FIELDS = ['sequence_index', 'entity_columns', 'context_columns', 'deepecho_version']
+MODALITIES = ['single_table', 'multi_table', 'sequential']
+S3_PREFIX = 's3://'
-def download_dataset(dataset_name, datasets_path=None, bucket=None, aws_key=None, aws_secret=None):
- datasets_path = datasets_path or DATASETS_PATH
+def _get_bucket_name(bucket):
+ return bucket[len(S3_PREFIX):] if bucket.startswith(S3_PREFIX) else bucket
+
+
+def download_dataset(modality, dataset_name, datasets_path=None, bucket=None, aws_key=None,
+ aws_secret=None):
+ datasets_path = datasets_path or DATASETS_PATH / dataset_name
bucket = bucket or BUCKET
+ bucket_name = _get_bucket_name(bucket)
LOGGER.info('Downloading dataset %s from %s', dataset_name, bucket)
s3 = get_s3_client(aws_key, aws_secret)
- obj = s3.get_object(Bucket=bucket, Key=f'{dataset_name}.zip')
+ obj = s3.get_object(Bucket=bucket_name, Key=f'{modality.upper()}/{dataset_name}.zip')
bytes_io = io.BytesIO(obj['Body'].read())
LOGGER.info('Extracting dataset into %s', datasets_path)
@@ -33,7 +42,8 @@ def download_dataset(dataset_name, datasets_path=None, bucket=None, aws_key=None
zf.extractall(datasets_path)
-def _get_dataset_path(dataset, datasets_path, bucket=None, aws_key=None, aws_secret=None):
+def _get_dataset_path(modality, dataset, datasets_path, bucket=None, aws_key=None,
+ aws_secret=None):
dataset = Path(dataset)
if dataset.exists():
return dataset
@@ -43,7 +53,13 @@ def _get_dataset_path(dataset, datasets_path, bucket=None, aws_key=None, aws_sec
if dataset_path.exists():
return dataset_path
- download_dataset(dataset, datasets_path, bucket=bucket, aws_key=aws_key, aws_secret=aws_secret)
+ if not bucket.startswith(S3_PREFIX):
+ local_path = Path(bucket) / dataset if bucket else Path(dataset)
+ if local_path.exists():
+ return local_path
+
+ download_dataset(
+ modality, dataset, dataset_path, bucket=bucket, aws_key=aws_key, aws_secret=aws_secret)
return dataset_path
@@ -61,10 +77,14 @@ def _apply_max_columns_to_metadata(metadata, max_columns):
structure['states'] = structure['states'][:max_columns]
-def load_dataset(dataset, datasets_path=None, bucket=None, aws_key=None, aws_secret=None,
- max_columns=None):
- dataset_path = _get_dataset_path(dataset, datasets_path, bucket, aws_key, aws_secret)
- with open(dataset_path / 'metadata.json') as metadata_file:
+def load_dataset(modality, dataset, datasets_path=None, bucket=None, aws_key=None,
+ aws_secret=None, max_columns=None):
+ dataset_path = _get_dataset_path(
+ modality, dataset, datasets_path, bucket, aws_key, aws_secret)
+ metadata_filename = 'metadata.json'
+ if not os.path.exists(f'{dataset_path}/{metadata_filename}'):
+ metadata_filename = 'metadata_v0.json'
+ with open(dataset_path / metadata_filename) as metadata_file:
metadata_content = json.load(metadata_file)
if max_columns:
@@ -113,17 +133,34 @@ def load_tables(metadata, max_rows=None):
return real_data
-def get_available_datasets(bucket=None, aws_key=None, aws_secret=None):
+def get_available_datasets():
+ return _get_available_datasets('single_table')
+
+
+def _get_available_datasets(modality, bucket=None, aws_key=None, aws_secret=None):
+ if modality not in MODALITIES:
+ modalities_list = ', '.join(MODALITIES)
+ raise ValueError(
+ f'Modality `{modality}` not recognized. Must be one of {modalities_list}')
+
s3 = get_s3_client(aws_key, aws_secret)
- response = s3.list_objects(Bucket=bucket or BUCKET)
+ bucket = bucket or BUCKET
+ bucket_name = _get_bucket_name(bucket)
+
+ response = s3.list_objects(Bucket=bucket_name, Prefix=modality.upper())
datasets = []
for content in response['Contents']:
key = content['Key']
- size = int(content['Size'])
+ metadata = s3.head_object(Bucket=bucket_name, Key=key)['ResponseMetadata']['HTTPHeaders']
+ size = metadata.get('x-amz-meta-size-mb')
+ size = float(size) if size is not None else size
+ num_tables = metadata.get('x-amz-meta-num-tables')
+ num_tables = int(num_tables) if num_tables is not None else num_tables
if key.endswith('.zip'):
datasets.append({
- 'name': key[:-len('.zip')],
- 'size': size
+ 'dataset_name': key[:-len('.zip')].lstrip(f'{modality.upper()}/'),
+ 'size_MB': size,
+ 'num_tables': num_tables,
})
return pd.DataFrame(datasets)
@@ -149,18 +186,33 @@ def get_downloaded_datasets(datasets_path=None):
def get_dataset_paths(datasets, datasets_path, bucket, aws_key, aws_secret):
"""Build the full path to datasets and ensure they exist."""
+ bucket = bucket or BUCKET
+ is_remote = bucket.startswith(S3_PREFIX)
+
if datasets_path is None:
datasets_path = DATASETS_PATH
datasets_path = Path(datasets_path)
if datasets is None:
- if datasets_path.exists():
- datasets = list(datasets_path.iterdir())
-
- if not datasets:
- datasets = get_available_datasets()['name'].tolist()
+ # local path
+ if not is_remote and Path(bucket).exists():
+ datasets = []
+ folder_items = list(Path(bucket).iterdir())
+ for dataset in folder_items:
+ if not dataset.name.startswith('.'):
+ if dataset.name.endswith('zip'):
+ dataset_name = os.path.splitext(dataset.name)[0]
+ dataset_path = datasets_path / dataset_name
+ ZipFile(dataset).extractall(dataset_path)
+
+ datasets.append(dataset_path)
+ elif dataset not in datasets:
+ datasets.append(dataset)
+ else:
+ datasets = _get_available_datasets(
+ 'single_table', bucket=bucket)['dataset_name'].tolist()
return [
- _get_dataset_path(dataset, datasets_path, bucket, aws_key, aws_secret)
+ _get_dataset_path('single_table', dataset, datasets_path, bucket, aws_key, aws_secret)
for dataset in datasets
]
diff --git a/sdgym/metrics.py b/sdgym/metrics.py
index c10592a9..d2d8ca93 100644
--- a/sdgym/metrics.py
+++ b/sdgym/metrics.py
@@ -12,6 +12,9 @@ def __init__(self, metric, **kwargs):
def compute(self, real_data, synthetic_data, metadata):
return self._metric.compute(real_data, synthetic_data, metadata, **self._kwargs)
+ def normalize(self, raw_score):
+ return self._metric.normalize(raw_score)
+
# Metrics to use by default for specific problem types and data
# modalities if no metrics have been explicitly specified.
@@ -47,16 +50,11 @@ def compute(self, real_data, synthetic_data, metadata):
DATA_MODALITY_METRICS = {
'single-table': [
'CSTest',
- 'KSTest',
- 'KSTestExtended',
- 'LogisticDetection',
+ 'KSComplement',
],
'multi-table': [
'CSTest',
- 'KSTest',
- 'KSTestExtended',
- 'LogisticDetection',
- 'LogisticParentChildDetection',
+ 'KSComplement',
],
'timeseries': [
'TSFClassifierEfficacy',
@@ -84,10 +82,12 @@ def get_metrics(metrics, metadata):
metrics = DATA_MODALITY_METRICS[modality]
final_metrics = {}
+ metric_kwargs = {}
for metric in metrics:
+ metric_args = None
if isinstance(metric, tuple):
- metric_name, metric = metric
- elif isinstance(metric, str):
+ metric, metric_args = metric
+ if isinstance(metric, str):
metric_name = metric
try:
metric = metric_classes[metric]
@@ -98,5 +98,7 @@ def get_metrics(metrics, metadata):
metric_name = metric.__name__
final_metrics[metric_name] = metric
+ if metric_args:
+ metric_kwargs[metric_name] = metric_args
- return final_metrics
+ return final_metrics, metric_kwargs
diff --git a/sdgym/s3.py b/sdgym/s3.py
index db5c6c07..b538fda1 100644
--- a/sdgym/s3.py
+++ b/sdgym/s3.py
@@ -107,6 +107,8 @@ def write_file(contents, path, aws_key, aws_secret):
if path.endswith('gz') or path.endswith('gzip'):
content_encoding = 'gzip'
write_mode = 'wb'
+ elif isinstance(contents, bytes):
+ write_mode = 'wb'
if is_s3_path(path):
s3 = get_s3_client(aws_key, aws_secret)
diff --git a/sdgym/summary.py b/sdgym/summary.py
index c1419ca7..ed97bbd4 100644
--- a/sdgym/summary.py
+++ b/sdgym/summary.py
@@ -18,6 +18,11 @@
'timeseries': []
}
+LIBRARIES = {
+ 'SDV': ['ctgan', 'copulagan', 'gaussiancopula', 'tvae', 'hma1', 'par'],
+ 'YData': ['dragan', 'vanilllagan', 'wgan'],
+}
+
def preprocess(data):
if isinstance(data, str):
@@ -52,8 +57,8 @@ def _mean_score(data):
return data.groupby('synthesizer').normalized_score.mean()
-def _best(data):
- ranks = data.groupby('dataset').rank(method='min', ascending=False)['normalized_score'] == 1
+def _best(data, rank, field, ascending):
+ ranks = data.groupby('dataset').rank(method='dense', ascending=ascending)[field] == rank
return ranks.groupby(data.synthesizer).sum()
@@ -76,7 +81,7 @@ def summarize(data, baselines=(), datasets=None):
"""Obtain an overview of the performance of each synthesizer.
Optionally compare the synthesizers with the indicated baselines or analyze
- only some o the datasets.
+ only some of the datasets.
Args:
data (pandas.DataFrame):
@@ -94,7 +99,7 @@ def summarize(data, baselines=(), datasets=None):
baselines_data = data[data.synthesizer.isin(baselines)]
data = data[~data.synthesizer.isin(baselines)]
- no_identity = data[data.synthesizer != 'Identity']
+ no_identity = data[data.synthesizer != 'DataIdentity']
coverage_perc, coverage_str = _coverage(data)
solved = data.groupby('synthesizer').apply(lambda x: x.normalized_score.notnull().sum())
@@ -105,9 +110,13 @@ def summarize(data, baselines=(), datasets=None):
'coverage': coverage_str,
'coverage_perc': coverage_perc,
'time': _seconds(data),
- 'best': _best(no_identity),
+ 'best': _best(no_identity, 1, 'normalized_score', False),
'avg score': _mean_score(data),
+ 'best_time': _best(no_identity, 1, 'model_time', True),
+ 'second_best_time': _best(no_identity, 2, 'model_time', True),
+ 'third_best_time': _best(no_identity, 3, 'model_time', True),
}
+
for baseline in baselines:
baseline_data = baselines_data[baselines_data.synthesizer == baseline]
baseline_scores = baseline_data.set_index('dataset').normalized_score
@@ -190,12 +199,38 @@ def add_sheet(dfs, name, writer, cell_fmt, index_fmt, header_fmt):
worksheet.set_column(idx, idx, width + 1, fmt)
+def _find_library(synthesizer):
+ for library, library_synthesizers in LIBRARIES.items():
+ for library_synthesizer in library_synthesizers:
+ if library_synthesizer in synthesizer.lower():
+ return library
+
+ return None
+
+
+def _add_summary_libraries(summary_data):
+ summary_data['library'] = summary_data.index.map(_find_library)
+ summary_data['library'].fillna('Other', inplace=True)
+ return summary_data
+
+
def _add_summary(data, modality, baselines, writer):
total_summary = summarize(data, baselines=baselines)
- summary = total_summary[['coverage_perc', 'time', 'avg score']].rename({
+
+ summary = total_summary[[
+ 'coverage_perc',
+ 'best_time',
+ 'second_best_time',
+ 'third_best_time',
+ ]].rename({
'coverage_perc': 'coverage %',
- 'time': 'avg time'
+ 'best_time': '# of Wins',
+ 'second_best_time': '# of 2nd best',
+ 'third_best_time': '# of 3rd best',
}, axis=1)
+ summary.drop(index='Identity', inplace=True, errors='ignore')
+ summary = _add_summary_libraries(summary)
+
beat_baseline_headers = ['beat_' + b.lower() for b in baselines]
quality = total_summary[['total', 'solved', 'best'] + beat_baseline_headers]
performance = total_summary[['time']]
diff --git a/sdgym/synthesizers/__init__.py b/sdgym/synthesizers/__init__.py
index 8336ce23..89fae288 100644
--- a/sdgym/synthesizers/__init__.py
+++ b/sdgym/synthesizers/__init__.py
@@ -1,38 +1,40 @@
-from sdgym.synthesizers.clbn import CLBN
-from sdgym.synthesizers.gretel import Gretel, PreprocessedGretel
-from sdgym.synthesizers.identity import Identity
-from sdgym.synthesizers.independent import Independent
-from sdgym.synthesizers.medgan import MedGAN
-from sdgym.synthesizers.privbn import PrivBN
+from sdgym.synthesizers.clbn import CLBNSynthesizer
+from sdgym.synthesizers.generate import (
+ SYNTHESIZER_MAPPING, create_multi_table_synthesizer, create_sdv_synthesizer_variant,
+ create_sequential_synthesizer, create_single_table_synthesizer)
+from sdgym.synthesizers.identity import DataIdentity
+from sdgym.synthesizers.independent import IndependentSynthesizer
+from sdgym.synthesizers.medgan import MedGANSynthesizer
+from sdgym.synthesizers.privbn import PrivBNSynthesizer
from sdgym.synthesizers.sdv import (
- CTGAN, CopulaGAN, GaussianCopulaCategorical, GaussianCopulaCategoricalFuzzy,
- GaussianCopulaOneHot)
-from sdgym.synthesizers.tablegan import TableGAN
-from sdgym.synthesizers.uniform import Uniform
-from sdgym.synthesizers.veegan import VEEGAN
-from sdgym.synthesizers.ydata import (
- DRAGAN, WGAN_GP, PreprocessedDRAGAN, PreprocessedVanillaGAN, PreprocessedWGAN_GP, VanillaGAN)
+ CopulaGANSynthesizer, CTGANSynthesizer, FastMLPreset, GaussianCopulaSynthesizer,
+ HMASynthesizer, PARSynthesizer, SDVRelationalSynthesizer, SDVTabularSynthesizer,
+ TVAESynthesizer)
+from sdgym.synthesizers.tablegan import TableGANSynthesizer
+from sdgym.synthesizers.uniform import UniformSynthesizer
+from sdgym.synthesizers.veegan import VEEGANSynthesizer
__all__ = (
- 'CLBN',
- 'Identity',
- 'Independent',
- 'MedGAN',
- 'PrivBN',
- 'TableGAN',
- 'CTGAN',
- 'Uniform',
- 'VEEGAN',
- 'CopulaGAN',
- 'GaussianCopulaCategorical',
- 'GaussianCopulaCategoricalFuzzy',
- 'GaussianCopulaOneHot',
- 'Gretel',
- 'PreprocessedGretel',
- 'VanillaGAN',
- 'DRAGAN',
- 'WGAN_GP',
- 'PreprocessedDRAGAN',
- 'PreprocessedWGAN_GP',
- 'PreprocessedVanillaGAN',
+ 'CLBNSynthesizer',
+ 'DataIdentity',
+ 'IndependentSynthesizer',
+ 'MedGANSynthesizer',
+ 'PrivBNSynthesizer',
+ 'TableGANSynthesizer',
+ 'CTGANSynthesizer',
+ 'TVAESynthesizer',
+ 'UniformSynthesizer',
+ 'VEEGANSynthesizer',
+ 'CopulaGANSynthesizer',
+ 'GaussianCopulaSynthesizer',
+ 'HMASynthesizer',
+ 'PARSynthesizer',
+ 'FastMLPreset',
+ 'SDVTabularSynthesizer',
+ 'SDVRelationalSynthesizer',
+ 'create_single_table_synthesizer',
+ 'create_multi_table_synthesizer',
+ 'create_sdv_synthesizer_variant',
+ 'create_sequential_synthesizer',
+ 'SYNTHESIZER_MAPPING',
)
diff --git a/sdgym/synthesizers/base.py b/sdgym/synthesizers/base.py
index adb70772..9d76087e 100644
--- a/sdgym/synthesizers/base.py
+++ b/sdgym/synthesizers/base.py
@@ -9,7 +9,7 @@
LOGGER = logging.getLogger(__name__)
-class Baseline(abc.ABC):
+class BaselineSynthesizer(abc.ABC):
"""Base class for all the ``SDGym`` baselines."""
MODALITIES = ()
@@ -42,11 +42,37 @@ def get_baselines(cls):
return synthesizers
- def fit_sample(self, real_data, metadata):
- pass
+ def get_trained_synthesizer(self, data, metadata):
+ """Get a synthesizer that has been trained on the provided data and metadata.
+ Args:
+ data (pandas.DataFrame or dict):
+ The data to train on.
+ metadata (sdv.Metadata):
+ The metadata.
+
+ Returns:
+ obj:
+ The synthesizer object
+ """
+
+ def sample_from_synthesizer(synthesizer, n_samples):
+ """Sample data from the provided synthesizer.
+
+ Args:
+ synthesizer (obj):
+ The synthesizer object to sample data from.
+ n_samples (int):
+ The number of samples to create.
+
+ Returns:
+ pandas.DataFrame or dict:
+ The sampled data. If single-table, should be a DataFrame. If multi-table,
+ should be a dict mapping table name to DataFrame.
+ """
-class SingleTableBaseline(Baseline, abc.ABC):
+
+class SingleTableBaselineSynthesizer(BaselineSynthesizer, abc.ABC):
"""Base class for all the SingleTable Baselines.
Subclasses can choose to implement ``_fit_sample``, which will
@@ -59,36 +85,63 @@ class SingleTableBaseline(Baseline, abc.ABC):
MODALITIES = ('single-table', )
CONVERT_TO_NUMERIC = False
- def _transform_fit_sample(self, real_data, metadata):
- ht = rdt.HyperTransformer()
+ def _get_transformed_trained_synthesizer(self, real_data, metadata):
+ self.ht = rdt.HyperTransformer()
columns_to_transform = list()
fields_metadata = metadata['fields']
- id_fields = list()
+ self.id_fields = list()
for field in fields_metadata:
if fields_metadata.get(field).get('type') != 'id':
columns_to_transform.append(field)
else:
- id_fields.append(field)
+ self.id_fields.append(field)
+
+ self.id_field_values = real_data[self.id_fields]
- ht.fit(real_data[columns_to_transform])
- transformed_data = ht.transform(real_data)
- synthetic_data = self._fit_sample(transformed_data, metadata)
- reverse_transformed_synthetic_data = ht.reverse_transform(synthetic_data)
- reverse_transformed_synthetic_data[id_fields] = real_data[id_fields]
+ self.ht.fit(real_data[columns_to_transform])
+ transformed_data = self.ht.transform(real_data)
+ return self._get_trained_synthesizer(transformed_data, metadata)
+
+ def _get_reverse_transformed_samples(self, data):
+ synthetic_data = self._sample_from_synthesizer(data)
+ reverse_transformed_synthetic_data = self.ht.reverse_transform(synthetic_data)
+ reverse_transformed_synthetic_data[self.id_fields] = self.id_field_values
return reverse_transformed_synthetic_data
- def fit_sample(self, real_data, metadata):
- _fit_sample = self._transform_fit_sample if self.CONVERT_TO_NUMERIC else self._fit_sample
- if isinstance(real_data, dict):
- return {
- table_name: _fit_sample(table, metadata.get_table_meta(table_name))
- for table_name, table in real_data.items()
- }
+ def get_trained_synthesizer(self, data, metadata):
+ """Get a synthesizer that has been trained on the provided data and metadata.
+
+ Args:
+ data (pandas.DataFrame):
+ The data to train on.
+ metadata (sdv.Metadata):
+ The metadata.
+
+ Returns:
+ obj:
+ The synthesizer object
+ """
+ return self._get_transformed_trained_synthesizer(data, metadata) if (
+ self.CONVERT_TO_NUMERIC) else self._get_trained_synthesizer(data, metadata)
+
+ def sample_from_synthesizer(self, synthesizer, n_samples):
+ """Sample data from the provided synthesizer.
- return _fit_sample(real_data, metadata)
+ Args:
+ synthesizer (obj):
+ The synthesizer object to sample data from.
+ n_samples (int):
+ The number of samples to create.
+
+ Returns:
+ pandas.DataFrame:
+ The sampled data.
+ """
+ return self._get_reverse_transformed_samples(synthesizer, n_samples) if (
+ self.CONVERT_TO_NUMERIC) else self._sample_from_synthesizer(synthesizer, n_samples)
-class MultiSingleTableBaseline(Baseline, abc.ABC):
+class MultiSingleTableBaselineSynthesizer(BaselineSynthesizer, abc.ABC):
"""Base class for SingleTableBaselines that are used on multi table scenarios.
These classes model and sample each table independently and then just
@@ -97,32 +150,62 @@ class MultiSingleTableBaseline(Baseline, abc.ABC):
MODALITIES = ('multi-table', 'single-table')
- def fit_sample(self, real_data, metadata):
- if isinstance(real_data, dict):
- tables = {
- table_name: self._fit_sample(table, metadata.get_table_meta(table_name))
- for table_name, table in real_data.items()
- }
+ def get_trained_synthesizer(self, data, metadata):
+ """Get the trained synthesizer.
- for table_name, table in tables.items():
- parents = metadata.get_parents(table_name)
- for parent_name in parents:
- parent = tables[parent_name]
- primary_key = metadata.get_primary_key(parent_name)
- foreign_keys = metadata.get_foreign_keys(parent_name, table_name)
- length = len(table)
- for foreign_key in foreign_keys:
- foreign_key_values = parent[primary_key].sample(length, replace=True)
- table[foreign_key] = foreign_key_values.values
+ Args:
+ data (dict):
+ A dict mapping table name to table data.
+ metadata (sdv.Metadata):
+ The multi-table metadata.
+
+ Returns:
+ dict:
+ A mapping of table name to synthesizers.
+ """
+ self.metadata = metadata
+ synthesizers = {
+ table_name: self._get_trained_synthesizer(table, metadata.get_table_meta(table_name))
+ for table_name, table in data.items()
+ }
+ self.table_columns = {table_name: data[table_name].columns for table_name in data.keys()}
+
+ return synthesizers
+
+ def sample_from_synthesizer(self, synthesizers, n_samples):
+ """Sample from the given synthesizers.
- tables[table_name] = table[real_data[table_name].columns]
+ Args:
+ synthesizers (dict):
+ A dict mapping table name to table synthesizer.
+ n_samples (int):
+ The number of samples.
+
+ Returns:
+ dict:
+ A mapping of table name to sampled table data.
+ """
+ tables = {
+ table_name: self._sample_from_synthesizer(synthesizer, n_samples)
+ for table_name, synthesizer in synthesizers.items()
+ }
- return tables
+ for table_name, table in tables.items():
+ parents = self.metadata.get_parents(table_name)
+ for parent_name in parents:
+ parent = tables[parent_name]
+ primary_key = self.metadata.get_primary_key(parent_name)
+ foreign_keys = self.metadata.get_foreign_keys(parent_name, table_name)
+ for foreign_key in foreign_keys:
+ foreign_key_values = parent[primary_key].sample(len(table), replace=True)
+ table[foreign_key] = foreign_key_values.values
- return self._fit_sample(real_data, metadata)
+ tables[table_name] = table[self.table_columns[table_name]]
+ return tables
-class LegacySingleTableBaseline(SingleTableBaseline, abc.ABC):
+
+class LegacySingleTableBaselineSynthesizer(SingleTableBaselineSynthesizer, abc.ABC):
"""Single table baseline which passes ordinals and categoricals down.
This class exists here to support the legacy baselines which do not operate
@@ -151,15 +234,32 @@ def _get_columns(self, real_data, table_metadata):
return model_columns, categorical_columns
- def _fit_sample(self, real_data, table_metadata):
- columns, categoricals = self._get_columns(real_data, table_metadata)
- real_data = real_data[columns]
+ def get_trained_synthesizer(self, data, metadata):
+ """Get the trained synthesizer.
- ht = rdt.HyperTransformer(default_data_type_transformers={
- 'categorical': 'LabelEncodingTransformer',
- })
- ht.fit(real_data.iloc[:, categoricals])
- model_data = ht.transform(real_data)
+ Args:
+ data (dict):
+ A dict mapping table name to table data.
+ metadata (sdv.Metadata):
+ The multi-table metadata.
+
+ Returns:
+ dict:
+ A mapping of table name to synthesizers.
+ """
+ self.columns, self.categoricals = self._get_columns(data, metadata)
+ data = data[self.columns]
+
+ if self.categoricals:
+ self.ht = rdt.HyperTransformer(default_data_type_transformers={
+ 'categorical': 'LabelEncodingTransformer',
+ })
+ self.ht.fit(data.iloc[:, self.categoricals])
+ model_data = self.ht.transform(data)
+ else:
+ model_data = data
+
+ self.model_columns = model_data.columns
supported = set(model_data.select_dtypes(('number', 'bool')).columns)
unsupported = set(model_data.columns) - supported
@@ -173,12 +273,25 @@ def _fit_sample(self, real_data, table_metadata):
raise UnsupportedDataset(f'Null values found in columns {unsupported_columns}')
LOGGER.info("Fitting %s", self.__class__.__name__)
- self.fit(model_data.to_numpy(), categoricals, ())
+ self.fit(model_data.to_numpy(), self.categoricals, ())
+
+ def sample_from_synthesizer(self, synthesizer, n_samples):
+ """Sample from the given synthesizers.
+
+ Args:
+ synthesizer:
+ The table synthesizer.
+ n_samples (int):
+ The number of samples.
+
+ Returns:
+ dict:
+ A mapping of table name to sampled table data.
+ """
+ sampled_data = self.sample(n_samples)
+ sampled_data = pd.DataFrame(sampled_data, columns=self.model_columns)
- LOGGER.info("Sampling %s", self.__class__.__name__)
- sampled_data = self.sample(len(model_data))
- sampled_data = pd.DataFrame(sampled_data, columns=columns)
+ if self.categoricals:
+ sampled_data = self.ht.reverse_transform(sampled_data)
- synthetic_data = real_data.copy()
- synthetic_data.update(ht.reverse_transform(sampled_data))
- return synthetic_data
+ return sampled_data
diff --git a/sdgym/synthesizers/clbn.py b/sdgym/synthesizers/clbn.py
index e26c1598..8289bf83 100644
--- a/sdgym/synthesizers/clbn.py
+++ b/sdgym/synthesizers/clbn.py
@@ -3,11 +3,11 @@
import numpy as np
from pomegranate import BayesianNetwork, ConditionalProbabilityTable, DiscreteDistribution
-from sdgym.synthesizers.base import LegacySingleTableBaseline
+from sdgym.synthesizers.base import LegacySingleTableBaselineSynthesizer
from sdgym.synthesizers.utils import DiscretizeTransformer
-class CLBN(LegacySingleTableBaseline):
+class CLBNSynthesizer(LegacySingleTableBaselineSynthesizer):
"""CLBNSynthesizer."""
def fit(self, data, categorical_columns=tuple(), ordinal_columns=tuple()):
diff --git a/sdgym/synthesizers/generate.py b/sdgym/synthesizers/generate.py
new file mode 100644
index 00000000..cbaba077
--- /dev/null
+++ b/sdgym/synthesizers/generate.py
@@ -0,0 +1,268 @@
+"""Synthesizers module."""
+
+from sdv.lite import TabularPreset
+from sdv.relational import HMA1
+from sdv.tabular import CTGAN, TVAE, CopulaGAN, GaussianCopula
+from sdv.timeseries import PAR
+
+from sdgym.synthesizers.base import (
+ BaselineSynthesizer, MultiSingleTableBaselineSynthesizer, SingleTableBaselineSynthesizer)
+from sdgym.synthesizers.sdv import FastMLPreset, SDVRelationalSynthesizer, SDVTabularSynthesizer
+
+SYNTHESIZER_MAPPING = {
+ 'FastMLPreset': TabularPreset,
+ 'GaussianCopulaSynthesizer': GaussianCopula,
+ 'CTGANSynthesizer': CTGAN,
+ 'CopulaGANSynthesizer': CopulaGAN,
+ 'TVAESynthesizer': TVAE,
+ 'PARSynthesizer': PAR,
+ 'HMASynthesizer': HMA1,
+}
+
+
+def create_sdv_synthesizer_variant(display_name, synthesizer_class, synthesizer_parameters):
+ """Create a new synthesizer that is a variant of an SDV tabular synthesizer.
+
+ Args:
+ display_name (string):
+ A string with the name of this synthesizer, used for display purposes only
+ when the results are generated.
+ synthesizer_class (string):
+ The name of the SDV synthesizer class. The available options are:
+
+ * 'FastMLPreset'
+ * 'GaussianCopulaSynthesizer'
+ * 'CTGANSynthesizer',
+ * 'CopulaGANSynthesizer'
+ * 'TVAESynthesizer'
+ * 'PARSynthesizer'
+ * 'HMASynthesizer'
+
+ synthesizer_parameters (dict):
+ A dictionary of the parameter names and values that will be used for the synthesizer.
+
+ Returns:
+ class:
+ The synthesizer class.
+ """
+ if synthesizer_class not in SYNTHESIZER_MAPPING.keys():
+ raise ValueError(
+ f'Synthesizer class {synthesizer_class} is not recognized. '
+ f"The supported options are {', '.join(SYNTHESIZER_MAPPING.keys())}"
+ )
+
+ baseclass = SDVTabularSynthesizer
+ if synthesizer_class == 'HMASynthesizer':
+ baseclass = SDVRelationalSynthesizer
+ if synthesizer_class == 'FastMLPreset':
+ baseclass = FastMLPreset
+
+ class NewSynthesizer(baseclass):
+ """New Synthesizer class.
+
+ Args:
+ synthesizer_class (string):
+ The name of the SDV synthesizer class. The available options are:
+
+ * 'FastMLPreset'
+ * 'GaussianCopulaSynthesizer'
+ * 'CTGANSynthesizer'
+ * 'CopulaGANSynthesizer'
+ * 'TVAESynthesizer'
+ * 'PARSynthesizer'
+
+ synthesizer_parameters (dict):
+ A dictionary of the parameter names and values that will be used for
+ the synthesizer.
+ """
+
+ _MODEL = SYNTHESIZER_MAPPING.get(synthesizer_class)
+ _MODEL_KWARGS = synthesizer_parameters
+
+ NewSynthesizer.__name__ = f'Variant:{display_name}'
+
+ return NewSynthesizer
+
+
+def create_single_table_synthesizer(display_name, get_trained_synthesizer_fn,
+ sample_from_synthesizer_fn):
+ """Create a new single-table synthesizer.
+
+ Args:
+ display_name(string):
+ A string with the name of this synthesizer, used for display purposes only when
+ the results are generated
+ get_trained_synthesizer_fn (callable):
+ A function to generate and train a synthesizer, given the real data and metadata.
+ sample_from_synthesizer (callable):
+ A function to sample from the given synthesizer.
+
+ Returns:
+ class:
+ The synthesizer class.
+ """
+ class NewSynthesizer(SingleTableBaselineSynthesizer):
+ """New Synthesizer class.
+
+ Args:
+ get_trained_synthesizer_fn (callable):
+ Function to replace the ``get_trained_synthesizer`` method.
+ sample_from_synthesizer_fn (callable):
+ Function to replace the ``sample_from_synthesizer`` method.
+ """
+
+ def get_trained_synthesizer(self, data, metadata):
+ """Create and train a synthesizer, given the real data and metadata.
+
+ Args:
+ data (pandas.DataFrame):
+ The real data.
+ metadata (sdv.Metadata):
+ The single table metadata.
+
+ Returns:
+ obj:
+ The trained synthesizer.
+ """
+ return get_trained_synthesizer_fn(data, metadata)
+
+ def sample_from_synthesizer(self, synthesizer, num_samples):
+ """Sample the desired number of samples from the given synthesizer.
+
+ Args:
+ synthesizer (obj):
+ The trained synthesizer.
+ num_samples (int):
+ The number of samples to generate.
+
+ Returns:
+ pandas.DataFrame:
+ The synthetic data.
+ """
+ return sample_from_synthesizer_fn(synthesizer, num_samples)
+
+ NewSynthesizer.__name__ = f'Custom:{display_name}'
+
+ return NewSynthesizer
+
+
+def create_multi_table_synthesizer(display_name, get_trained_synthesizer_fn,
+ sample_from_synthesizer_fn):
+ """Create a new multi-table synthesizer.
+
+ Args:
+ display_name(string):
+ A string with the name of this synthesizer, used for display purposes only when
+ the results are generated
+ get_trained_synthesizer_fn (callable):
+ A function to generate and train a synthesizer, given the real data and metadata.
+ sample_from_synthesizer (callable):
+ A function to sample from the given synthesizer.
+
+ Returns:
+ class:
+ The synthesizer class.
+ """
+ class NewSynthesizer(MultiSingleTableBaselineSynthesizer):
+ """New Synthesizer class.
+
+ Args:
+ get_trained_synthesizer_fn (callable):
+ Function to replace the ``get_trained_synthesizer`` method.
+ sample_from_synthesizer_fn (callable):
+ Function to replace the ``sample_from_synthesizer`` method.
+ """
+
+ def get_trained_synthesizer(self, data, metadata):
+ """Create and train a synthesizer, given the real data and metadata.
+
+ Args:
+ data (dict):
+ The real data. A mapping of table names to table data.
+ metadata (sdv.Metadata):
+ The multi table metadata.
+
+ Returns:
+ obj:
+ The trained synthesizer.
+ """
+ return get_trained_synthesizer_fn(data, metadata)
+
+ def sample_from_synthesizer(self, synthesizer):
+ """Sample from the given synthesizer.
+
+ Args:
+ synthesizer (obj):
+ The trained synthesizer.
+
+ Returns:
+ dict:
+ The synthetic data. A mapping of table names to table data.
+ """
+ return sample_from_synthesizer_fn(synthesizer)
+
+ NewSynthesizer.__name__ = f'Custom:{display_name}'
+
+ return NewSynthesizer
+
+
+def create_sequential_synthesizer(display_name, get_trained_synthesizer_fn,
+ sample_from_synthesizer_fn):
+ """Create a new sequential synthesizer.
+
+ Args:
+ display_name(string):
+ A string with the name of this synthesizer, used for display purposes only when
+ the results are generated
+ get_trained_synthesizer_fn (callable):
+ A function to generate and train a synthesizer, given the real data and metadata.
+ sample_from_synthesizer (callable):
+ A function to sample from the given synthesizer.
+
+ Returns:
+ class:
+ The synthesizer class.
+ """
+ class NewSynthesizer(BaselineSynthesizer):
+ """New Synthesizer class.
+
+ Args:
+ get_trained_synthesizer_fn (callable):
+ Function to replace the ``get_trained_synthesizer`` method.
+ sample_from_synthesizer_fn (callable):
+ Function to replace the ``sample_from_synthesizer`` method.
+ """
+
+ def get_trained_synthesizer(self, data, metadata):
+ """Create and train a synthesizer, given the real data and metadata.
+
+ Args:
+ data (dict):
+ The real data. A mapping of table names to table data.
+ metadata (sdv.Metadata):
+ The multi table metadata.
+
+ Returns:
+ obj:
+ The trained synthesizer.
+ """
+ return get_trained_synthesizer_fn(data, metadata)
+
+ def sample_from_synthesizer(self, synthesizer, n_sequences):
+ """Sample from the given synthesizer.
+
+ Args:
+ synthesizer (obj):
+ The trained synthesizer.
+ n_sequences (int):
+ The number of sequences to generate.
+
+ Returns:
+ dict:
+ The synthetic data. A mapping of table names to table data.
+ """
+ return sample_from_synthesizer_fn(synthesizer, n_sequences)
+
+ NewSynthesizer.__name__ = f'Custom:{display_name}'
+
+ return NewSynthesizer
diff --git a/sdgym/synthesizers/gretel.py b/sdgym/synthesizers/gretel.py
deleted file mode 100644
index ec3b3dfb..00000000
--- a/sdgym/synthesizers/gretel.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import tempfile
-
-import numpy as np
-
-from sdgym.synthesizers.base import SingleTableBaseline
-
-try:
- from gretel_synthetics.batch import DataFrameBatch
-except ImportError:
- DataFrameBatch = None
-
-
-class Gretel(SingleTableBaseline):
- """Class to represent Gretel's neural network model."""
-
- def __init__(self, max_lines=0, max_line_len=2048, epochs=None, vocab_size=20000,
- gen_lines=None, dp=False, field_delimiter=",", overwrite=True,
- checkpoint_dir=None):
- if DataFrameBatch is None:
- raise ImportError('Please install gretel-synthetics using `pip install sdgym[gretel]`')
-
- self.max_lines = max_lines
- self.max_line_len = max_line_len
- self.epochs = epochs
- self.vocab_size = vocab_size
- self.gen_lines = gen_lines
- self.dp = dp
- self.field_delimiter = field_delimiter
- self.overwrite = overwrite
- self.checkpoint_dir = checkpoint_dir or tempfile.TemporaryDirectory().name
-
- def _fit_sample(self, data, metadata):
- config = {
- 'max_lines': self.max_lines,
- 'max_line_len': self.max_line_len,
- 'epochs': self.epochs or data.shape[1] * 3, # value recommended by Gretel
- 'vocab_size': self.vocab_size,
- 'gen_lines': self.gen_lines or data.shape[0],
- 'dp': self.dp,
- 'field_delimiter': self.field_delimiter,
- 'overwrite': self.overwrite,
- 'checkpoint_dir': self.checkpoint_dir
- }
- batcher = DataFrameBatch(df=data, config=config)
- batcher.create_training_data()
- batcher.train_all_batches()
- batcher.generate_all_batch_lines()
- synth_data = batcher.batches_to_df()
- return synth_data
-
-
-class PreprocessedGretel(Gretel):
- """Class that uses RDT to make all columns numeric before using Gretel's model."""
-
- CONVERT_TO_NUMERIC = True
-
- @staticmethod
- def make_numeric(val):
- if type(val) in [float, int]:
- return val
-
- if isinstance(val, str) and val.isnumeric():
- return float(val)
-
- return np.nan
-
- def _fix_numeric_columns(self, data, metadata):
- fields_metadata = metadata['fields']
- for field in data:
- if field in fields_metadata and fields_metadata.get(field).get('type') == 'id':
- continue
-
- data[field] = data[field].apply(self.make_numeric)
- avg = data[field].mean() if not np.isnan(data[field].mean()) else 0
- data[field] = data[field].fillna(round(avg))
-
- def _fit_sample(self, data, metadata):
- synth_data = super()._fit_sample(data, metadata)
- self._fix_numeric_columns(synth_data, metadata)
- return synth_data
diff --git a/sdgym/synthesizers/identity.py b/sdgym/synthesizers/identity.py
index c853c81c..f78a0544 100644
--- a/sdgym/synthesizers/identity.py
+++ b/sdgym/synthesizers/identity.py
@@ -1,14 +1,20 @@
import copy
-from sdgym.synthesizers.base import Baseline
+from sdgym.synthesizers.base import BaselineSynthesizer
-class Identity(Baseline):
+class DataIdentity(BaselineSynthesizer):
"""Trivial synthesizer.
Returns the same exact data that is used to fit it.
"""
- def fit_sample(self, real_data, metadata):
- del metadata
- return copy.deepcopy(real_data)
+ def __init__(self):
+ self._data = None
+
+ def get_trained_synthesizer(self, data, metadata):
+ self._data = data
+ return None
+
+ def sample_from_synthesizer(self, synthesizer, n_samples):
+ return copy.deepcopy(self._data)
diff --git a/sdgym/synthesizers/independent.py b/sdgym/synthesizers/independent.py
index a70ba22b..a32e09b5 100644
--- a/sdgym/synthesizers/independent.py
+++ b/sdgym/synthesizers/independent.py
@@ -2,32 +2,43 @@
from sdv.metadata import Table
from sklearn.mixture import GaussianMixture
-from sdgym.synthesizers.base import MultiSingleTableBaseline
+from sdgym.synthesizers.base import MultiSingleTableBaselineSynthesizer
-class Independent(MultiSingleTableBaseline):
+class IndependentSynthesizer(MultiSingleTableBaselineSynthesizer):
"""Synthesizer that learns each column independently.
Categorical columns are sampled using empirical frequencies.
Continuous columns are learned and sampled using a GMM.
"""
- @staticmethod
- def _fit_sample(real_data, metadata):
+ def _get_trained_synthesizer(self, real_data, metadata):
metadata = Table(metadata, dtype_transformers={'O': None, 'i': None})
metadata.fit(real_data)
transformed = metadata.transform(real_data)
+ self.length = len(real_data)
+ gm_models = {}
+ for name, column in transformed.items():
+ kind = column.dtype.kind
+ if kind != 'O':
+ num_components = min(column.nunique(), 5)
+ model = GaussianMixture(num_components)
+ model.fit(column.values.reshape(-1, 1))
+ gm_models[name] = model
+
+ return (metadata, transformed, gm_models)
+
+ def _sample_from_synthesizer(self, synthesizer, n_samples):
+ metadata, transformed, gm_models = synthesizer
sampled = pd.DataFrame()
- length = len(real_data)
for name, column in transformed.items():
kind = column.dtype.kind
if kind == 'O':
- values = column.sample(length, replace=True).values
+ values = column.sample(self.length, replace=True).values
else:
- model = GaussianMixture(5)
- model.fit(column.values.reshape(-1, 1))
- values = model.sample(length)[0].ravel().clip(column.min(), column.max())
+ model = gm_models.get(name)
+ values = model.sample(self.length)[0].ravel().clip(column.min(), column.max())
sampled[name] = values
diff --git a/sdgym/synthesizers/medgan.py b/sdgym/synthesizers/medgan.py
index 65bb7988..30fe5e15 100644
--- a/sdgym/synthesizers/medgan.py
+++ b/sdgym/synthesizers/medgan.py
@@ -6,7 +6,7 @@
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
-from sdgym.synthesizers.base import LegacySingleTableBaseline
+from sdgym.synthesizers.base import LegacySingleTableBaselineSynthesizer
from sdgym.synthesizers.utils import GeneralTransformer, select_device
@@ -118,7 +118,7 @@ def aeloss(fake, real, output_info):
return sum(loss) / fake.size()[0]
-class MedGAN(LegacySingleTableBaseline):
+class MedGANSynthesizer(LegacySingleTableBaselineSynthesizer):
"""docstring for MedGAN."""
def __init__(self,
diff --git a/sdgym/synthesizers/privbn.py b/sdgym/synthesizers/privbn.py
index 3346e301..d368fea9 100644
--- a/sdgym/synthesizers/privbn.py
+++ b/sdgym/synthesizers/privbn.py
@@ -9,7 +9,7 @@
import numpy as np
from sdgym.constants import CATEGORICAL, ORDINAL
-from sdgym.synthesizers.base import LegacySingleTableBaseline
+from sdgym.synthesizers.base import LegacySingleTableBaselineSynthesizer
from sdgym.synthesizers.utils import Transformer
LOGGER = logging.getLogger(__name__)
@@ -20,7 +20,7 @@ def try_mkdirs(dir):
os.makedirs(dir)
-class PrivBN(LegacySingleTableBaseline):
+class PrivBNSynthesizer(LegacySingleTableBaselineSynthesizer):
"""docstring for PrivBN."""
def __init__(self, theta=20, max_samples=25000):
diff --git a/sdgym/synthesizers/sdv.py b/sdgym/synthesizers/sdv.py
index 8244b18b..cc86b8c1 100644
--- a/sdgym/synthesizers/sdv.py
+++ b/sdgym/synthesizers/sdv.py
@@ -4,137 +4,132 @@
import sdv
import sdv.timeseries
-from sdgym.synthesizers.base import Baseline, SingleTableBaseline
+from sdgym.synthesizers.base import BaselineSynthesizer, SingleTableBaselineSynthesizer
from sdgym.synthesizers.utils import select_device
LOGGER = logging.getLogger(__name__)
-class SDV(Baseline, abc.ABC):
+class FastMLPreset(SingleTableBaselineSynthesizer):
- MODALITIES = ('single-table', 'multi-table')
+ MODALITIES = ('single-table', )
+ _MODEL = None
+ _MODEL_KWARGS = None
- def fit_sample(self, data, metadata):
- LOGGER.info('Fitting SDV')
- model = sdv.SDV()
- model.fit(metadata, data)
+ def _get_trained_synthesizer(self, data, metadata):
+ model_kwargs = self._MODEL_KWARGS.copy() if self._MODEL_KWARGS else {}
+ model = sdv.lite.TabularPreset(name='FAST_ML', metadata=metadata, **model_kwargs)
+ model.fit(data)
- LOGGER.info('Sampling SDV')
- return model.sample_all()
+ return model
+ def _sample_from_synthesizer(self, synthesizer, n_samples):
+ return synthesizer.sample(n_samples)
-class SDVTabular(SingleTableBaseline, abc.ABC):
+
+class SDVTabularSynthesizer(SingleTableBaselineSynthesizer, abc.ABC):
MODALITIES = ('single-table', )
_MODEL = None
_MODEL_KWARGS = None
- def _fit_sample(self, data, metadata):
+ def _get_trained_synthesizer(self, data, metadata):
LOGGER.info('Fitting %s', self.__class__.__name__)
model_kwargs = self._MODEL_KWARGS.copy() if self._MODEL_KWARGS else {}
model = self._MODEL(table_metadata=metadata, **model_kwargs)
model.fit(data)
+ return model
+ def _sample_from_synthesizer(self, synthesizer, n_samples):
LOGGER.info('Sampling %s', self.__class__.__name__)
- return model.sample()
-
+ return synthesizer.sample(n_samples)
-class GaussianCopulaCategorical(SDVTabular):
-
- _MODEL = sdv.tabular.GaussianCopula
- _MODEL_KWARGS = {
- 'categorical_transformer': 'categorical'
- }
-
-
-class GaussianCopulaCategoricalFuzzy(SDVTabular):
-
- _MODEL = sdv.tabular.GaussianCopula
- _MODEL_KWARGS = {
- 'categorical_transformer': 'categorical_fuzzy'
- }
-
-class GaussianCopulaOneHot(SDVTabular):
+class GaussianCopulaSynthesizer(SDVTabularSynthesizer):
_MODEL = sdv.tabular.GaussianCopula
- _MODEL_KWARGS = {
- 'categorical_transformer': 'OneHotEncodingTransformer'
- }
-class CUDATabular(SDVTabular, abc.ABC):
+class CUDATabularSynthesizer(SDVTabularSynthesizer, abc.ABC):
- def _fit_sample(self, data, metadata):
- LOGGER.info('Fitting %s', self.__class__.__name__)
+ def _get_trained_synthesizer(self, data, metadata):
model_kwargs = self._MODEL_KWARGS.copy() if self._MODEL_KWARGS else {}
model_kwargs.setdefault('cuda', select_device())
+ LOGGER.info('Fitting %s with kwargs %s', self.__class__.__name__, model_kwargs)
model = self._MODEL(table_metadata=metadata, **model_kwargs)
model.fit(data)
+ return model
+ def _sample_from_synthesizer(self, synthesizer, n_samples):
LOGGER.info('Sampling %s', self.__class__.__name__)
- return model.sample()
+ return synthesizer.sample(n_samples)
-class CTGAN(CUDATabular):
+class CTGANSynthesizer(CUDATabularSynthesizer):
_MODEL = sdv.tabular.CTGAN
-class TVAE(CUDATabular):
+class TVAESynthesizer(CUDATabularSynthesizer):
_MODEL = sdv.tabular.TVAE
-class CopulaGAN(CUDATabular):
+class CopulaGANSynthesizer(CUDATabularSynthesizer):
_MODEL = sdv.tabular.CopulaGAN
-class SDVRelational(Baseline, abc.ABC):
+class SDVRelationalSynthesizer(BaselineSynthesizer, abc.ABC):
MODALITIES = ('single-table', 'multi-table')
_MODEL = None
_MODEL_KWARGS = None
- def fit_sample(self, data, metadata):
+ def _get_trained_synthesizer(self, data, metadata):
LOGGER.info('Fitting %s', self.__class__.__name__)
model_kwargs = self._MODEL_KWARGS.copy() if self._MODEL_KWARGS else {}
model = self._MODEL(metadata=metadata, **model_kwargs)
model.fit(data)
+ return model
+ def _sample_from_synthesizer(self, synthesizer, n_samples):
LOGGER.info('Sampling %s', self.__class__.__name__)
- return model.sample()
+ return synthesizer.sample()
-class HMA1(SDVRelational):
+class HMASynthesizer(SDVRelationalSynthesizer):
_MODEL = sdv.relational.HMA1
-class SDVTimeseries(SingleTableBaseline, abc.ABC):
+class SDVTimeseriesSynthesizer(SingleTableBaselineSynthesizer, abc.ABC):
MODALITIES = ('timeseries', )
_MODEL = None
_MODEL_KWARGS = None
- def _fit_sample(self, data, metadata):
+ def _get_trained_synthesizer(self, data, metadata):
LOGGER.info('Fitting %s', self.__class__.__name__)
model_kwargs = self._MODEL_KWARGS.copy() if self._MODEL_KWARGS else {}
model = self._MODEL(table_metadata=metadata, **model_kwargs)
model.fit(data)
+ return model
+ def _sample_from_synthesizer(self, synthesizer, n_samples):
LOGGER.info('Sampling %s', self.__class__.__name__)
- return model.sample()
+ return synthesizer.sample()
-class PAR(SDVTimeseries):
+class PARSynthesizer(SDVTimeseriesSynthesizer):
- def _fit_sample(self, data, metadata):
+ def _get_trained_synthesizer(self, data, metadata):
LOGGER.info('Fitting %s', self.__class__.__name__)
model = sdv.timeseries.PAR(table_metadata=metadata, epochs=1024, verbose=False)
model.device = select_device()
model.fit(data)
+ return model
+ def _sample_from_synthesizer(self, synthesizer, n_samples):
LOGGER.info('Sampling %s', self.__class__.__name__)
- return model.sample()
+ return synthesizer.sample()
diff --git a/sdgym/synthesizers/tablegan.py b/sdgym/synthesizers/tablegan.py
index ec77742d..5fc536e9 100644
--- a/sdgym/synthesizers/tablegan.py
+++ b/sdgym/synthesizers/tablegan.py
@@ -9,7 +9,7 @@
from torch.utils.data import DataLoader, TensorDataset
from sdgym.constants import CATEGORICAL
-from sdgym.synthesizers.base import LegacySingleTableBaseline
+from sdgym.synthesizers.base import LegacySingleTableBaselineSynthesizer
from sdgym.synthesizers.utils import TableganTransformer, select_device
@@ -116,8 +116,8 @@ def weights_init(m):
init.constant_(m.bias.data, 0)
-class TableGAN(LegacySingleTableBaseline):
- """docstring for TableganSynthesizer??"""
+class TableGANSynthesizer(LegacySingleTableBaselineSynthesizer):
+ """docstring for TableganSynthesizer"""
def __init__(self,
random_dim=100,
diff --git a/sdgym/synthesizers/uniform.py b/sdgym/synthesizers/uniform.py
index bc39b201..ef96ce66 100644
--- a/sdgym/synthesizers/uniform.py
+++ b/sdgym/synthesizers/uniform.py
@@ -2,28 +2,30 @@
import pandas as pd
from sdv.metadata import Table
-from sdgym.synthesizers.base import MultiSingleTableBaseline
+from sdgym.synthesizers.base import MultiSingleTableBaselineSynthesizer
-class Uniform(MultiSingleTableBaseline):
+class UniformSynthesizer(MultiSingleTableBaselineSynthesizer):
"""Synthesizer that samples each column using a Uniform distribution."""
- @staticmethod
- def _fit_sample(real_data, metadata):
+ def _get_trained_synthesizer(self, real_data, metadata):
metadata = Table(metadata, dtype_transformers={'O': None, 'i': None})
metadata.fit(real_data)
transformed = metadata.transform(real_data)
+ self.length = len(real_data)
+ return (metadata, transformed)
+ def _sample_from_synthesizer(self, synthesizer, n_samples):
+ metadata, transformed = synthesizer
sampled = pd.DataFrame()
- length = len(real_data)
for name, column in transformed.items():
kind = column.dtype.kind
if kind == 'i':
- values = np.random.randint(column.min(), column.max() + 1, size=length)
+ values = np.random.randint(column.min(), column.max() + 1, size=self.length)
elif kind == 'O':
- values = np.random.choice(column.unique(), size=length)
+ values = np.random.choice(column.unique(), size=self.length)
else:
- values = np.random.uniform(column.min(), column.max(), size=length)
+ values = np.random.uniform(column.min(), column.max(), size=self.length)
sampled[name] = values
diff --git a/sdgym/synthesizers/veegan.py b/sdgym/synthesizers/veegan.py
index 3652ff74..851ca372 100644
--- a/sdgym/synthesizers/veegan.py
+++ b/sdgym/synthesizers/veegan.py
@@ -5,7 +5,7 @@
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
-from sdgym.synthesizers.base import LegacySingleTableBaseline
+from sdgym.synthesizers.base import LegacySingleTableBaselineSynthesizer
from sdgym.synthesizers.utils import GeneralTransformer, select_device
@@ -77,7 +77,7 @@ def forward(self, input, output_info):
return torch.cat(data_t, dim=1)
-class VEEGAN(LegacySingleTableBaseline):
+class VEEGANSynthesizer(LegacySingleTableBaselineSynthesizer):
"""VEEGANSynthesizer."""
def __init__(
diff --git a/sdgym/synthesizers/ydata.py b/sdgym/synthesizers/ydata.py
deleted file mode 100644
index af8ace7f..00000000
--- a/sdgym/synthesizers/ydata.py
+++ /dev/null
@@ -1,109 +0,0 @@
-import abc
-
-from sdgym.synthesizers.base import SingleTableBaseline
-
-try:
- import ydata_synthetic.synthesizers.regular as ydata
-except ImportError:
- ydata = None
-
-
-class YData(SingleTableBaseline, abc.ABC):
-
- def _fit_sample(self, real_data, table_metadata):
- if ydata is None:
- raise ImportError('Please install ydata using `make install-ydata`.')
-
- columns = real_data.columns
- synthesizer = self._build_ydata_synthesizer(real_data)
- synthetic_data = synthesizer.sample(len(real_data))
- synthetic_data.columns = columns
-
- return synthetic_data
-
-
-class VanillaGAN(YData):
-
- def __init__(self, noise_dim=32, dim=128, batch_size=128, log_step=100,
- epochs=201, learning_rate=5e-4, beta_1=0.5, beta_2=0.9):
- self.noise_dim = noise_dim
- self.dim = dim
- self.batch_size = batch_size
- self.log_step = log_step
- self.epochs = epochs
- self.learning_rate = learning_rate
- self.beta_1 = beta_1
- self.beta_2 = beta_2
-
- def _build_ydata_synthesizer(self, data):
- model_args = [self.batch_size, self.learning_rate, self.beta_1, self.beta_2,
- self.noise_dim, data.shape[1], self.dim]
- train_args = ['', self.epochs, self.log_step]
-
- synthesizer = ydata.VanilllaGAN(model_args)
- synthesizer.train(data, train_args)
-
- return synthesizer
-
-
-class WGAN_GP(YData):
-
- def __init__(self, noise_dim=32, dim=128, batch_size=128, log_step=100,
- epochs=201, learning_rate=5e-4, beta_1=0.5, beta_2=0.9):
- self.noise_dim = noise_dim
- self.dim = dim
- self.batch_size = batch_size
- self.log_step = log_step
- self.epochs = epochs
- self.learning_rate = learning_rate
- self.beta_1 = beta_1
- self.beta_2 = beta_2
-
- def _build_ydata_synthesizer(self, data):
- model_args = [self.batch_size, self.learning_rate, self.beta_1, self.beta_2,
- self.noise_dim, data.shape[1], self.dim]
- train_args = ['', self.epochs, self.log_step]
-
- synthesizer = ydata.WGAN_GP(model_args, n_critic=2)
- synthesizer.train(data, train_args)
-
- return synthesizer
-
-
-class DRAGAN(YData):
-
- def __init__(self, noise_dim=128, dim=128, batch_size=500, log_step=100,
- epochs=201, learning_rate=1e-5, beta_1=0.5, beta_2=0.9):
- self.noise_dim = noise_dim
- self.dim = dim
- self.batch_size = batch_size
- self.log_step = log_step
- self.epochs = epochs
- self.learning_rate = learning_rate
- self.beta_1 = beta_1
- self.beta_2 = beta_2
-
- def _build_ydata_synthesizer(self, data):
- gan_args = [self.batch_size, self.learning_rate, self.beta_1, self.beta_2,
- self.noise_dim, data.shape[1], self.dim]
- train_args = ['', self.epochs, self.log_step]
-
- synthesizer = ydata.DRAGAN(gan_args, n_discriminator=3)
- synthesizer.train(data, train_args)
-
- return synthesizer
-
-
-class PreprocessedDRAGAN(DRAGAN):
-
- CONVERT_TO_NUMERIC = True
-
-
-class PreprocessedWGAN_GP(WGAN_GP):
-
- CONVERT_TO_NUMERIC = True
-
-
-class PreprocessedVanillaGAN(VanillaGAN):
-
- CONVERT_TO_NUMERIC = True
diff --git a/sdgym/utils.py b/sdgym/utils.py
index 12fcee12..4b9aa4d8 100644
--- a/sdgym/utils.py
+++ b/sdgym/utils.py
@@ -10,10 +10,11 @@
import types
import humanfriendly
+import pandas as pd
import psutil
from sdgym.errors import SDGymError
-from sdgym.synthesizers.base import Baseline
+from sdgym.synthesizers.base import BaselineSynthesizer
from sdgym.synthesizers.utils import select_device
LOGGER = logging.getLogger(__name__)
@@ -67,7 +68,7 @@ def _get_synthesizer_name(synthesizer):
if isinstance(synthesizer, types.MethodType):
synthesizer_name = synthesizer.__self__.__class__.__name__
else:
- synthesizer_name = synthesizer.__name__
+ synthesizer_name = getattr(synthesizer, '__name__', 'undefined')
return synthesizer_name
@@ -82,7 +83,7 @@ def _get_synthesizer(synthesizer, name=None):
with open(synthesizer, 'r') as json_file:
return json.load(json_file)
- baselines = Baseline.get_subclasses(include_parents=True)
+ baselines = BaselineSynthesizer.get_subclasses(include_parents=True)
if synthesizer in baselines:
LOGGER.info('Trying to import synthesizer by name.')
synthesizer = baselines[synthesizer]
@@ -110,16 +111,17 @@ def _get_synthesizer(synthesizer, name=None):
}
-def get_synthesizers(synthesizers):
+def get_synthesizers(synthesizers=None):
"""Get the dict of synthesizers from the input value.
If the input is a synthesizer or an iterable of synthesizers, get their names
- and put them on a dict.
+ and put them on a dict. If None is given, get all the available synthesizers.
Args:
- synthesizers (function, class, list, tuple or dict):
+ synthesizers (function, class, list, tuple, dict or None):
A synthesizer (function or method or class) or an iterable of synthesizers
or a dict containing synthesizer names as keys and synthesizers as values.
+ If no synthesizers are given, all the available ones are returned.
Returns:
dict[str, function]:
@@ -129,10 +131,13 @@ def get_synthesizers(synthesizers):
TypeError:
if neither a synthesizer or an iterable or a dict is passed.
"""
- if callable(synthesizers):
+ if callable(synthesizers) or isinstance(synthesizers, tuple):
return [_get_synthesizer(synthesizers)]
- if isinstance(synthesizers, (list, tuple)):
+ if not synthesizers:
+ synthesizers = BaselineSynthesizer.get_baselines()
+
+ if isinstance(synthesizers, list):
return [
_get_synthesizer(synthesizer)
for synthesizer in synthesizers
@@ -191,7 +196,7 @@ def build_synthesizer(synthesizer, synthesizer_dict):
_synthesizer_dict = copy.deepcopy(synthesizer_dict)
- def _synthesizer_function(real_data, metadata):
+ def _synthesizer_fit_function(real_data, metadata):
metadata_keyword = _synthesizer_dict.get('metadata', '$metadata')
real_data_keyword = _synthesizer_dict.get('real_data', '$real_data')
device_keyword = _synthesizer_dict.get('device', '$device')
@@ -218,11 +223,63 @@ def _synthesizer_function(real_data, metadata):
setattr(instance, device_attribute, device)
instance.fit(**fit_kwargs)
+ return instance
+ def _synthesizer_sample_function(instance, n_samples=None):
sampled = instance.sample()
- if not multi_table:
- sampled = {table: sampled}
return sampled
- return _synthesizer_function
+ return _synthesizer_fit_function, _synthesizer_sample_function
+
+
+def get_size_of(obj, obj_ids=None):
+ """Get the memory used by a given object in bytes.
+
+ Args:
+ obj (object):
+ The object to get the size of.
+ obj_ids (set):
+ The ids of the objects that have already been evaluated.
+
+ Returns:
+ int:
+ The size in bytes.
+ """
+ size = 0
+ if obj_ids is None:
+ obj_ids = set()
+
+ obj_id = id(obj)
+ if obj_id in obj_ids:
+ return 0
+
+ obj_ids.add(obj_id)
+ if isinstance(obj, dict):
+ size += sum([get_size_of(v, obj_ids) for v in obj.values()])
+ elif isinstance(obj, pd.DataFrame):
+ size += obj.memory_usage(index=True).sum()
+ elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
+ size += sum([get_size_of(i, obj_ids) for i in obj])
+ else:
+ size += sys.getsizeof(obj)
+
+ return size
+
+
+def get_duplicates(items):
+ """Get any duplicate items in the given list.
+
+ Args:
+ items (list):
+ The list of items to de-deduplicate.
+
+ Returns:
+ set:
+ The duplicate items.
+ """
+ seen = set()
+ return set(
+ item for item in items
+ if item in seen or seen.add(item)
+ )
diff --git a/setup.cfg b/setup.cfg
index 253a878f..a3a9dcf0 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,5 +1,5 @@
[bumpversion]
-current_version = 0.5.0
+current_version = 0.6.0.dev2
commit = True
tag = True
parse = (?P\d+)\.(?P\d+)\.(?P\d+)(\.(?P[a-z]+)(?P\d+))?
diff --git a/setup.py b/setup.py
index 7e5c95a7..384f9508 100644
--- a/setup.py
+++ b/setup.py
@@ -17,21 +17,24 @@
'botocore>=1.18,<2',
'compress-pickle>=1.2.0,<3',
'humanfriendly>=8.2,<11',
- "numpy>=1.18.0,<1.20.0;python_version<'3.7'",
- "numpy>=1.20.0,<2;python_version>='3.7'",
- 'pandas>=1.1.3,<2',
- "pomegranate>=0.13.4,<0.14.2;python_version<'3.7'",
- "pomegranate>=0.14.1,<0.15;python_version>='3.7'",
+ "numpy>=1.20.0,<2;python_version<'3.10'",
+ "numpy>=1.23.3,<2;python_version>='3.10'",
+ "pandas>=1.1.3,<2;python_version<'3.10'",
+ "pandas>=1.5.0,<2;python_version>='3.10'",
+ "pomegranate>=0.14.3,<0.15",
'psutil>=5.7,<6',
- 'scikit-learn>=0.24,<2',
- 'scipy>=1.5.4,<2',
+ "scikit-learn>=0.24,<2;python_version<'3.10'",
+ "scikit-learn>=1.1.3,<2;python_version>='3.10'",
+ "scipy>=1.5.4,<2;python_version<'3.10'",
+ "scipy>=1.9.2,<2;python_version>='3.10'",
'tabulate>=0.8.3,<0.9',
- 'torch>=1.8.0,<2',
+ "torch>=1.8.0,<2;python_version<'3.10'",
+ "torch>=1.11.0,<2;python_version>='3.10'",
'tqdm>=4.15,<5',
'XlsxWriter>=1.2.8,<4',
- 'rdt>=0.6.1,<0.7',
- 'sdmetrics>=0.4.1,<0.5',
- 'sdv>=0.13.0',
+ 'rdt>=1.3.0,<2.0',
+ 'sdmetrics>=0.9.0,<1.0',
+ 'sdv>=0.18.0',
]
@@ -40,18 +43,6 @@
'distributed',
]
-
-ydata_requires = [
- # preferably install using make install-ydata
- 'ydata-synthetic>=0.3.0,<0.4',
-]
-
-gretel_requires = [
- 'gretel-synthetics>=0.15.4,<0.16',
- 'tensorflow==2.4.0rc1',
- 'wheel~=0.35',
-]
-
setup_requires = [
'pytest-runner>=2.11.1',
]
@@ -91,23 +82,22 @@
]
setup(
- author='MIT Data To AI Lab',
- author_email='dailabmit@gmail.com',
+ author='DataCebo, Inc.',
+ author_email='info@sdv.dev',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
- 'License :: OSI Approved :: MIT License',
+ 'License :: Free for non-commercial use',
'Natural Language :: English',
'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
+ 'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
description=(
- 'A framework to benchmark the performance of synthetic data generators '
- 'for non-temporal tabular data'
+ 'Benchmark tabular synthetic data generators using a variety of datasets'
),
entry_points={
'console_scripts': [
@@ -115,25 +105,24 @@
],
},
extras_require={
- 'all': development_requires + tests_require + dask_requires + gretel_requires,
+ 'all': development_requires + tests_require + dask_requires,
'dev': development_requires + tests_require + dask_requires,
'test': tests_require,
- 'gretel': gretel_requires,
'dask': dask_requires,
},
include_package_data=True,
install_requires=install_requires,
- license='MIT license',
+ license='BSL-1.1',
long_description=readme + '\n\n' + history,
long_description_content_type='text/markdown',
keywords='machine learning synthetic data generation benchmark generative models',
name='sdgym',
packages=find_packages(include=['sdgym', 'sdgym.*']),
- python_requires='>=3.6,<3.10',
+ python_requires='>=3.7,<3.11',
setup_requires=setup_requires,
test_suite='tests',
tests_require=tests_require,
url='https://github.com/sdv-dev/SDGym',
- version='0.5.0',
+ version='0.6.0.dev2',
zip_safe=False,
)
diff --git a/tasks.py b/tasks.py
index 75847ab0..71605486 100644
--- a/tasks.py
+++ b/tasks.py
@@ -1,7 +1,9 @@
import glob
+import inspect
import operator
import os
import re
+import pkg_resources
import platform
import shutil
import stat
@@ -17,6 +19,10 @@
}
+if not hasattr(inspect, 'getargspec'):
+ inspect.getargspec = inspect.getfullargspec
+
+
@task
def check_dependencies(c):
c.run('python -m pip check')
@@ -48,15 +54,18 @@ def readme(c):
def _validate_python_version(line):
- python_version_match = re.search(r"python_version(<=?|>=?)\'(\d\.?)+\'", line)
- if python_version_match:
+ is_valid = True
+ for python_version_match in re.finditer(r"python_version(<=?|>=?|==)\'(\d\.?)+\'", line):
python_version = python_version_match.group(0)
- comparison = re.search(r'(>=?|<=?)', python_version).group(0)
+ comparison = re.search(r'(>=?|<=?|==)', python_version).group(0)
version_number = python_version.split(comparison)[-1].replace("'", "")
comparison_function = COMPARISONS[comparison]
- return comparison_function(platform.python_version(), version_number)
+ is_valid = is_valid and comparison_function(
+ pkg_resources.parse_version(platform.python_version()),
+ pkg_resources.parse_version(version_number),
+ )
- return True
+ return is_valid
@task
@@ -69,8 +78,7 @@ def install_minimum(c):
for line in lines:
if started:
if line == ']':
- started = False
- continue
+ break
line = line.strip()
if _validate_python_version(line):
diff --git a/tests/integration/test_benchmark.py b/tests/integration/test_benchmark.py
index 1b256c69..438b62b8 100644
--- a/tests/integration/test_benchmark.py
+++ b/tests/integration/test_benchmark.py
@@ -1,82 +1,114 @@
-import json
+import contextlib
+import io
+
+import pytest
import sdgym
+from sdgym.synthesizers import create_single_table_synthesizer
def test_identity():
- output = sdgym.run(
- synthesizers=['Identity', 'Independent', 'Uniform'],
- datasets=['trains_v1', 'KRK_v1'],
+ output = sdgym.benchmark_single_table(
+ synthesizers=['DataIdentity', 'IndependentSynthesizer', 'UniformSynthesizer'],
+ sdv_datasets=['student_placements'],
)
assert not output.empty
- assert set(output['modality'].unique()) == {'single-table', 'multi-table'}
+ assert 'Train_Time' in output
+ assert 'Sample_Time' in output
- scores = output.groupby('synthesizer').score.mean().sort_values()
+ scores = output.groupby('Synthesizer').NewRowSynthesis.mean().sort_values()
- assert ['Uniform', 'Independent', 'Identity'] == scores.index.tolist()
+ assert [
+ 'DataIdentity',
+ 'IndependentSynthesizer',
+ 'UniformSynthesizer',
+ ] == scores.index.tolist()
+ quality_scores = output.groupby('Synthesizer').Quality_Score.mean().sort_values()
-def test_identity_jobs():
- jobs = [
- ('Identity', 'trains_v1', 0),
- ('Independent', 'trains_v1', 1),
- ('Uniform', 'KRK_v1', 1),
- ]
- output = sdgym.run(jobs=jobs)
+ assert [
+ 'UniformSynthesizer',
+ 'IndependentSynthesizer',
+ 'DataIdentity',
+ ] == quality_scores.index.tolist()
- assert not output.empty
- assert set(output['modality'].unique()) == {'single-table', 'multi-table'}
- columns = ['synthesizer', 'dataset', 'iteration']
- combinations = set(
- tuple(record)
- for record in output[columns].drop_duplicates().to_records(index=False)
+def test_benchmarking_no_metrics():
+ output = sdgym.benchmark_single_table(
+ synthesizers=['DataIdentity', 'IndependentSynthesizer', 'UniformSynthesizer'],
+ sdv_datasets=['student_placements'],
+ sdmetrics=[],
)
- assert combinations == set(jobs)
+ assert not output.empty
+ assert 'Train_Time' in output
+ assert 'Sample_Time' in output
+ # Expect no metric columns.
+ assert len(output.columns) == 9
+
+def test_benchmarking_no_report_output():
+ """Test that the benchmarking printing does not include report progress."""
+ prints = io.StringIO()
+ with contextlib.redirect_stderr(prints):
+ sdgym.benchmark_single_table(
+ synthesizers=['DataIdentity', 'IndependentSynthesizer', 'UniformSynthesizer'],
+ sdv_datasets=['student_placements'],
+ )
-def test_json_synthesizer():
- synthesizer = {
- 'name': 'synthesizer_name',
- 'synthesizer': 'sdgym.synthesizers.ydata.PreprocessedVanillaGAN',
- 'modalities': ['single-table'],
- 'init_kwargs': {'categorical_transformer': 'label_encoding'},
- 'fit_kwargs': {'data': '$real_data'}
- }
+ assert 'Creating report:' not in prints
- output = sdgym.run(
- synthesizers=[json.dumps(synthesizer)],
- datasets=['KRK_v1'],
- iterations=1,
- )
- assert set(output['synthesizer']) == {'synthesizer_name'}
-
-
-def test_json_synthesizer_multi_table():
- synthesizer = {
- 'name': 'HMA1',
- 'synthesizer': 'sdv.relational.HMA1',
- 'modalities': [
- 'multi-table'
- ],
- 'init_kwargs': {
- 'metadata': '$metadata'
- },
- 'fit_kwargs': {
- 'tables': '$real_data'
- }
- }
-
- output = sdgym.run(
- synthesizers=[json.dumps(synthesizer)],
- datasets=['university_v1', 'trains_v1'],
- iterations=1,
+def get_trained_synthesizer_err(data, metadata):
+ return {}
+
+
+def sample_from_synthesizer_err(synthesizer, num_rows):
+ raise ValueError('random error')
+
+
+def test_error_handling():
+ erroring_synthesizer = create_single_table_synthesizer(
+ 'my_synth', get_trained_synthesizer_err, sample_from_synthesizer_err)
+ output = sdgym.benchmark_single_table(
+ synthesizers=['DataIdentity', 'IndependentSynthesizer', 'UniformSynthesizer'],
+ custom_synthesizers=[erroring_synthesizer],
+ sdv_datasets=['student_placements'],
)
- # CSTest for `university_v1` is not valid because there are no categorical columns.
- valid_out = output.loc[~((output.dataset == 'university_v1') & (output.metric == 'CSTest'))]
+ assert not output.empty
+ assert 'Train_Time' in output
+ assert 'Sample_Time' in output
+ assert (
+ output[output['Synthesizer'] == 'Custom:my_synth'][['Train_Time', 'Sample_Time']]
+ ).isna().all(1).all()
+
+
+def test_compute_quality_score():
+ output = sdgym.benchmark_single_table(
+ synthesizers=['DataIdentity', 'IndependentSynthesizer', 'UniformSynthesizer'],
+ sdv_datasets=['student_placements'],
+ compute_quality_score=False,
+ )
- assert not valid_out.error.any()
+ assert not output.empty
+ assert 'Train_Time' in output
+ assert 'Sample_Time' in output
+ assert 'Quality_Score' not in output
+
+
+def test_duplicate_synthesizers():
+ custom_synthesizer = create_single_table_synthesizer(
+ 'my_synth', get_trained_synthesizer_err, sample_from_synthesizer_err)
+ with pytest.raises(
+ ValueError,
+ match=(
+ 'Synthesizers must be unique. Please remove repeated values in the `synthesizers` '
+ 'and `custom_synthesizers` parameters.'
+ )
+ ):
+ sdgym.benchmark_single_table(
+ synthesizers=['GaussianCopulaSynthesizer', 'GaussianCopulaSynthesizer'],
+ custom_synthesizers=[custom_synthesizer, custom_synthesizer]
+ )
diff --git a/tests/unit/synthesizers/test_generate.py b/tests/unit/synthesizers/test_generate.py
new file mode 100644
index 00000000..0adbed6e
--- /dev/null
+++ b/tests/unit/synthesizers/test_generate.py
@@ -0,0 +1,117 @@
+from unittest.mock import Mock
+
+import pytest
+
+from sdgym.synthesizers import FastMLPreset, SDVRelationalSynthesizer, SDVTabularSynthesizer
+from sdgym.synthesizers.generate import (
+ SYNTHESIZER_MAPPING, create_multi_table_synthesizer, create_sdv_synthesizer_variant,
+ create_sequential_synthesizer, create_single_table_synthesizer)
+
+
+def test_create_single_table_synthesizer():
+ """Test that a single table synthesizer is created."""
+ # Run
+ out = create_single_table_synthesizer('test_synth', Mock(), Mock())
+
+ # Assert
+ assert out.__name__ == 'Custom:test_synth'
+ assert hasattr(out, 'get_trained_synthesizer')
+ assert hasattr(out, 'sample_from_synthesizer')
+
+
+def test_create_multi_table_synthesizer():
+ """Test that a multi table synthesizer is created."""
+ # Run
+ out = create_multi_table_synthesizer('test_synth', Mock(), Mock())
+
+ # Assert
+ assert out.__name__ == 'Custom:test_synth'
+ assert hasattr(out, 'get_trained_synthesizer')
+ assert hasattr(out, 'sample_from_synthesizer')
+
+
+def test_create_sequential_synthesizer():
+ """Test that a sequential synthesizer is created."""
+ # Run
+ out = create_sequential_synthesizer('test_synth', Mock(), Mock())
+
+ # Assert
+ assert out.__name__ == 'Custom:test_synth'
+ assert hasattr(out, 'get_trained_synthesizer')
+ assert hasattr(out, 'sample_from_synthesizer')
+
+
+def test_create_sdv_variant_synthesizer():
+ """Test that a sdv variant synthesizer is created.
+
+ Expect that if the synthesizer class is a single-table synthesizer, the
+ new synthesizer inherits from the SDVTabularSynthesizer base class."""
+ # Setup
+ synthesizer_class = 'GaussianCopulaSynthesizer'
+ synthesizer_parameters = {}
+
+ # Run
+ out = create_sdv_synthesizer_variant('test_synth', synthesizer_class, synthesizer_parameters)
+
+ # Assert
+ assert out.__name__ == 'Variant:test_synth'
+ assert out._MODEL == SYNTHESIZER_MAPPING.get(synthesizer_class)
+ assert out._MODEL_KWARGS == synthesizer_parameters
+ assert issubclass(out, SDVTabularSynthesizer)
+
+
+def test_create_sdv_variant_synthesizer_error():
+ """Test that a sdv variant synthesizer is created.
+
+ Expect that if the synthesizer class is a single-table synthesizer, the
+ new synthesizer inherits from the SDVTabularSynthesizer base class."""
+ # Setup
+ synthesizer_class = 'test'
+ synthesizer_parameters = {}
+
+ # Run
+ with pytest.raises(
+ ValueError,
+ match=r"Synthesizer class test is not recognized. The supported options are "
+ "FastMLPreset, GaussianCopulaSynthesizer, CTGANSynthesizer, "
+ "CopulaGANSynthesizer, TVAESynthesizer, PARSynthesizer, HMASynthesizer"
+ ):
+ create_sdv_synthesizer_variant('test_synth', synthesizer_class, synthesizer_parameters)
+
+
+def test_create_sdv_variant_synthesizer_relational():
+ """Test that a sdv variant synthesizer is created.
+
+ Expect that if the synthesizer class is a relational synthesizer, the
+ new synthesizer inherits from the SDVRelationalSynthesizer base class."""
+ # Setup
+ synthesizer_class = 'HMASynthesizer'
+ synthesizer_parameters = {}
+
+ # Run
+ out = create_sdv_synthesizer_variant('test_synth', synthesizer_class, synthesizer_parameters)
+
+ # Assert
+ assert out.__name__ == 'Variant:test_synth'
+ assert out._MODEL == SYNTHESIZER_MAPPING.get(synthesizer_class)
+ assert out._MODEL_KWARGS == synthesizer_parameters
+ assert issubclass(out, SDVRelationalSynthesizer)
+
+
+def test_create_sdv_variant_synthesizer_preset():
+ """Test that a sdv variant synthesizer is created.
+
+ Expect that if the synthesizer class is a preset synthesizer, the
+ new synthesizer inherits from the FastMLPreset base class."""
+ # Setup
+ synthesizer_class = 'FastMLPreset'
+ synthesizer_parameters = {}
+
+ # Run
+ out = create_sdv_synthesizer_variant('test_synth', synthesizer_class, synthesizer_parameters)
+
+ # Assert
+ assert out.__name__ == 'Variant:test_synth'
+ assert out._MODEL == SYNTHESIZER_MAPPING.get(synthesizer_class)
+ assert out._MODEL_KWARGS == synthesizer_parameters
+ assert issubclass(out, FastMLPreset)
diff --git a/tests/unit/synthesizers/test_independent.py b/tests/unit/synthesizers/test_independent.py
new file mode 100644
index 00000000..097a1836
--- /dev/null
+++ b/tests/unit/synthesizers/test_independent.py
@@ -0,0 +1,22 @@
+from unittest.mock import Mock, patch
+
+import pandas as pd
+
+from sdgym.synthesizers import IndependentSynthesizer
+
+
+class TestIndependentSynthesizer:
+
+ @patch('sdgym.synthesizers.independent.GaussianMixture')
+ def test__get_trained_synthesizer(self, gm_mock):
+ """Expect that GaussianMixture is instantiated with 4 components."""
+ # Setup
+ independent = IndependentSynthesizer()
+ independent.length = 10
+ data = pd.DataFrame({'col1': [1, 2, 3, 4]})
+
+ # Run
+ independent._get_trained_synthesizer(data, Mock())
+
+ # Assert
+ gm_mock.assert_called_once_with(4)
diff --git a/tests/unit/test_benchmark.py b/tests/unit/test_benchmark.py
new file mode 100644
index 00000000..949c76de
--- /dev/null
+++ b/tests/unit/test_benchmark.py
@@ -0,0 +1,44 @@
+from unittest.mock import ANY, MagicMock, patch
+
+import pandas as pd
+import pytest
+
+from sdgym.benchmark import benchmark_single_table
+
+
+@patch('sdgym.benchmark.os.path')
+def test_output_file_exists(path_mock):
+ """Test the benchmark function when the output path already exists."""
+ # Setup
+ path_mock.exists.return_value = True
+ output_filepath = 'test_output.csv'
+
+ # Run and assert
+ with pytest.raises(
+ ValueError,
+ match='test_output.csv already exists. Please provide a file that does not already exist.',
+ ):
+ benchmark_single_table(
+ synthesizers=['DataIdentity', 'IndependentSynthesizer', 'UniformSynthesizer'],
+ sdv_datasets=['student_placements'],
+ output_filepath=output_filepath
+ )
+
+
+@patch('sdgym.benchmark.tqdm.tqdm')
+def test_progress_bar_updates(tqdm_mock):
+ """Test that the benchmarking function updates the progress bar on one line."""
+ # Setup
+ scores_mock = MagicMock()
+ scores_mock.__iter__.return_value = [pd.DataFrame([1, 2, 3])]
+ tqdm_mock.return_value = scores_mock
+
+ # Run
+ benchmark_single_table(
+ synthesizers=['DataIdentity'],
+ sdv_datasets=['student_placements'],
+ show_progress=True,
+ )
+
+ # Assert
+ tqdm_mock.assert_called_once_with(ANY, total=1, position=0, leave=True)
diff --git a/tests/unit/test_datasets.py b/tests/unit/test_datasets.py
index 9c55f193..46059363 100644
--- a/tests/unit/test_datasets.py
+++ b/tests/unit/test_datasets.py
@@ -1,10 +1,13 @@
import io
-from unittest.mock import Mock, patch
+from pathlib import Path
+from unittest.mock import Mock, call, patch
from zipfile import ZipFile
import botocore
-from sdgym.datasets import download_dataset
+from sdgym.datasets import (
+ _get_bucket_name, _get_dataset_path, download_dataset, get_available_datasets,
+ get_dataset_paths)
class AnyConfigWith:
@@ -41,8 +44,9 @@ def test_download_dataset_public_bucket(boto3_mock, tmpdir):
- file creation for dataset in datasets path
"""
# setup
+ modality = 'single_table'
dataset = 'my_dataset'
- bucket = 'my_bucket'
+ bucket = 's3://my_bucket'
bytesio = io.BytesIO()
with ZipFile(bytesio, mode='w') as zf:
@@ -60,6 +64,7 @@ def test_download_dataset_public_bucket(boto3_mock, tmpdir):
# run
download_dataset(
+ modality,
dataset,
datasets_path=str(tmpdir),
bucket=bucket
@@ -70,7 +75,8 @@ def test_download_dataset_public_bucket(boto3_mock, tmpdir):
's3',
config=AnyConfigWith(botocore.UNSIGNED)
)
- s3_mock.get_object.assert_called_once_with(Bucket=bucket, Key=f'{dataset}.zip')
+ s3_mock.get_object.assert_called_once_with(
+ Bucket='my_bucket', Key=f'{modality.upper()}/{dataset}.zip')
with open(f'{tmpdir}/{dataset}') as dataset_file:
assert dataset_file.read() == 'test_content'
@@ -101,8 +107,9 @@ def test_download_dataset_private_bucket(boto3_mock, tmpdir):
- file creation for dataset in datasets path
"""
# setup
+ modality = 'single_table'
dataset = 'my_dataset'
- bucket = 'my_bucket'
+ bucket = 's3://my_bucket'
aws_key = 'my_key'
aws_secret = 'my_secret'
bytesio = io.BytesIO()
@@ -121,6 +128,7 @@ def test_download_dataset_private_bucket(boto3_mock, tmpdir):
# run
download_dataset(
+ modality,
dataset,
datasets_path=str(tmpdir),
bucket=bucket,
@@ -134,6 +142,98 @@ def test_download_dataset_private_bucket(boto3_mock, tmpdir):
aws_access_key_id=aws_key,
aws_secret_access_key=aws_secret
)
- s3_mock.get_object.assert_called_once_with(Bucket=bucket, Key=f'{dataset}.zip')
+ s3_mock.get_object.assert_called_once_with(
+ Bucket='my_bucket', Key=f'{modality.upper()}/{dataset}.zip')
with open(f'{tmpdir}/{dataset}') as dataset_file:
assert dataset_file.read() == 'test_content'
+
+
+@patch('sdgym.datasets.Path')
+def test__get_dataset_path(mock_path):
+ """Test that the path to the dataset is returned if it already exists."""
+ # Setup
+ modality = 'single_table'
+ dataset = 'test_dataset'
+ datasets_path = 'local_path'
+ mock_path.return_value.__rtruediv__.side_effect = [False, False, True]
+
+ # Run
+ path = _get_dataset_path(modality, dataset, datasets_path)
+
+ # Assert
+ assert path == mock_path.return_value
+
+
+def test_get_bucket_name():
+ """Test that the bucket name is returned for s3 path."""
+ # Setup
+ bucket = 's3://bucket-name'
+
+ # Run
+ bucket_name = _get_bucket_name(bucket)
+
+ # Assert
+ assert bucket_name == 'bucket-name'
+
+
+def test_get_bucket_name_local_folder():
+ """Test that the bucket name is returned for a local path."""
+ # Setup
+ bucket = 'bucket-name'
+
+ # Run
+ bucket_name = _get_bucket_name(bucket)
+
+ # Assert
+ assert bucket_name == 'bucket-name'
+
+
+@patch('sdgym.datasets._get_available_datasets')
+def test_get_available_datasets(helper_mock):
+ """Test that the modality is set to single-table."""
+ # Run
+ get_available_datasets()
+
+ # Assert
+ helper_mock.assert_called_once_with('single_table')
+
+
+@patch('sdgym.datasets._get_dataset_path')
+@patch('sdgym.datasets.ZipFile')
+@patch('sdgym.datasets.Path')
+def test_get_dataset_paths(path_mock, zipfile_mock, helper_mock):
+ """Test that the dataset paths are generated correctly."""
+ # Setup
+ local_path = 'test_local_path'
+ bucket_path_mock = Mock()
+ bucket_path_mock.exists.return_value = True
+ path_mock.side_effect = [
+ Path('datasets_folder'), bucket_path_mock, bucket_path_mock]
+ bucket_path_mock.iterdir.return_value = [
+ Path('test_local_path/dataset_1.zip'),
+ Path('test_local_path/dataset_2'),
+ ]
+
+ # Run
+ get_dataset_paths(None, None, local_path, None, None)
+
+ # Assert
+ zipfile_mock.return_value.extractall.assert_called_once_with(Path('datasets_folder/dataset_1'))
+ helper_mock.assert_has_calls([
+ call(
+ 'single_table',
+ Path('datasets_folder/dataset_1'),
+ Path('datasets_folder'),
+ 'test_local_path',
+ None,
+ None,
+ ),
+ call(
+ 'single_table',
+ Path('test_local_path/dataset_2'),
+ Path('datasets_folder'),
+ 'test_local_path',
+ None,
+ None,
+ ),
+ ])
diff --git a/tests/unit/test_summary.py b/tests/unit/test_summary.py
index 34d4c443..b7d5265b 100644
--- a/tests/unit/test_summary.py
+++ b/tests/unit/test_summary.py
@@ -41,6 +41,9 @@ def test_make_summary_spreadsheet(add_sheet_mock, excel_writer_mock, summarize_m
'coverage_perc': [1.0, 0.5],
'time': [100, 200],
'best': [2, 0],
+ 'best_time': [1, 0],
+ 'second_best_time': [0, 1],
+ 'third_best_time': [0, 0],
'beat_uniform': [2, 1],
'beat_independent': [2, 1],
'beat_clbn': [2, 1],
@@ -50,7 +53,7 @@ def test_make_summary_spreadsheet(add_sheet_mock, excel_writer_mock, summarize_m
'errors': [0, 1],
'metric_errors': [0, 0],
'avg score': [0.9, 0.45]
- })
+ }, index=['synth1', 'synth2'])
preprocessed_data = pd.DataFrame({'modality': ['single-table']})
errors = pd.DataFrame({
'synth1': [0],
@@ -67,9 +70,11 @@ def test_make_summary_spreadsheet(add_sheet_mock, excel_writer_mock, summarize_m
# Assert
expected_summary = pd.DataFrame({
'coverage %': [1.0, 0.5],
- 'avg time': [100, 200],
- 'avg score': [0.9, 0.45]
- })
+ '# of Wins': [1, 0],
+ '# of 2nd best': [0, 1],
+ '# of 3rd best': [0, 0],
+ 'library': ['Other', 'Other'],
+ }, index=['synth1', 'synth2'])
expected_summary.index.name = ''
expected_quality = pd.DataFrame({
'total': [2, 2],
@@ -79,9 +84,9 @@ def test_make_summary_spreadsheet(add_sheet_mock, excel_writer_mock, summarize_m
'beat_independent': [2, 1],
'beat_clbn': [2, 1],
'beat_privbn': [2, 1]
- })
+ }, index=['synth1', 'synth2'])
expected_quality.index.name = ''
- expected_performance = pd.DataFrame({'time': [100, 200]})
+ expected_performance = pd.DataFrame({'time': [100, 200]}, index=['synth1', 'synth2'])
expected_performance.index.name = ''
expected_errors = pd.DataFrame({
'total': [2, 2],
@@ -92,7 +97,7 @@ def test_make_summary_spreadsheet(add_sheet_mock, excel_writer_mock, summarize_m
'memory_error': [0, 0],
'errors': [0, 1],
'metric_errors': [0, 0]
- })
+ }, index=['synth1', 'synth2'])
expected_errors.index.name = ''
add_sheet_calls = add_sheet_mock.mock_calls
read_csv_mock.assert_called_once_with('file_path.csv', 'aws_key', 'aws_secret')
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
new file mode 100644
index 00000000..93601122
--- /dev/null
+++ b/tests/unit/test_utils.py
@@ -0,0 +1,40 @@
+import sys
+
+from sdgym.utils import get_duplicates, get_size_of
+
+
+def test_get_size_of():
+ """Test that the correct size is returned."""
+ # Setup
+ test_obj = {'key': 'value'}
+
+ # Run
+ size = get_size_of(test_obj)
+
+ # Assert
+ assert size == sys.getsizeof('value')
+
+
+def test_get_size_of_nested_obj():
+ """Test that the correct size is returned when given a nested object."""
+ # Setup
+ test_inner_obj = {'inner_key': 'inner_value'}
+ test_obj = {'key1': 'value', 'key2': test_inner_obj}
+
+ # Run
+ size = get_size_of(test_obj)
+
+ # Assert
+ assert size == sys.getsizeof('value') + sys.getsizeof('inner_value')
+
+
+def test_get_duplicates():
+ """Test that the correct duplicates are returned."""
+ # Setup
+ items = ['a', 'a', 'b', 'c', 'd', 'd', 'd']
+
+ # Run
+ duplicates = get_duplicates(items)
+
+ # Assert
+ assert duplicates == {'a', 'd'}
diff --git a/tox.ini b/tox.ini
index 7b793aae..7d130b04 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = py38-lint, py3{6,7,8,9}-{integration,unit,minimum,readme}
+envlist = py38-lint, py3{7,8,9,10}-{integration,unit,minimum,readme}
[testenv]
skipsdist = false