-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreadme.txt
31 lines (26 loc) · 1.2 KB
/
readme.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
This is the sample code of CorrVAE handling dSprites dataset. The code is adapted from the code of PCVAE: https://github.com/xguo7/PCVAE.
===========================================================================================================
Running environment:
--------------------
Python 3.9;
===========================================================================================================
Dependencies:
-------------
PyTorch 1.8.1
networkx 2.5
pandas 1.1.3
numpy 1.20.2
rdkit 2021.09.3
===========================================================================================================
Data:
-----
The dSprites dataset can be downloaded from https://github.com/deepmind/dsprites-dataset. dSprites dataset is located in data folder. The code to reconstruct the dSprites dataset is in .utils/datasets.py
===========================================================================================================
Code description:
-----------------
To train the model, run:
python train.py
or:
directly run code in train.py
This will train the model with the dSprites dataset and returns the trained model as modelCorrVAE.pt.
For evaluation purpose, run code in test.py.