Skip to content

Commit

Permalink
include wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonatan Menger committed Feb 20, 2025
1 parent 1dc0837 commit fec878c
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ tests/__pycache__/
*.log
*.log.*
*.log.*

wandb
# anndata files
# only track data files in data/demo
data/*
Expand Down
25 changes: 12 additions & 13 deletions conf/train_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# python train.py dataset.basename=my_awesome_dataset

# Here are some defaults:
embedding_method: "geneformer" #This refers to the precomputed "embeddings" in .obsm of the anndata object accessible throgh the sharelink in the dataset
embedding_method: "scvi" #This refers to the precomputed "embeddings" in .obsm of the anndata object accessible throgh the sharelink in the dataset

input_dim_map:
hvg: 2000
Expand All @@ -18,7 +18,7 @@ input_dim_map:

dataset:
basename: "geo_7k_cellxgene_3_5K"
type: "pairs"
type: "multiplets"
test_datasets:
- "jo-mengr/bowel_disease_single"

Expand All @@ -27,32 +27,31 @@ text_encoder:
freeze_text_encoder: True
unfreeze_last_n_layers: 0

#loss: "MultipleNegativesRankingLoss"
loss: "ContrastiveLoss"
loss: "MultipleNegativesRankingLoss"
#loss: "ContrastiveLoss"
#evaluator: "TripletEvaluator" # If not provided, the default evaluator is used

adapter:
omics_input_dim: None #this will be overwritten by the embedding_dim_map
hidden_dim: 512
output_dim: 2048
output_dim: 1024

trainer:
unfreeze_epoch: 0.03
output_dir: "../../models/trained"
num_train_epochs: 64
per_device_train_batch_size: 128
per_device_eval_batch_size: 128
num_train_epochs: 32
per_device_train_batch_size: 64
per_device_eval_batch_size: 54
learning_rate: 2e-5
warmup_ratio: 0.1
fp16: true
fp16: false
bf16: false
eval_strategy: "steps"
eval_steps: 500
eval_steps: 100
save_strategy: "steps"
save_steps: 2000
save_steps: 500
save_total_limit: 2
logging_steps: 500
run_name: "mmcontext"
logging_steps: 100

save_dir: "out"
# You could also define logging via Hydra’s logging config if desired,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies = [
"sentence-transformers",
"session-info",
"torch",
"wandb>=0.19.6",
"zarr<=2.18.4",
]
optional-dependencies.dev = [
Expand Down
6 changes: 3 additions & 3 deletions scripts/start_job.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#SBATCH --output=mmcontext_slurm.out
#SBATCH --error=mmcontext_slurm.err
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1 # Request amount of GPUs
#SBATCH --gres=gpu:2 # Request amount of GPUs
#SBATCH --mem=64G # Request 64GB of host RAM, not GPU VRAM!
#SBATCH --time=01:00:00 # Max job time of 1 hour

source .venv/bin/activate

# Now run your Python script
#accelerate launch --num_processes 2 scripts/train_hydra.py
python3 scripts/train_hydra.py
accelerate launch --num_processes 2 scripts/train_hydra.py
#python3 scripts/train_hydra.py
5 changes: 3 additions & 2 deletions scripts/train_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import BinaryClassificationEvaluator
from transformers.integrations import WandbCallback

from mmcontext.engine.callback import UnfreezeTextEncoderCallback
from mmcontext.eval import SystemMonitor, zero_shot_classification_roc
Expand Down Expand Up @@ -123,7 +124,7 @@ def main(cfg: DictConfig):
save_steps=cfg.trainer.save_steps,
save_total_limit=cfg.trainer.save_total_limit,
logging_steps=cfg.trainer.logging_steps,
run_name=cfg.trainer.run_name,
run_name=str(hydra_run_dir),
)

# -------------------------------------------------------------------------
Expand All @@ -144,7 +145,7 @@ def main(cfg: DictConfig):
loss=loss_obj,
evaluator=dev_evaluator,
extra_feature_keys=["omics_representation"],
callbacks=[unfreeze_callback],
callbacks=[unfreeze_callback, WandbCallback()],
)
trainer.train()

Expand Down
Loading

0 comments on commit fec878c

Please sign in to comment.