Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keras 3.0 #50

Open
wants to merge 31 commits into
base: feature/keras-3-tf
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
af53e5d
update dependencies
AlexanderVNikitin Mar 21, 2024
b90e021
add auto-version to setup; v0.0.5
AlexanderVNikitin Mar 24, 2024
903bff4
SliceAndShuffle fix (#43)
liyiersan Mar 26, 2024
3c9315b
add dtw to readme
AlexanderVNikitin Mar 30, 2024
5d02985
improve docs
AlexanderVNikitin Apr 1, 2024
6237a0f
Fix discriminative metric (#45)
liyiersan May 24, 2024
2688d03
v0.0.6
AlexanderVNikitin May 24, 2024
0abe597
add fontsizes for visualizations
AlexanderVNikitin May 31, 2024
afcc307
Update README.md
AlexanderVNikitin Jun 3, 2024
1665a82
add transformer models and predictive maintenance simulator
AlexanderVNikitin Jun 4, 2024
129b99a
add metrics
AlexanderVNikitin Jun 4, 2024
efff04d
generalize dataloader
AlexanderVNikitin Jun 4, 2024
4327724
add Lotka Volterra simulator
AlexanderVNikitin Jun 5, 2024
50a0c6e
Update README.md
AlexanderVNikitin Jun 6, 2024
a5f23b9
detailed readme
AlexanderVNikitin Jun 7, 2024
b91cf21
fix typo
AlexanderVNikitin Jun 11, 2024
3b09e33
improve ds tutorial
AlexanderVNikitin Jun 15, 2024
ae58487
add colab versions of tutorials
AlexanderVNikitin Jun 15, 2024
b821cd7
improve visualizations
AlexanderVNikitin Jun 17, 2024
d5b8cd1
add synchronized_brainwave_datase and its test case, modify readme (#48)
uncircle Jun 18, 2024
f43887c
update readme
AlexanderVNikitin Jun 20, 2024
dafc344
v0.0.7
AlexanderVNikitin Jun 24, 2024
2b138cb
add wgan
AlexanderVNikitin Jul 11, 2024
366afb3
Initial commit for keras-3.0 branch
liyiersan Jul 19, 2024
683b64b
remaining are cgan, sts, and timeGAN
liyiersan Jul 19, 2024
f455948
fix cgan and cvae
liyiersan Jul 19, 2024
500eb81
fix test
liyiersan Jul 20, 2024
1444426
fix all test
liyiersan Jul 20, 2024
78c34eb
fix torch_train
liyiersan Jul 20, 2024
edddd17
add keras3.0 support for sts and timegan
liyiersan Aug 30, 2024
e7bb0d1
test pytest
liyiersan Sep 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,9 @@ dmypy.json

# Pyre type checker
.pyre/

# local history
.lh/

# data
data/
168 changes: 117 additions & 51 deletions README.md

Large diffs are not rendered by default.

Binary file added docs/_static/generation_process.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
94 changes: 94 additions & 0 deletions docs/guides/augmentations.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
.. _augmentations-label:

Augmentations
============

[Recommended] A more in-depth tutorial on augmentations for time series data `is available in our repo. <https://github.com/AlexanderVNikitin/tsgm/blob/main/tutorials/augmentations.ipynb>`_

TSGM provides a wide variety of augmentation techniques beyond generative models.
For the following demonstrations, we first need to generate a toy dataset:

.. code-block:: python

import tsgm
X = tsgm.utils.gen_sine_dataset(100, 64, 2, max_value=20)

Jittering
------------
In tsgm, Gaussian noise augmentation can be applied as follows:

.. code-block:: python

aug_model = tsgm.models.augmentations.GaussianNoise()
samples = aug_model.generate(X=X, n_samples=10, variance=0.2)

The idea behind Gaussian noise augmentation is that adding a small amount of jittering to time series probably will not change it significantly but will increase the amount of such noisy samples in our dataset.

Shuffle Features
------------
Another approach to time series augmentation is simply shuffle the features. This approach is suitable only for particular multivariate time series, where they are invariant to all or particular permutations of features. For instance, it can be applied to time series where each feature represents same independent measurements from various sensors.

.. code-block:: python

aug_model = tsgm.models.augmentations.Shuffle()
samples = aug_model.generate(X=X, n_samples=3)

Slice and shuffle
------------
Slice and shuffle augmentation [3] cuts a time series into slices and shuffles those pieces. This augmentation can be performed for time series that exhibit some form of invariance over time. For instance, imagine a time series measured from wearable devices for several days. The good strategy for this case is to slice time series by days and, by shuffling those days, get additional samples.

.. code-block:: python

aug_model = tsgm.models.augmentations.SliceAndShuffle()
samples = aug_model.generate(X=X, n_samples=10, n_segments=3)

Magnitude Warping
------------
Magnitude warping [3] changes the magnitude of each sample in a time series dataset by multiplication of the original time series with a cubic spline curve. This process scales the magnitude of time series, which can be beneficial in many cases, such as our synthetic example with sines n_knots number of knots at random magnitudes distributed as N(1, σ^2) where σ is set by a parameter sigma in function .generate.

.. code-block:: python

aug_model = tsgm.models.augmentations.MagnitudeWarping()
samples = aug_model.generate(X=X, n_samples=10, sigma=1)



Window Warping
------------
In this technique [4], the selected windows in time series data are either speeding up or down. Then, the whole resulting time series is scaled back to the original size in order to keep the timesteps at the original length. See an example of such augmentation below:

.. code-block:: python

aug_model = tsgm.models.augmentations.WindowWarping()
samples = aug_model.generate(X=X, n_samples=10, scales=(0.5,), window_ratio=0.5)


Dynamic Time Warping Barycentric Average (DTWBA)
------------
Dynamic Time Warping Barycentric Average (DTWBA)[2] is an augmentation method that is based on Dynamic Time Warping (DTW). DTW is a method of measuring similarity between time series. The idea is to "sync" those time series, as it is demonstrated in the following picture.

DTWBA goes like this:

1. The algorithm picks one time series to initialize the DTW_BA result.
2. This time series can either be given explicitly or can be chosen randomly from the dataset
3. For each of the N time series, the algorithm computes DTW distance and the path (the path is the mapping that minimizes the distance)
4. After computing all DTW distances, the algorithm updates the DTWBA result by doing the average with respect to all the paths found above
5. The algorithm repeats steps (2) and (3) until the DTWBA result converges

.. code-block:: python

aug_model = tsgm.models.augmentations.DTWBarycentricAveraging()
initial_timeseries = random.sample(range(X.shape[0]), 10)
initial_timeseries = X[initial_timeseries]
samples = aug_model.generate(X=X, n_samples=10, initial_timeseries=initial_timeseries )


References
------------
[1] H. Sakoe and S. Chiba, “Dynamic programming algorithm optimization for spoken word recognition”. IEEE Transactions on Acoustics, Speech, and Signal Processing, 26(1), 43-49 (1978).

[2] F. Petitjean, A. Ketterlin & P. Gancarski. A global averaging method for dynamic time warping, with applications to clustering. Pattern Recognition, Elsevier, 2011, Vol. 44, Num. 3, pp. 678-693

[3] Um TT, Pfister FM, Pichler D, Endo S, Lang M, Hirche S, Fietzek U, Kulic´ D (2017) Data augmentation of wearable sensor data for parkinson’s disease monitoring using convolutional neural networks. In: Proceedings of the 19th ACM international conference on multimodal interaction, pp. 216–220

[4] Rashid, K.M. and Louis, J., 2019. Window-warping: a time series data augmentation of IMU data for construction equipment activity identification. In ISARC. Proceedings of the international symposium on automation and robotics in construction (Vol. 36, pp. 651-657). IAARC Publications.
19 changes: 9 additions & 10 deletions docs/guides/introduction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TSGM implements multiple augmentation approaches including window warping, shuff
aug_model = tsgm.models.augmentations.GaussianNoise(variance=0.2)
samples = aug_model.generate(X=X, n_samples=10)

More examples are available in `the augmentation tutorial. <https://github.com/AlexanderVNikitin/tsgm/blob/main/tutorials/augmentations.ipynb>`_
More examples are available in `the augmentation tutorial <https://github.com/AlexanderVNikitin/tsgm/blob/main/tutorials/augmentations.ipynb>`_ or in :ref:`augmentations-label`.

Generators
=============================
Expand All @@ -42,7 +42,7 @@ The training of data-driven simulators can be done via likelihood optimization,
- `tsgm.models.cgan.ConditionalGAN` - conditional GAN model for labeled and temporally labeled time-series simulation,\\
- `tsgm.models.cvae.BetaVAE` - beta-VAE model adapted for time-series simulation,\\
- `tsgm.models.cvae.cBetaVAE` - conditional beta-VAE model for labeled and temporally labeled time-series simulation,\\
- `tsgm.models.cvae.TimeGAN` - extended GAN-based model for time series generation.
- `tsgm.models.timegan.TimeGAN` - extended GAN-based model for time series generation.

A minimalistic example of synthetic data generation with VAEs:

Expand Down Expand Up @@ -103,12 +103,13 @@ Metrics
=============================
In `tsgm.metrics`, we implemented several metrics for evaluation of generated time series. Essentially, these metrics are subdivided into five types:

- data similarity / distance,
- predictive consistency,
- fairness,
- privacy,
- downstream effectiveness,
- visual similarity.
- data similarity / distance: `tsgm.metrics.DistanceMetric`, `tsgm.metrics.MMDMetric`, `tsgm.metrics.DiscriminativeMetric`,
- predictive consistency: `tsgm.metrics.ConsistencyMetric`,
- fairness: `tsgm.metrics.DemographicParityMetric`, `tsgm.metrics.PredictiveParityMetric`
- privacy: `tsgm.metrics.PrivacyMembershipInferenceMetric`,
- diversity: `tsgm.metrics.EntropyMetric`, `tsgm.metrics.ShannonEntropyMetric`, `tsgm.metrics.PairwiseDistanceMetric`,
- downstream effectiveness: `tsgm.metrics.DownstreamPerformanceMetric`,
- qualitative analysis: `tsgm.utils.visualization`.

See the following code for an example of using metrics:

Expand Down Expand Up @@ -151,5 +152,3 @@ If you find the *TSGM* useful, please consider citing our paper:
journal={arXiv preprint arXiv:2305.11567},
year={2023}
}


6 changes: 3 additions & 3 deletions docs/modules/root.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Datasets
:members:
:undoc-members:

.. automodule:: tsgm.utils.dataset
.. automodule:: tsgm.utils.datasets
:members:
:undoc-members:

Expand Down Expand Up @@ -78,9 +78,9 @@ Zoo
:undoc-members:


Datasets
Simulators
--------------
.. automodule:: tsgm.utils.datasets
.. automodule:: tsgm.simulator
:members:
:undoc-members:

Expand Down
2 changes: 1 addition & 1 deletion requirements/docs_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ nbsphinx
jupytext
pydata-sphinx-theme
sphinxcontrib-bibtex
sphinx-autoapi
sphinx-autoapi==3.0.0
sphinx_rtd_theme
16 changes: 13 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import os
from setuptools import setup
from setuptools import find_packages


# Function to read version from __version__.py
def get_version():
with open(os.path.join(os.path.dirname(__file__), 'tsgm/version.py')) as f:
exec(f.read())
return locals()['__version__']


name = "tsgm"
version = get_version()

keywords = [
"machine learning",
Expand Down Expand Up @@ -44,7 +53,7 @@ def read_file(filename: str) -> str:


setup(name='tsgm',
version='0.0.4',
version=version,
description='Time Series Generative Modelling Framework',
author=author,
author_email='',
Expand All @@ -70,8 +79,9 @@ def read_file(filename: str) -> str:
"yfinance==0.2.28",
"tqdm",
"dtaidistance >= 2.3.10",
"tensorflow",
"tensorflow-probability",
"tensorflow < 2.16",
"tensorflow-probability < 0.24.0",
"statsmodels"
],
package_data={'tsgm': ['README.md']},
packages=find_packages())
14 changes: 9 additions & 5 deletions tests/test_abc.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import pytest
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import tsgm
import tensorflow_probability as tfp
from tsgm.backend import get_distributions

distributions = get_distributions()

import tensorflow as tf
import numpy as np
from tensorflow import keras
import keras


def test_abc_rejection_sampler_nn_simulator():
Expand Down Expand Up @@ -43,8 +47,8 @@ def test_abc_rejection_sampler_model_based_simulator():
data = tsgm.dataset.DatasetProperties(N=100, D=2, T=100)
simulator = tsgm.simulator.SineConstSimulator(data=data, max_scale=max_scale, max_const=20)
priors = {
"max_scale": tfp.distributions.Uniform(9, 11),
"max_const": tfp.distributions.Uniform(19, 21)
"max_scale": distributions.Uniform(9, 11),
"max_const": distributions.Uniform(19, 21)
}
samples_ref = simulator.generate(10)

Expand Down
4 changes: 3 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import numpy as np

import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import tsgm

@pytest.mark.parametrize("mock_aug", [
Expand Down
69 changes: 57 additions & 12 deletions tests/test_cgan.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import pytest
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import tsgm

import tensorflow as tf
try:
import tensorflow_privacy as tf_privacy
__tf_privacy_available = True
except ModuleNotFoundError:
__tf_privacy_available = False
import numpy as np
from tensorflow import keras
import keras


from tsgm.backend import get_backend



def _gen_dataset(seq_len: int, feature_dim: int, batch_size: int):
Expand All @@ -17,8 +23,13 @@ def _gen_dataset(seq_len: int, feature_dim: int, batch_size: int):
scaler = tsgm.utils.TSFeatureWiseScaler((-1, 1))
X_train = scaler.fit_transform(data).astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices(X_train)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
backend = get_backend()
if os.environ.get("KERAS_BACKEND") == "tensorflow":
dataset = backend.data.Dataset.from_tensor_slices(X_train)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
elif os.environ.get("KERAS_BACKEND") == "torch":
dataset = backend.utils.data.TensorDataset(X_train)
dataset = backend.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataset


Expand All @@ -29,8 +40,13 @@ def _gen_cond_dataset(seq_len: int, batch_size: int):
X_train = scaler.fit_transform(X).astype(np.float32)
y = keras.utils.to_categorical(y_i, 2).astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((X_train, y))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
backend = get_backend()
if os.environ.get("KERAS_BACKEND") == "tensorflow":
dataset = backend.data.Dataset.from_tensor_slices(X_train)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
elif os.environ.get("KERAS_BACKEND") == "torch":
dataset = backend.utils.data.TensorDataset(X_train)
dataset = backend.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataset, y


Expand All @@ -41,8 +57,13 @@ def _gen_t_cond_dataset(seq_len: int, batch_size: int):
X_train = scaler.fit_transform(X).astype(np.float32)
y = y.astype(np.float32)

dataset = tf.data.Dataset.from_tensor_slices((X_train, y))
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
backend = get_backend()
if os.environ.get("KERAS_BACKEND") == "tensorflow":
dataset = backend.data.Dataset.from_tensor_slices(X_train)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)
elif os.environ.get("KERAS_BACKEND") == "torch":
dataset = backend.utils.data.TensorDataset(X_train)
dataset = backend.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataset, y


Expand Down Expand Up @@ -236,7 +257,6 @@ def test_dp_compiler():
learning_rate=learning_rate
)


g_optimizer = tf_privacy.DPKerasAdamOptimizer(
l2_norm_clip=l2_norm_clip,
noise_multiplier=noise_multiplier,
Expand All @@ -259,6 +279,31 @@ def test_dp_compiler():
assert generated_samples.shape == (10, 64, 1)


def test_temporal_cgan_multiple_features():
# TODO
pass
def test_wavegan():
latent_dim = 2
output_dim = 1
feature_dim = 1
seq_len = 64
batch_size = 48

dataset = _gen_dataset(seq_len, feature_dim, batch_size)
architecture = tsgm.models.architectures.zoo["wavegan"](
seq_len=seq_len, feat_dim=feature_dim,
latent_dim=latent_dim, output_dim=output_dim)
discriminator, generator = architecture.discriminator, architecture.generator
gan = tsgm.models.cgan.GAN(
discriminator=discriminator, generator=generator, latent_dim=latent_dim, use_wgan=True
)
gan.compile(
d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
loss_fn=keras.losses.BinaryCrossentropy(),
)

gan.fit(dataset, epochs=1)

assert gan.generator is not None
assert gan.discriminator is not None
# Check generation
generated_samples = gan.generate(10)
assert generated_samples.shape == (10, seq_len, 1)
4 changes: 4 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest

import numpy as np
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import tsgm
import tsgm.backend


def test_dataset():
Expand Down
Loading