Skip to content

Commit

Permalink
Add code for extension
Browse files Browse the repository at this point in the history
  • Loading branch information
melanibe committed Sep 16, 2024
2 parents f590d3b + b318f1b commit f183d7e
Show file tree
Hide file tree
Showing 65 changed files with 6,206 additions and 823 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ jobs:
python -V
conda info
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --extend-ignore=E203 --show-source --statistics --max-line-length=119
flake8 . --count --extend-ignore=E203,E741 --show-source --statistics --max-line-length=119
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
joined_simple.csv
lightning_logs/
*.sh
*.png
*.pdf
*.eps
embed_cf/

outputs2/
outputs
znew_scripts
padchest_cf_images_v0
Expand Down
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Counterfactual contrastive learning: robust representations via causal image synthesis
# Robust representations for image classification via counterfactual contrastive learning

This repository contains the code for the paper "Counterfactual contrastive learning: robust representations via causal image synthesis". <br>
[Link to paper](https://arxiv.org/abs/2403.09605).
This repository contains the code for the papers "Counterfactual contrastive learning: robust representations via causal image synthesis" and extended version "Robust representations for image classification via counterfactual contrastive learning". <br>
[Link to Data Engineering Workshop @ MICCAI 2024 paper](https://arxiv.org/abs/2403.09605).
[Link to extension (pre-print)]()

![alt text](figure1.png)

Expand Down Expand Up @@ -64,4 +65,4 @@ python classification/train.py experiment=base_density trainer.finetune_path=PAT
```
You can choose the proportion of labelled data to use for finetuning with the flag `data.prop_train=1.0`

5. Evaluate on the test set by running the notebook `evaluation/embed_density.ipynb` to run and save inference results on the test set.
5. Evaluate on the test set by running the inference code `evaluation/simclr_embed.py` to run and save inference results on the test set.
2 changes: 2 additions & 0 deletions causal_models/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def main(args):
ckpt_args["data_dir"] = args.data_dir
if args.lr < ckpt_args["lr"]:
ckpt_args["lr"] = args.lr
if args.epochs > ckpt_args["epochs"]:
ckpt_args["epochs"] = args.epochs
vars(args).update(ckpt_args)
else:
print(f"Checkpoint not found at: {args.resume}")
Expand Down
2 changes: 1 addition & 1 deletion causal_models/train_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def setup_dataloaders(args, cache: bool = True, shuffle_train=True):
"data=embed",
"data.batch_size=16",
f"data.cache={cache}",
f"data.exclude_cviews={not 'cview' in args.parents_x}",
f"data.exclude_cviews={'cview' not in args.parents_x}",
],
)
print(cfg)
Expand Down
Loading

0 comments on commit f183d7e

Please sign in to comment.