diff --git a/conf/bandgap.yaml b/conf/bandgap.yaml new file mode 100644 index 0000000..130b8fd --- /dev/null +++ b/conf/bandgap.yaml @@ -0,0 +1,33 @@ + + +hydra: + job: + name: bandgap + run: + dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} + + # launcher: + # _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + # submitit_folder: ${hydra.sweep.dir}/.submitit/%j + # timeout_min: 3600 + # mem_gb: 160 + # nodes: 1 + # #gpus_per_task: 1 + # gres: gpu:1 + # #gpus_per_node: 2 + # name: ${hydra.job.name} + # partition: 'gpu' + # additional_parameters: + # nodelist: 'gpu[008,013-017]' + # tasks_per_node: 1 + +defaults: +- model: none +# - override hydra/launcher: submitit_slurm + +runs: + - name: benchmark_run + tasks: [benchmark] \ No newline at end of file diff --git a/conf/benchmark.yaml b/conf/benchmark.yaml index 56f4217..0c8a299 100644 --- a/conf/benchmark.yaml +++ b/conf/benchmark.yaml @@ -1,24 +1,33 @@ - hydra: - job: - name: benchmark - run: - dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} - sweep: - dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} - subdir: ${hydra.job.override_dirname} - - - - defaults: - - model: none - - - - runs: - - - - name: benchmark_run - tasks: [benchmark] - +hydra: + job: + name: benchmark + run: + dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} + + # launcher: + # _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + # submitit_folder: ${hydra.sweep.dir}/.submitit/%j + # timeout_min: 3600 + # mem_gb: 160 + # nodes: 1 + # #gpus_per_task: 1 + # gres: gpu:1 + # #gpus_per_node: 2 + # name: ${hydra.job.name} + # partition: 'gpu' + # additional_parameters: + # nodelist: 'gpu[008,013-017]' + # tasks_per_node: 1 + +defaults: +- model: none +# - override hydra/launcher: submitit_slurm + +runs: + - name: benchmark_run + tasks: [benchmark] \ No newline at end of file diff --git a/conf/bg/atoms.yaml b/conf/bg/atoms.yaml new file mode 100644 index 0000000..4c183f8 --- /dev/null +++ b/conf/bg/atoms.yaml @@ -0,0 +1,19 @@ +# @package _global_ +model: + representation: atom_sequences + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: n0w0f/MatText-atom-seq-2m + + \ No newline at end of file diff --git a/conf/bg/atoms_params.yaml b/conf/bg/atoms_params.yaml new file mode 100644 index 0000000..728685e --- /dev/null +++ b/conf/bg/atoms_params.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: atom_sequences_plusplus + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-plusplus-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/bg/cifp1.yaml b/conf/bg/cifp1.yaml new file mode 100644 index 0000000..51bed8e --- /dev/null +++ b/conf/bg/cifp1.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_p1 + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifp1-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 1024 + training_arguments: + per_device_train_batch_size: 128 + path: + pretrained_checkpoint: n0w0f/MatText-cifp1-2m \ No newline at end of file diff --git a/conf/bg/cifpsym.yaml b/conf/bg/cifpsym.yaml new file mode 100644 index 0000000..6175580 --- /dev/null +++ b/conf/bg/cifpsym.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_symmetrized + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifsymmetrized-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifsymmetrized-2m \ No newline at end of file diff --git a/conf/bg/composition.yaml b/conf/bg/composition.yaml new file mode 100644 index 0000000..7e52344 --- /dev/null +++ b/conf/bg/composition.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: composition + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-composition-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/bg/crystal_llm.yaml b/conf/bg/crystal_llm.yaml new file mode 100644 index 0000000..ce787ac --- /dev/null +++ b/conf/bg/crystal_llm.yaml @@ -0,0 +1,16 @@ +# @package _global_ +model: + representation: crystal_text_llm + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/checkpoint-393000 + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + \ No newline at end of file diff --git a/conf/bg/local_env.yaml b/conf/bg/local_env.yaml new file mode 100644 index 0000000..15a3667 --- /dev/null +++ b/conf/bg/local_env.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: local_env + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 \ No newline at end of file diff --git a/conf/bg/slices.yaml b/conf/bg/slices.yaml new file mode 100644 index 0000000..4a447e7 --- /dev/null +++ b/conf/bg/slices.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: slices + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-slices-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: n0w0f/MatText-slices-2m \ No newline at end of file diff --git a/conf/bg/zmatrix.yaml b/conf/bg/zmatrix.yaml new file mode 100644 index 0000000..5c1e96f --- /dev/null +++ b/conf/bg/zmatrix.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: zmatrix + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-zmatrix-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: n0w0f/MatText-zmatrix-2m \ No newline at end of file diff --git a/conf/bg2m/atoms.yaml b/conf/bg2m/atoms.yaml new file mode 100644 index 0000000..1d55ace --- /dev/null +++ b/conf/bg2m/atoms.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: atoms_params + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop/checkpoints/checkpoints/atoms_params_pt_30k_atoms/checkpoint-1000 diff --git a/conf/bg2m/atoms_params.yaml b/conf/bg2m/atoms_params.yaml new file mode 100644 index 0000000..1d55ace --- /dev/null +++ b/conf/bg2m/atoms_params.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: atoms_params + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop/checkpoints/checkpoints/atoms_params_pt_30k_atoms/checkpoint-1000 diff --git a/conf/bg2m/cifp1.yaml b/conf/bg2m/cifp1.yaml new file mode 100644 index 0000000..ad74f90 --- /dev/null +++ b/conf/bg2m/cifp1.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: cif_p1 + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 1024 + training_arguments: + per_device_train_batch_size: 32 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/cif_p1_pt_30k_rt_2/checkpoint-46000 diff --git a/conf/bg2m/cifsymmetrized.yaml b/conf/bg2m/cifsymmetrized.yaml new file mode 100644 index 0000000..e7cc55b --- /dev/null +++ b/conf/bg2m/cifsymmetrized.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: cif_symmetrized + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 1024 + training_arguments: + per_device_train_batch_size: 32 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/cif_symmetrized_pt_30k_rt/checkpoint-45000 diff --git a/conf/bg2m/composition.yaml b/conf/bg2m/composition.yaml new file mode 100644 index 0000000..3783298 --- /dev/null +++ b/conf/bg2m/composition.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: composition + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/composition_pt_30k_rt/checkpoint-1000 diff --git a/conf/bg2m/crystal_llm.yaml b/conf/bg2m/crystal_llm.yaml new file mode 100644 index 0000000..9f97208 --- /dev/null +++ b/conf/bg2m/crystal_llm.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: crystal_llm_rep + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/crystal_llm_rep_pt_30k_rt/checkpoint-11000 diff --git a/conf/bg2m/local_env.yaml b/conf/bg2m/local_env.yaml new file mode 100644 index 0000000..cbb1363 --- /dev/null +++ b/conf/bg2m/local_env.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: zmatrix + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop/checkpoints/checkpoints/atoms_params_pt_30k_atoms/checkpoint-1000 diff --git a/conf/bg2m/slice.yaml b/conf/bg2m/slice.yaml new file mode 100644 index 0000000..1fe01e1 --- /dev/null +++ b/conf/bg2m/slice.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: slice + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop2/checkpoints/checkpoints/slice_pt_30k_rt/checkpoint-23000 diff --git a/conf/bg2m/zmatrix.yaml b/conf/bg2m/zmatrix.yaml new file mode 100644 index 0000000..cbb1363 --- /dev/null +++ b/conf/bg2m/zmatrix.yaml @@ -0,0 +1,13 @@ +# @package _global_ +model: + representation: zmatrix + logging: + wandb_project: 2m_intel_ft + + finetune: + model_name: 2m_intel_ft + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: /work/so87pot/mattext/megaloop/checkpoints/checkpoints/atoms_params_pt_30k_atoms/checkpoint-1000 diff --git a/conf/classification.yaml b/conf/classification.yaml new file mode 100644 index 0000000..c5acbd4 --- /dev/null +++ b/conf/classification.yaml @@ -0,0 +1,33 @@ + + +hydra: + job: + name: is_metal + run: + dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} + + # launcher: + # _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher + # submitit_folder: ${hydra.sweep.dir}/.submitit/%j + # timeout_min: 3600 + # mem_gb: 160 + # nodes: 1 + # #gpus_per_task: 1 + # gres: gpu:1 + # #gpus_per_node: 2 + # name: ${hydra.job.name} + # partition: 'gpu' + # additional_parameters: + # nodelist: 'gpu[008,013-017]' + # tasks_per_node: 1 + +defaults: +- model: none +# - override hydra/launcher: submitit_slurm + +runs: + - name: classification_run + tasks: [classification] \ No newline at end of file diff --git a/conf/form/atoms.yaml b/conf/form/atoms.yaml new file mode 100644 index 0000000..6923edd --- /dev/null +++ b/conf/form/atoms.yaml @@ -0,0 +1,19 @@ +# @package _global_ +model: + representation: atom_sequences + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: n0w0f/MatText-atom-seq-2m + + \ No newline at end of file diff --git a/conf/form/atoms_params.yaml b/conf/form/atoms_params.yaml new file mode 100644 index 0000000..42d2740 --- /dev/null +++ b/conf/form/atoms_params.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: atom_sequences_plusplus + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-plusplus-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 32 + training_arguments: + per_device_train_batch_size: 2048 + + \ No newline at end of file diff --git a/conf/form/cifp1.yaml b/conf/form/cifp1.yaml new file mode 100644 index 0000000..221da0b --- /dev/null +++ b/conf/form/cifp1.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_p1 + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifp1-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifp1-2m \ No newline at end of file diff --git a/conf/form/cifpsym.yaml b/conf/form/cifpsym.yaml new file mode 100644 index 0000000..0dccf71 --- /dev/null +++ b/conf/form/cifpsym.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_symmetrized + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifsymmetrized-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifsymmetrized-2m \ No newline at end of file diff --git a/conf/form/composition.yaml b/conf/form/composition.yaml new file mode 100644 index 0000000..4a2ab67 --- /dev/null +++ b/conf/form/composition.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: composition + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-composition-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 32 + training_arguments: + per_device_train_batch_size: 2048 + + \ No newline at end of file diff --git a/conf/form/crystal_llm.yaml b/conf/form/crystal_llm.yaml new file mode 100644 index 0000000..667968d --- /dev/null +++ b/conf/form/crystal_llm.yaml @@ -0,0 +1,16 @@ +# @package _global_ +model: + representation: crystal_text_llm + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/cllm/checkpoint-393000 + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + \ No newline at end of file diff --git a/conf/form/local_env.yaml b/conf/form/local_env.yaml new file mode 100644 index 0000000..0113a76 --- /dev/null +++ b/conf/form/local_env.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: local_env + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/local_env/checkpoint-381000 + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/local_env/checkpoint-381000 \ No newline at end of file diff --git a/conf/form/slices.yaml b/conf/form/slices.yaml new file mode 100644 index 0000000..9b21975 --- /dev/null +++ b/conf/form/slices.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: slices + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-slices-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 512 + training_arguments: + per_device_train_batch_size: 128 + path: + pretrained_checkpoint: n0w0f/MatText-slices-2m \ No newline at end of file diff --git a/conf/form/zmatrix.yaml b/conf/form/zmatrix.yaml new file mode 100644 index 0000000..02a38bc --- /dev/null +++ b/conf/form/zmatrix.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: zmatrix + dataset: "form_energy" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-zmatrix-2m + logging: + wandb_project: revision-form + + finetune: + model_name: revision-form + context_length: 512 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-zmatrix-2m \ No newline at end of file diff --git a/conf/form_energy.yaml b/conf/form_energy.yaml new file mode 100644 index 0000000..00ff258 --- /dev/null +++ b/conf/form_energy.yaml @@ -0,0 +1,19 @@ + + +hydra: + job: + name: formation_energy + run: + dir: ${hydra:runtime.cwd}/outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} + + +defaults: +- model: none + + +runs: + - name: benchmark_run + tasks: [benchmark] \ No newline at end of file diff --git a/conf/is_metal/atoms.yaml b/conf/is_metal/atoms.yaml new file mode 100644 index 0000000..4c183f8 --- /dev/null +++ b/conf/is_metal/atoms.yaml @@ -0,0 +1,19 @@ +# @package _global_ +model: + representation: atom_sequences + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + path: + pretrained_checkpoint: n0w0f/MatText-atom-seq-2m + + \ No newline at end of file diff --git a/conf/is_metal/atoms_params.yaml b/conf/is_metal/atoms_params.yaml new file mode 100644 index 0000000..728685e --- /dev/null +++ b/conf/is_metal/atoms_params.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: atom_sequences_plusplus + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-atom-seq-plusplus-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/is_metal/cifp1.yaml b/conf/is_metal/cifp1.yaml new file mode 100644 index 0000000..51bed8e --- /dev/null +++ b/conf/is_metal/cifp1.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_p1 + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifp1-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 1024 + training_arguments: + per_device_train_batch_size: 128 + path: + pretrained_checkpoint: n0w0f/MatText-cifp1-2m \ No newline at end of file diff --git a/conf/is_metal/cifpsym.yaml b/conf/is_metal/cifpsym.yaml new file mode 100644 index 0000000..6175580 --- /dev/null +++ b/conf/is_metal/cifpsym.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: cif_symmetrized + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-cifsymmetrized-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 1024 + training_arguments: + per_device_train_batch_size: 64 + path: + pretrained_checkpoint: n0w0f/MatText-cifsymmetrized-2m \ No newline at end of file diff --git a/conf/is_metal/composition.yaml b/conf/is_metal/composition.yaml new file mode 100644 index 0000000..7f66ae7 --- /dev/null +++ b/conf/is_metal/composition.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: composition + dataset: "is-metal" + dataset_type: filtered + special_num_token: False + checkpoint: n0w0f/MatText-composition-2m + logging: + wandb_project: revision-bg-filtered + + finetune: + model_name: revision-bg-filtered + context_length: 32 + training_arguments: + per_device_train_batch_size: 1024 + + \ No newline at end of file diff --git a/conf/is_metal/crystal_llm.yaml b/conf/is_metal/crystal_llm.yaml new file mode 100644 index 0000000..ce787ac --- /dev/null +++ b/conf/is_metal/crystal_llm.yaml @@ -0,0 +1,16 @@ +# @package _global_ +model: + representation: crystal_text_llm + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/alpaca_ckpt/checkpoint-393000 + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + \ No newline at end of file diff --git a/conf/is_metal/local_env.yaml b/conf/is_metal/local_env.yaml new file mode 100644 index 0000000..15a3667 --- /dev/null +++ b/conf/is_metal/local_env.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: local_env + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: /home/so87pot/n0w0f/structllm_ckpt/santiago_ckpt_rt/checkpoint-95000 \ No newline at end of file diff --git a/conf/is_metal/slices.yaml b/conf/is_metal/slices.yaml new file mode 100644 index 0000000..4a447e7 --- /dev/null +++ b/conf/is_metal/slices.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: slices + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-slices-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: n0w0f/MatText-slices-2m \ No newline at end of file diff --git a/conf/is_metal/zmatrix.yaml b/conf/is_metal/zmatrix.yaml new file mode 100644 index 0000000..5c1e96f --- /dev/null +++ b/conf/is_metal/zmatrix.yaml @@ -0,0 +1,17 @@ +# @package _global_ +model: + representation: zmatrix + dataset: "bandgap" + dataset_type: matbench + special_num_token: False + checkpoint: n0w0f/MatText-zmatrix-2m + logging: + wandb_project: revision-bg + + finetune: + model_name: revision-bg + context_length: 512 + training_arguments: + per_device_train_batch_size: 256 + path: + pretrained_checkpoint: n0w0f/MatText-zmatrix-2m \ No newline at end of file diff --git a/conf/llama_8b_bg/atoms.yaml b/conf/llama_8b_bg/atoms.yaml new file mode 100644 index 0000000..6406987 --- /dev/null +++ b/conf/llama_8b_bg/atoms.yaml @@ -0,0 +1,8 @@ +# @package _global_ +model: + representation: atom_sequences + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft + diff --git a/conf/llama_8b_bg/atoms_params.yaml b/conf/llama_8b_bg/atoms_params.yaml new file mode 100644 index 0000000..efe4430 --- /dev/null +++ b/conf/llama_8b_bg/atoms_params.yaml @@ -0,0 +1,11 @@ +# @package _global_ +model: + representation: atom_sequences_plusplus + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft + + + + \ No newline at end of file diff --git a/conf/llama_8b_bg/cifp1.yaml b/conf/llama_8b_bg/cifp1.yaml new file mode 100644 index 0000000..826af92 --- /dev/null +++ b/conf/llama_8b_bg/cifp1.yaml @@ -0,0 +1,8 @@ +# @package _global_ +model: + representation: cif_p1 + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft + diff --git a/conf/llama_8b_bg/cifpsym.yaml b/conf/llama_8b_bg/cifpsym.yaml new file mode 100644 index 0000000..86addd9 --- /dev/null +++ b/conf/llama_8b_bg/cifpsym.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: cif_symmetrized + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/composition.yaml b/conf/llama_8b_bg/composition.yaml new file mode 100644 index 0000000..5289b4b --- /dev/null +++ b/conf/llama_8b_bg/composition.yaml @@ -0,0 +1,8 @@ +# @package _global_ +model: + representation: composition + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft + \ No newline at end of file diff --git a/conf/llama_8b_bg/crystal_llm.yaml b/conf/llama_8b_bg/crystal_llm.yaml new file mode 100644 index 0000000..61b5d3b --- /dev/null +++ b/conf/llama_8b_bg/crystal_llm.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: crystal_text_llm + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft \ No newline at end of file diff --git a/conf/llama_8b_bg/local_env.yaml b/conf/llama_8b_bg/local_env.yaml new file mode 100644 index 0000000..7a25734 --- /dev/null +++ b/conf/llama_8b_bg/local_env.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: local_env + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/slices.yaml b/conf/llama_8b_bg/slices.yaml new file mode 100644 index 0000000..b680d22 --- /dev/null +++ b/conf/llama_8b_bg/slices.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: slices + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft diff --git a/conf/llama_8b_bg/zmatrix.yaml b/conf/llama_8b_bg/zmatrix.yaml new file mode 100644 index 0000000..94734f0 --- /dev/null +++ b/conf/llama_8b_bg/zmatrix.yaml @@ -0,0 +1,7 @@ +# @package _global_ +model: + representation: zmatrix + dataset: "bandgap" + dataset_type: filtered + logging: + wandb_project: llama-7B-ft diff --git a/conf/llm_sft.yaml b/conf/llm_sft.yaml index 74ce6ed..434b756 100644 --- a/conf/llm_sft.yaml +++ b/conf/llm_sft.yaml @@ -15,9 +15,6 @@ runs: - - - - name: llama_sft_run tasks: [llama_sft] diff --git a/conf/model/benchmark_example.yaml b/conf/model/benchmark_example.yaml index 68ef9c6..6ffd0f3 100644 --- a/conf/model/benchmark_example.yaml +++ b/conf/model/benchmark_example.yaml @@ -59,7 +59,7 @@ finetune: training_arguments: output_dir: "${model.finetune.path.output_dir}" overwrite_output_dir: True - num_train_epochs: 1 + num_train_epochs: 50 per_device_train_batch_size: 1024 save_strategy: "epoch" evaluation_strategy: "epoch" diff --git a/conf/model/classification_example.yaml b/conf/model/classification_example.yaml new file mode 100644 index 0000000..dd96a4e --- /dev/null +++ b/conf/model/classification_example.yaml @@ -0,0 +1,100 @@ +representation: ??? +special_num_token: False +dataset: ??? +dataset_type: ??? +fold : 5 +data_repository: "n0w0f/MatText" +checkpoint: ??? +special_tokens: + { + "unk_token": "[UNK]", + "pad_token": "[PAD]", + "cls_token": "[CLS]", + "sep_token": "[SEP]", + "mask_token": "[MASK]", + "eos_token": "[EOS]", + "bos_token": "[BOS]", + } + +logging: + wandb_project: classification + wandb_log_model: "checkpoint" + +finetune: + model_name: classification + freeze_base_model: False + dataset_name: "${model.dataset}-train-${model.dataset_type}" + exp_name: + [ + "train_${model.representation}_${model.finetune.dataset_name}_0", + "train_${model.representation}_${model.finetune.dataset_name}_1", + "train_${model.representation}_${model.finetune.dataset_name}_2", + "train_${model.representation}_${model.finetune.dataset_name}_3", + "train_${model.representation}_${model.finetune.dataset_name}_4", + ] + + path: + pretrained_checkpoint: "${model.checkpoint}" + + finetune_data_rootpath: results # <--- Change this to the path of the finetune data + finetune_traindata: + [ + # "kvrh-train-filtered", + ] + + finetune_testdata: + root_path: "${hydra:runtime.cwd}/../../results/${now:%Y-%m-%d}/${now:%H-%M-%S}/${model.finetune.model_name}" # <--- Change this to the path where chkpoints and logs will be saved + output_dir: "${model.finetune.path.root_path}/checkpoints/${model.finetune.exp_name}" + logging_dir: "${model.finetune.path.root_path}/logs/${model.finetune.exp_name}" + finetuned_modelname: "${model.finetune.path.root_path}/checkpoints/finetuned_${model.finetune.exp_name}" + + context_length: 32 + dataprep_seed: 42 + callbacks: + early_stopping: True + custom_logger: True + early_stopping_patience: 10 + early_stopping_threshold: 0.001 + + training_arguments: + output_dir: "${model.finetune.path.output_dir}" + overwrite_output_dir: True + num_train_epochs: 2 + per_device_train_batch_size: 1024 + save_strategy: "epoch" + evaluation_strategy: "epoch" + logging_strategy: "epoch" + logging_first_step: True + save_steps: 3 # Number of epochs before saving + report_to: "wandb" + save_total_limit: 5 + learning_rate: 2e-4 + logging_steps: 1 + eval_steps: 1 + seed: 42 + load_best_model_at_end: True + +inference: + benchmark_dataset: "${model.dataset}-test-${model.dataset_type}" + context_length: "${model.finetune.context_length}" + exp_name: + [ + "test_${model.representation}_${model.finetune.dataset_name}_0", + "test_${model.representation}_${model.finetune.dataset_name}_1", + "test_${model.representation}_${model.finetune.dataset_name}_2", + "test_${model.representation}_${model.finetune.dataset_name}_3", + "test_${model.representation}_${model.finetune.dataset_name}_4", + ] + path: + pretrained_checkpoint: [] + test_data_rootpath: # <--- Change this to the path of the finetune data + test_data: + [ + # "kvrh-train-filtered", + ] + root_path: "/home/so87pot/n0w0f/mattext/src/mattext/models/predictions" # <--- Change this to the path where predictions will be saved + output_dir: "${model.inference.path.root_path}/checkpoints/${model.inference.exp_name}" + logging_dir: "${model.inference.path.root_path}/logs/${model.inference.exp_name}" + predictions: "${model.inference.path.root_path}/checkpoints/inference${model.inference.exp_name}" + + benchmark_save_file: "${model.finetune.path.root_path}" diff --git a/conf/model/formation_energy.yaml b/conf/model/formation_energy.yaml new file mode 100644 index 0000000..6ffd0f3 --- /dev/null +++ b/conf/model/formation_energy.yaml @@ -0,0 +1,100 @@ +representation: ??? +special_num_token: False +dataset: ??? +dataset_type: ??? +fold : 5 +data_repository: "n0w0f/MatText" +checkpoint: ??? +special_tokens: + { + "unk_token": "[UNK]", + "pad_token": "[PAD]", + "cls_token": "[CLS]", + "sep_token": "[SEP]", + "mask_token": "[MASK]", + "eos_token": "[EOS]", + "bos_token": "[BOS]", + } + +logging: + wandb_project: test-benchmark + wandb_log_model: "checkpoint" + +finetune: + model_name: test-benchmark + freeze_base_model: False + dataset_name: "${model.dataset}-train-${model.dataset_type}" + exp_name: + [ + "train_${model.representation}_${model.finetune.dataset_name}_0", + "train_${model.representation}_${model.finetune.dataset_name}_1", + "train_${model.representation}_${model.finetune.dataset_name}_2", + "train_${model.representation}_${model.finetune.dataset_name}_3", + "train_${model.representation}_${model.finetune.dataset_name}_4", + ] + + path: + pretrained_checkpoint: "${model.checkpoint}" + + finetune_data_rootpath: results # <--- Change this to the path of the finetune data + finetune_traindata: + [ + # "kvrh-train-filtered", + ] + + finetune_testdata: + root_path: "${hydra:runtime.cwd}/../../results/${now:%Y-%m-%d}/${now:%H-%M-%S}/${model.finetune.model_name}" # <--- Change this to the path where chkpoints and logs will be saved + output_dir: "${model.finetune.path.root_path}/checkpoints/${model.finetune.exp_name}" + logging_dir: "${model.finetune.path.root_path}/logs/${model.finetune.exp_name}" + finetuned_modelname: "${model.finetune.path.root_path}/checkpoints/finetuned_${model.finetune.exp_name}" + + context_length: 32 + dataprep_seed: 42 + callbacks: + early_stopping: True + custom_logger: True + early_stopping_patience: 10 + early_stopping_threshold: 0.001 + + training_arguments: + output_dir: "${model.finetune.path.output_dir}" + overwrite_output_dir: True + num_train_epochs: 50 + per_device_train_batch_size: 1024 + save_strategy: "epoch" + evaluation_strategy: "epoch" + logging_strategy: "epoch" + logging_first_step: True + save_steps: 3 # Number of epochs before saving + report_to: "wandb" + save_total_limit: 5 + learning_rate: 2e-4 + logging_steps: 1 + eval_steps: 1 + seed: 42 + load_best_model_at_end: True + +inference: + benchmark_dataset: "${model.dataset}-test-${model.dataset_type}" + context_length: "${model.finetune.context_length}" + exp_name: + [ + "test_${model.representation}_${model.finetune.dataset_name}_0", + "test_${model.representation}_${model.finetune.dataset_name}_1", + "test_${model.representation}_${model.finetune.dataset_name}_2", + "test_${model.representation}_${model.finetune.dataset_name}_3", + "test_${model.representation}_${model.finetune.dataset_name}_4", + ] + path: + pretrained_checkpoint: [] + test_data_rootpath: # <--- Change this to the path of the finetune data + test_data: + [ + # "kvrh-train-filtered", + ] + root_path: "/home/so87pot/n0w0f/mattext/src/mattext/models/predictions" # <--- Change this to the path where predictions will be saved + output_dir: "${model.inference.path.root_path}/checkpoints/${model.inference.exp_name}" + logging_dir: "${model.inference.path.root_path}/logs/${model.inference.exp_name}" + predictions: "${model.inference.path.root_path}/checkpoints/inference${model.inference.exp_name}" + + benchmark_save_file: "${model.finetune.path.root_path}" diff --git a/conf/model/llama_8b.yaml b/conf/model/llama_8b.yaml new file mode 100644 index 0000000..8f175f6 --- /dev/null +++ b/conf/model/llama_8b.yaml @@ -0,0 +1,137 @@ +representation: ??? +add_special_tokens: False +dataset: ??? +dataset_type: ??? +fold : 5 +data_repository: "n0w0f/MatText" +checkpoint: "meta-llama/Meta-Llama-3-8B-Instruct" +special_tokens: { + "unk_token": "[UNK]", + "pad_token": "[PAD]", + "cls_token": "[CLS]", + "sep_token": "[SEP]", + "mask_token": "[MASK]", + "eos_token": "[EOS]", + "bos_token": "[BOS]", +} + +REPRESENTATION_MAP : { + "cif_p1" : "cif_p1", + "Slice" : "slice", + } + +PROPERTY_MAP : { + "gvrh" : "shear modulus (in GPa)", + "kvrh" : "bulk modulus (in GPa)", + "dielectric" : "refractive index", + "perovskites" : "formation energy (in eV)", + "bandgap" : "bandgap (in eV)", + "form_energy" : "formation energy (in eV)",} + +MATERIAL_MAP : { + "gvrh" : "material", + "kvrh" : "material", + "dielectric" : "dielectric material", + "perovskites" : "perovskite material", + "bandgap" : "material ", + "form_energy" : "material", } + + +logging: + wandb_project : test-llama + wandb_log_model : "checkpoint" + +finetune: + model_name: test-llama + freeze_base_model: False + dataprep_seed: 42 + dataset_name: "${model.dataset}-train-${model.dataset_type}" + benchmark_dataset: "${model.dataset}-test-${model.dataset_type}" + exp_name: [ + "train_${model.representation}_${model.finetune.dataset_name}", + ] + + + path: + pretrained_checkpoint: "${model.checkpoint}" + + + finetune_data_rootpath: "/work/so87pot/material_db/all_1" # <--- Change this to the path of the finetune data + finetune_traindata: [ + "${model.finetune.path.finetune_data_rootpath}/train_${model.finetune.dataset_name}_2.json", + ] + + finetune_testdata: [ + "${model.finetune.path.finetune_data_rootpath}/test_${model.finetune.dataset_name}_2.json", + ] + + root_path: "${hydra:runtime.cwd}/../../results/${now:%Y-%m-%d}/${now:%H-%M-%S}/${model.finetune.model_name}" + output_dir: "${model.finetune.path.root_path}/checkpoints/${model.finetune.exp_name}" + logging_dir: "${model.finetune.path.root_path}/logs/${model.finetune.exp_name}" + finetuned_modelname: "${model.finetune.path.root_path}/checkpoints/finetuned_${model.finetune.exp_name}" + + context_length: 1024 + callbacks: + early_stopping: False + custom_logger: False + early_stopping_patience: 5 + early_stopping_threshold: 0.001 + generation: + n_epochs: 1 + output_dir: "${model.finetune.path.output_dir}" + + bnb_config: + use_4bit: True + use_8bit: False + bnb_4bit_compute_dtype: "float16" + bnb_4bit_quant_type: "nf4" + use_nested_quant: False + + lora_config: + r: 32 + lora_alpha: 64 + lora_dropout: 0.05 + bias: "none" + task_type: "CAUSAL_LM" + #target_modules: ['q_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'k_proj', 'v_proj'] # Choose all linear layers from the model + + + training_arguments: + output_dir: "${model.finetune.path.output_dir}" + bf16: True + fp16: False + overwrite_output_dir: True + dataloader_num_workers: 2 + num_train_epochs: 5 + per_device_train_batch_size: 8 + per_device_eval_batch_size: 8 + save_strategy: "steps" + do_eval: True + evaluation_strategy: 'steps' + logging_strategy: 'steps' + logging_first_step: True + save_steps: 20 # Number of epochs before saving + report_to: "wandb" + save_total_limit: 2 + logging_steps: 10 + eval_steps: 10 + seed: 42 + load_best_model_at_end: True + # Number of update steps to accumulate the gradients for + gradient_accumulation_steps : 4 + # Enable gradient checkpointing + gradient_checkpointing : True + # Maximum gradient normal (gradient clipping) + max_grad_norm : 0.3 + # Initial learning rate (AdamW optimizer) + learning_rate : 3e-4 # 0.0005 crystal-llm + # Weight decay to apply to all layers except bias/LayerNorm weights + weight_decay : 0.001 + # Optimizer to use + optim : "paged_adamw_32bit" + # Learning rate schedule + lr_scheduler_type : "cosine" + # Ratio of steps for a linear warmup (from 0 to learning rate) + warmup_ratio : 0.03 + warmup_steps : 10 + eval_accumulation_steps : 4 diff --git a/revision-scripts/5fold_split.py b/revision-scripts/5fold_split.py new file mode 100644 index 0000000..2342aaa --- /dev/null +++ b/revision-scripts/5fold_split.py @@ -0,0 +1,38 @@ +import json +import os +import random +from sklearn.model_selection import KFold +import fire + +def split_dataset(input_json, output_dir, n_splits=5, random_state=42): + # Load the data + with open(input_json, 'r') as f: + data = json.load(f) + + # Create KFold object + kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state) + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Perform the split + for fold, (train_index, test_index) in enumerate(kf.split(data), 1): + train_data = [data[i] for i in train_index] + test_data = [data[i] for i in test_index] + + # Save train data + train_file = os.path.join(output_dir, f'train_mp_classification_fold_{fold}.json') + with open(train_file, 'w') as f: + json.dump(train_data, f, indent=2) + + # Save test data + test_file = os.path.join(output_dir, f'test_mp_classification_fold_{fold}.json') + with open(test_file, 'w') as f: + json.dump(test_data, f, indent=2) + + print(f"Fold {fold} created: {train_file} and {test_file}") + + print("Dataset splitting completed.") + +if __name__ == "__main__": + fire.Fire(split_dataset) \ No newline at end of file diff --git a/revision-scripts/matbench_is_metal.py b/revision-scripts/matbench_is_metal.py new file mode 100644 index 0000000..76b2c64 --- /dev/null +++ b/revision-scripts/matbench_is_metal.py @@ -0,0 +1,81 @@ +import json +import os + +import hydra +from matbench.bench import MatbenchBenchmark +from omegaconf import DictConfig + +# Check if the specified benchmark exists +available_benchmarks = [ + # "matbench_dielectric", + # "matbench_expt_gap", + # "matbench_expt_is_metal", + # "matbench_glass", + # "matbench_mp_e_form", + # "matbench_mp_gap", + "matbench_mp_is_metal", + # "matbench_phonons", + # "matbench_steels", +] + + +def convert_structure_to_serializable(pymatgen_structure): + return pymatgen_structure.to(fmt="cif") + + +@hydra.main(version_base=None, config_path="../conf", config_name="config") +def main(cfg: DictConfig) -> None: + mb = MatbenchBenchmark(autoload=False) + benchmarks = cfg.matbench.benchmarks.dataset + path = cfg.matbench.path.save_path + print(path) + if not os.path.exists(path): + os.mkdir(path) + else: + print(f"Directory '{path}' already exists.") + for benchmark_name in benchmarks: + if benchmark_name not in available_benchmarks: + raise ValueError( + f"Invalid benchmark name. Available benchmarks: {', '.join(available_benchmarks)}" + ) + + for benchmark_name in benchmarks: + benchmark = getattr(mb, benchmark_name) + benchmark.load() + + for fold in benchmark.folds: + # Get train inputs and outputs + train_inputs, train_outputs = benchmark.get_train_and_val_data(fold) + test_inputs = benchmark.get_test_data(fold) + + # Create the train data + train_data = [ + { + "structure": convert_structure_to_serializable(train_inputs[index]), + "labels": train_outputs[index], + } + for index in range(len(train_inputs)) + ] + + # Save the train data as a JSON file + train_dataset_name = f"train_{benchmark_name}_{fold}.json" + with open(f"{path}/{train_dataset_name}", "w") as train_file: + json.dump(train_data, train_file) + + print(f"Train data saved to {path}/{train_dataset_name}") + + test_data = [ + convert_structure_to_serializable(test_inputs[index]) + for index in range(len(test_inputs)) + ] + + # Save the test data as a JSON file + test_dataset_name = f"test_{benchmark_name}_{fold}.json" + with open(f"{path}/{test_dataset_name}", "w") as test_file: + json.dump(test_data, test_file) + + print(f"Test data saved to {path}/{test_dataset_name}") + + +if __name__ == "__main__": + main() diff --git a/revision-scripts/mp_classification.py b/revision-scripts/mp_classification.py new file mode 100644 index 0000000..6230a06 --- /dev/null +++ b/revision-scripts/mp_classification.py @@ -0,0 +1,63 @@ +import json +import os +import pickle + +import fire +import lmdb +from pymatgen.core import Structure + + +class Dataset: + def __init__(self, lmdb_path, max_readers=1): + self.env = lmdb.open( + lmdb_path, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + max_readers=max_readers, + ) + self.txn = self.env.begin() + + def __len__(self): + return self.txn.stat()["entries"] + + def get(self, index): + id_ = f"{index}".encode("ascii") + return pickle.loads(self.txn.get(id_)) + + +def create_json_from_lmdb(lmdb_path, output_dir): + dataset = Dataset(lmdb_path) + output_data = [] + + for i in range(len(dataset)): + d = dataset.get(i) + + # Convert structure to CIF + structure = d["structure"] + cif = structure.to(fmt="cif") + + entry = { + "structure": cif, + "is_stable": d["is_stable"], + "is_metal": d["is_metal"], + "is_magnetic": d["is_magnetic"], + } + + output_data.append(entry) + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Write to JSON file + output_file = os.path.join(output_dir, "mp_test.json") + with open(output_file, "w") as f: + json.dump(output_data, f, indent=2) + + print(f"JSON file created: {output_file}") + + +if __name__ == "__main__": + fire.Fire(create_json_from_lmdb) diff --git a/revision-scripts/prep_json.py b/revision-scripts/prep_json.py new file mode 100644 index 0000000..015e5fc --- /dev/null +++ b/revision-scripts/prep_json.py @@ -0,0 +1,94 @@ +import json +import os +from matbench.bench import MatbenchBenchmark +import numpy as np + +# Define the available benchmarks +available_benchmarks = [ + "matbench_mp_is_metal", +] + +def convert_structure_to_serializable(pymatgen_structure): + """ + Convert a pymatgen Structure object to a serializable format (CIF). + """ + return pymatgen_structure.to(fmt="cif") + +def convert_label_to_serializable(label): + """ + Convert labels to 0 or 1, specifically converting numpy booleans to Python integers. + """ + return int(label) + +def download_benchmark_data(benchmark_name, save_path): + """ + Download and save the Matbench benchmark data as JSON files. + + Args: + benchmark_name (str): The name of the benchmark to download. + save_path (str): The directory path where the JSON files will be saved. + """ + if benchmark_name not in available_benchmarks: + raise ValueError( + f"Invalid benchmark name. Available benchmarks: {', '.join(available_benchmarks)}" + ) + + # Load the MatbenchBenchmark + mb = MatbenchBenchmark(autoload=False) + + # Create the save directory if it does not exist + if not os.path.exists(save_path): + os.mkdir(save_path) + else: + print(f"Directory '{save_path}' already exists.") + + # Load the benchmark data + benchmark = getattr(mb, benchmark_name) + benchmark.load() + + # Process each fold in the benchmark + for fold in benchmark.folds: + # Get train inputs and outputs + train_inputs, train_outputs = benchmark.get_train_and_val_data(fold) + test_inputs = benchmark.get_test_data(fold) + + # Create the train data + train_data = [ + { + "mbid": index, # Add material ID (index) + "structure": convert_structure_to_serializable(train_inputs[index]), + "labels": convert_label_to_serializable(train_outputs[index]), # Convert bool to 0 or 1 + } + for index in train_inputs.index + ] + + # Save the train data as a JSON file + train_dataset_name = f"train_{benchmark_name}_{fold}.json" + with open(os.path.join(save_path, train_dataset_name), "w") as train_file: + json.dump(train_data, train_file) + + print(f"Train data saved to {save_path}/{train_dataset_name}") + + # Create the test data + test_data = [ + { + "mbid": index, # Add material ID (index) + "structure": convert_structure_to_serializable(test_inputs[index]) + } + for index in test_inputs.index + ] + + # Save the test data as a JSON file + test_dataset_name = f"test_{benchmark_name}_{fold}.json" + with open(os.path.join(save_path, test_dataset_name), "w") as test_file: + json.dump(test_data, test_file) + + print(f"Test data saved to {save_path}/{test_dataset_name}") + +if __name__ == "__main__": + # Define the benchmark name and the directory to save the data + benchmark_name = "matbench_mp_is_metal" + save_path = "./benchmark_data_is_metal" + + # Download and save the benchmark data + download_benchmark_data(benchmark_name, save_path) diff --git a/revision-scripts/prep_rep.py b/revision-scripts/prep_rep.py new file mode 100644 index 0000000..5c277d7 --- /dev/null +++ b/revision-scripts/prep_rep.py @@ -0,0 +1,117 @@ +import json +import fire + +from concurrent.futures import ProcessPoolExecutor, TimeoutError +import multiprocessing +from functools import partial +from xtal2txt.core import TextRep + +from typing import List, Dict + + +def read_json(json_file: str) -> List[Dict]: + """Read JSON data from a file. + + Args: + json_file (str): The path to the JSON file. + + Returns: + List[Dict]: A list of dictionaries containing the JSON data. + """ + with open(json_file, 'r') as file: + data = json.load(file) + return data + + + + +def process_entry_train_matbench(entry: dict, timeout: int) -> dict: + + try: + text_reps = TextRep.from_input(entry["structure"]).get_requested_text_reps(["local_env","slice","composition","cif_symmetrized","cif_p1","crystal_llm_rep", "atoms","atoms_params", "zmatrix", "wyckoff_rep", "mbid" ]) # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + text_reps['is_stable'] = int(entry["is_stable"]) + text_reps["is_magnetic"] = int(entry["is_magnetic"]) + text_reps["is_metal"] = int(entry["is_metal"]) + return text_reps # Return the entire dictionary + except TimeoutError: + print("Timeout error processing a row") + return None + except Exception as e: + print(f"Error processing a row: {e}") + return None + + +def process_entry_test_matbench(entry: List, timeout: int) -> dict: + # Ensure the give_slice function and necessary data are picklable + try: + text_reps = TextRep.from_input(entry["structure"]).get_requested_text_reps(["local_env","slice","composition","cif_symmetrized","cif_p1","crystal_llm_rep", "atoms","atoms_params", "zmatrix", "wyckoff_rep", "mbid" ]) # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + text_reps["mbid"] = entry["mbid"] + return text_reps # Return the entire dictionary + except TimeoutError: + print("Timeout error processing a row") + return None + except Exception as e: + print(f"Error processing a row: {e}") + return None + + +def process_batch(num_workers, batch, timeout, process_entry_func): + + process_entry_with_timeout = partial(process_entry_func, timeout=timeout) + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + results = list(executor.map(process_entry_with_timeout, batch)) + + return [result for result in results if result is not None] + + + +def process_json_to_json(json_file: str, output_json_file: str, log_file_path: str,process_entry: str = 'test', num_workers: int = 48, timeout: int = 600, save_interval: int = 100, last_processed_entry: int = 0): + + num_cpus = multiprocessing.cpu_count() + print(num_workers) + + process_entry_funcs = { + 'test': process_entry_test_matbench, + 'train': process_entry_train_matbench + } + # Get the selected function + process_entry_func = process_entry_funcs[process_entry] + + print(f"json file: {json_file}") + print(f"number of cpus: {num_cpus}") + print(f"number of workers: {num_workers}") + print(f"last processed entry: {last_processed_entry}") + print(f"save_interval: {save_interval}") + + data = read_json(json_file) + batch_size = num_workers * 4 + + if last_processed_entry > 0: + data = data[last_processed_entry:] + + batch_iterator = (data[i:i + batch_size] for i in range(0, len(data), batch_size)) + + for i, batch_data in enumerate(batch_iterator, start=1): + batch_results = process_batch(num_workers,batch_data, timeout, process_entry_func) + + # Append batch_results to the output JSON file + with open(output_json_file, 'a') as f: + for result in batch_results: + json.dump(result, f) + f.write('\n') + + last_processed_entry += len(batch_data) + if i % save_interval == 0: + with open(log_file_path, "w") as log_file: + log_file.write(f"Last processed entry index: {last_processed_entry}\n") + log_file.write(f"Last processed batch number: {i}\n") + + print(f"Finished !!! logging at {log_file_path}") + + +if __name__ == "__main__": + fire.Fire(process_json_to_json) + + diff --git a/revision-scripts/text_rep.py b/revision-scripts/text_rep.py new file mode 100644 index 0000000..0c025e1 --- /dev/null +++ b/revision-scripts/text_rep.py @@ -0,0 +1,116 @@ +import json +import fire + +from concurrent.futures import ProcessPoolExecutor, TimeoutError +import multiprocessing +from functools import partial +from xtal2txt.core import TextRep + +from typing import List, Dict + + +def read_json(json_file: str) -> List[Dict]: + """Read JSON data from a file. + + Args: + json_file (str): The path to the JSON file. + + Returns: + List[Dict]: A list of dictionaries containing the JSON data. + """ + with open(json_file, 'r') as file: + data = json.load(file) + return data + + + + +def process_entry_train_matbench(entry: dict, timeout: int) -> dict: + + try: + text_reps = TextRep.from_input(entry["structure"]).get_requested_text_reps(["local_env","slice","composition","cif_symmetrized","cif_p1","crystal_llm_rep", "atoms","atoms_params", "zmatrix", "wyckoff_rep", "mbid" ]) # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + text_reps['labels'] = entry["labels"] + text_reps["mbid"] = entry["mbid"] + return text_reps # Return the entire dictionary + except TimeoutError: + print("Timeout error processing a row") + return None + except Exception as e: + print(f"Error processing a row: {e}") + return None + + +def process_entry_test_matbench(entry: List, timeout: int) -> dict: + # Ensure the give_slice function and necessary data are picklable + try: + text_reps = TextRep.from_input(entry["structure"]).get_requested_text_reps(["local_env","slice","composition","cif_symmetrized","cif_p1","crystal_llm_rep", "atoms","atoms_params", "zmatrix", "wyckoff_rep", "mbid" ]) # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + # Use get_all_text_reps to get various text representations # Add chemical formula to the dictionary + text_reps["mbid"] = entry["mbid"] + return text_reps # Return the entire dictionary + except TimeoutError: + print("Timeout error processing a row") + return None + except Exception as e: + print(f"Error processing a row: {e}") + return None + + +def process_batch(num_workers, batch, timeout, process_entry_func): + + process_entry_with_timeout = partial(process_entry_func, timeout=timeout) + + with ProcessPoolExecutor(max_workers=num_workers) as executor: + results = list(executor.map(process_entry_with_timeout, batch)) + + return [result for result in results if result is not None] + + + +def process_json_to_json(json_file: str, output_json_file: str, log_file_path: str,process_entry: str = 'test', num_workers: int = 48, timeout: int = 600, save_interval: int = 100, last_processed_entry: int = 0): + + num_cpus = multiprocessing.cpu_count() + print(num_workers) + + process_entry_funcs = { + 'test': process_entry_test_matbench, + 'train': process_entry_train_matbench + } + # Get the selected function + process_entry_func = process_entry_funcs[process_entry] + + print(f"json file: {json_file}") + print(f"number of cpus: {num_cpus}") + print(f"number of workers: {num_workers}") + print(f"last processed entry: {last_processed_entry}") + print(f"save_interval: {save_interval}") + + data = read_json(json_file) + batch_size = num_workers * 4 + + if last_processed_entry > 0: + data = data[last_processed_entry:] + + batch_iterator = (data[i:i + batch_size] for i in range(0, len(data), batch_size)) + + for i, batch_data in enumerate(batch_iterator, start=1): + batch_results = process_batch(num_workers,batch_data, timeout, process_entry_func) + + # Append batch_results to the output JSON file + with open(output_json_file, 'a') as f: + for result in batch_results: + json.dump(result, f) + f.write('\n') + + last_processed_entry += len(batch_data) + if i % save_interval == 0: + with open(log_file_path, "w") as log_file: + log_file.write(f"Last processed entry index: {last_processed_entry}\n") + log_file.write(f"Last processed batch number: {i}\n") + + print(f"Finished !!! logging at {log_file_path}") + + +if __name__ == "__main__": + fire.Fire(process_json_to_json) + + diff --git a/src/mattext/main.py b/src/mattext/main.py index 91120a6..e193754 100644 --- a/src/mattext/main.py +++ b/src/mattext/main.py @@ -1,4 +1,5 @@ import os +from typing import Callable, Union import hydra import wandb @@ -6,7 +7,7 @@ from hydra import utils from omegaconf import DictConfig -from mattext.models.benchmark import Matbenchmark +from mattext.models.benchmark import Matbenchmark, MatbenchmarkClassification from mattext.models.finetune import FinetuneModel from mattext.models.inference import Benchmark from mattext.models.llama import FinetuneLLama @@ -18,36 +19,86 @@ class TaskRunner: def __init__(self): self.wandb_api_key = os.environ.get("WANDB_API_KEY") + self.task_map = { + "benchmark": self.run_benchmarking, + "classification": self.run_classification, + "inference": self.run_inference, + "finetune": self.run_finetuning, + "pretrain": self.run_pretraining, + "qmof": self.run_qmof, + "llama": self.run_llama, + "llama_sft": self.run_llama_sft, + "potential": self.run_potential, + } def run_task(self, run: list, task_cfg: DictConfig, local_rank=None) -> None: - if "benchmark" in run: - self.run_benchmarking(task_cfg) - - if "inference" in run: - self.run_inference(task_cfg) - - if "finetune" in run: - self.run_finetuning(task_cfg) - - if "pretrain" in run: - self.run_pretraining(task_cfg) + for task in run: + if task in self.task_map: + self.task_map[task](task_cfg, local_rank) + else: + print(f"Unknown task: {task}") + + def _run_experiment( + self, + task_cfg: DictConfig, + local_rank: Union[int, None], + model_class: Callable, + experiment_type: str, + use_folds: bool = False, + use_train_data_path: bool = False, + ): + if use_folds: + iterations = range(task_cfg.model.fold) + elif use_train_data_path: + iterations = zip( + task_cfg.model.finetune.exp_name, + task_cfg.model.finetune.path.finetune_traindata, + ) + else: + iterations = [None] + + for item in iterations: + if use_folds: + exp_name = f"{task_cfg.model.finetune.exp_name}_fold_{item}" + fold = f"fold_{item}" + elif use_train_data_path: + exp_name, train_data_path = item + fold = None + else: + exp_name = task_cfg.model[experiment_type].exp_name + fold = None - if "qmof" in run: - self.run_qmof(task_cfg) + wandb.init( + config=dict(task_cfg.model[experiment_type]), + project=task_cfg.model.logging.wandb_project, + name=exp_name, + ) - if "llama" in run: - self.run_llama(task_cfg, local_rank=local_rank) + exp_cfg = task_cfg.copy() + exp_cfg.model[experiment_type].exp_name = exp_name + if use_train_data_path: + exp_cfg.model.finetune.path.finetune_traindata = train_data_path - if "llama_sft" in run: - self.run_llama_sft(task_cfg, local_rank=local_rank) + if fold: + model = model_class(exp_cfg, local_rank, fold=fold) + else: + model = model_class(exp_cfg, local_rank) - if "potential" in run: - self.run_potential(task_cfg) + result = ( + model.finetune() if hasattr(model, "finetune") else model.pretrain_mlm() + ) + print(result) + wandb.finish() def run_benchmarking(self, task_cfg: DictConfig, local_rank=None) -> None: - print("Finetuning and testing on matbench dataset") - matbench_predictor = Matbenchmark(task_cfg) - matbench_predictor.run_benchmarking(local_rank=local_rank) + print("Benchmarking") + benchmark = Matbenchmark(task_cfg) + benchmark.run_benchmarking(local_rank=local_rank) + + def run_classification(self, task_cfg: DictConfig, local_rank=None) -> None: + print("Benchmarking Classification") + benchmark = MatbenchmarkClassification(task_cfg) + benchmark.run_benchmarking(local_rank=local_rank) def run_qmof(self, task_cfg: DictConfig, local_rank=None) -> None: print("Finetuning on qmof") @@ -60,89 +111,27 @@ def run_inference(self, task_cfg: DictConfig, local_rank=None) -> None: matbench_predictor.run_benchmarking(local_rank=local_rank) def run_llama(self, task_cfg: DictConfig, local_rank=None) -> None: - for exp_name, train_data_path in zip( - task_cfg.model.finetune.exp_name, - task_cfg.model.finetune.path.finetune_traindata, - ): - wandb.init( - config=dict(task_cfg.model.finetune), - project=task_cfg.model.logging.wandb_project, - name=exp_name, - ) - - exp_cfg = task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = train_data_path - - finetuner = FinetuneLLama(exp_cfg, local_rank) - f = finetuner.finetune() - print(f) - wandb.finish() + self._run_experiment( + task_cfg, local_rank, FinetuneLLama, "finetune", use_train_data_path=True + ) def run_llama_sft(self, task_cfg: DictConfig, local_rank=None) -> None: - for fold in range(task_cfg.model.fold): - exp_name = f"{task_cfg.model.finetune.exp_name}_fold_{fold}" - wandb.init( - config=dict(task_cfg.model.finetune), - project=task_cfg.model.logging.wandb_project, - name=exp_name, - ) - - exp_cfg = task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - - finetuner = FinetuneLLamaSFT(exp_cfg, local_rank, fold=f"fold_{fold}") - f = finetuner.finetune() - print(f) - wandb.finish() + self._run_experiment( + task_cfg, local_rank, FinetuneLLamaSFT, "finetune", use_folds=True + ) def run_finetuning(self, task_cfg: DictConfig, local_rank=None) -> None: - for exp_name, train_data_path in zip( - task_cfg.model.finetune.exp_name, - task_cfg.model.finetune.path.finetune_traindata, - ): - wandb.init( - config=dict(task_cfg.model.finetune), - project=task_cfg.logging.wandb_project, - name=exp_name, - ) - - exp_cfg = task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = train_data_path - - finetuner = FinetuneModel(exp_cfg, local_rank) - finetuner.finetune() - wandb.finish() + self._run_experiment( + task_cfg, local_rank, FinetuneModel, "finetune", use_train_data_path=True + ) def run_potential(self, task_cfg: DictConfig, local_rank=None) -> None: - for exp_name, train_data_path in zip( - task_cfg.model.finetune.exp_name, - task_cfg.model.finetune.path.finetune_traindata, - ): - wandb.init( - config=dict(task_cfg.model.finetune), - project=task_cfg.model.logging.wandb_project, - name=exp_name, - ) - - exp_cfg = task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = train_data_path - - finetuner = PotentialModel(exp_cfg, local_rank) - finetuner.finetune() - wandb.finish() + self._run_experiment( + task_cfg, local_rank, PotentialModel, "finetune", use_train_data_path=True + ) def run_pretraining(self, task_cfg: DictConfig, local_rank=None) -> None: - wandb.init( - config=dict(task_cfg.model.pretrain), - project=task_cfg.model.logging.wandb_project, - name=task_cfg.model.pretrain.exp_name, - ) - print(task_cfg) - pretrainer = PretrainModel(task_cfg, local_rank) - pretrainer.pretrain_mlm() + self._run_experiment(task_cfg, local_rank, PretrainModel, "pretrain") def initialize_wandb(self): if self.wandb_api_key: diff --git a/src/mattext/models/benchmark.py b/src/mattext/models/benchmark.py index 93e4734..ef93603 100644 --- a/src/mattext/models/benchmark.py +++ b/src/mattext/models/benchmark.py @@ -1,34 +1,23 @@ import os import traceback +from abc import ABC, abstractmethod import wandb from matbench.bench import MatbenchBenchmark from omegaconf import DictConfig -from mattext.models.finetune import FinetuneModel -from mattext.models.predict import Inference -from mattext.models.score import MATTEXT_MATBENCH, MatTextTask +from mattext.models.finetune import FinetuneModel, FinetuneClassificationModel +from mattext.models.predict import Inference, InferenceClassification +from mattext.models.score import ( + MATTEXT_MATBENCH, + MatTextTask, +) from mattext.models.utils import fold_key_namer +from loguru import logger -class Matbenchmark: - """ - Class to perform predictions on Matbench datasets. - - Args: - - task_cfg (DictConfig): Configuration dictionary containing task parameters. - """ - +class BaseBenchmark(ABC): def __init__(self, task_cfg: DictConfig): - """ - Initializes the object with the given task configuration. - - Parameters: - task_cfg (DictConfig): The configuration dictionary containing task parameters. - - Returns: - None - """ self.task_cfg = task_cfg self.representation = self.task_cfg.model.representation self.task = self.task_cfg.model.dataset @@ -39,92 +28,137 @@ def __init__(self, task_cfg: DictConfig): self.train_data = self.task_cfg.model.finetune.dataset_name self.test_data = self.task_cfg.model.inference.benchmark_dataset self.benchmark_save_path = self.task_cfg.model.inference.benchmark_save_file - - # override wandb project name & tokenizer self.wandb_project = self.task_cfg.model.logging.wandb_project + @abstractmethod def run_benchmarking(self, local_rank=None) -> None: - """ - Runs benchmarking on the specified dataset. - - Args: - local_rank (int, optional): The local rank for distributed training. Defaults to None. - - Returns: - None + pass - Raises: - Exception: If an error occurs during inference for a finetuned checkpoint. - - """ + def _initialize_task(self): if self.task_type == "matbench": mb = MatbenchBenchmark(autoload=False) task = getattr(mb, MATTEXT_MATBENCH[self.task]) task.load() else: task = MatTextTask(task_name=self.task) + return task + + def _run_experiment(self, task, i, exp_name, test_name, local_rank): + fold_name = fold_key_namer(i) + logger.info( + f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}" + ) + logger.info("Fold Name: ",fold_name) + + exp_cfg = self.task_cfg.copy() + exp_cfg.model.finetune.exp_name = exp_name + exp_cfg.model.finetune.path.finetune_traindata = self.train_data + + finetuner = self._get_finetuner(exp_cfg, local_rank, fold_name) + ckpt = finetuner.finetune() + logger.info("Checkpoint: ",ckpt) + + wandb.init( + config=dict(self.task_cfg.model.inference), + project=self.task_cfg.model.logging.wandb_project, + name=test_name, + ) + + exp_cfg.model.inference.path.test_data = self.test_data + exp_cfg.model.inference.path.pretrained_checkpoint = ckpt + + try: + predict = self._get_inference(exp_cfg, fold_name) + predictions, prediction_ids = predict.predict() + self._record_predictions(task, i, predictions, prediction_ids) + except Exception as e: + logger.error( + f"Error occurred during inference for finetuned checkpoint '{exp_name}': {str(e)}" + ) + if isinstance(e, (ValueError, TypeError)): + raise + logger.error(traceback.format_exc()) + + @abstractmethod + def _get_finetuner(self, exp_cfg, local_rank, fold_name): + pass + + @abstractmethod + def _get_inference(self, exp_cfg, fold_name): + pass + + @abstractmethod + def _record_predictions(self, task, fold, predictions, prediction_ids): + pass + + def _save_results(self, task): + if not os.path.exists(self.benchmark_save_path): + os.makedirs(self.benchmark_save_path) + + file_name = os.path.join( + self.benchmark_save_path, + f"mattext_benchmark_{self.representation}_{self.benchmark}.json", + ) + task.to_file(file_name) + + +class Matbenchmark(BaseBenchmark): + def run_benchmarking(self, local_rank=None) -> None: + task = self._initialize_task() for i, (exp_name, test_name) in enumerate( zip(self.exp_names, self.test_exp_names) ): - print( - f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}" - ) wandb.init( config=dict(self.task_cfg.model.finetune), project=self.task_cfg.model.logging.wandb_project, name=exp_name, ) - fold_name = fold_key_namer(i) - print("-------------------------") - print(fold_name) - print("-------------------------") + self._run_experiment(task, i, exp_name, test_name, local_rank) + + self._save_results(task) + + def _get_finetuner(self, exp_cfg, local_rank, fold_name): + return FinetuneModel(exp_cfg, local_rank, fold=fold_name) + + def _get_inference(self, exp_cfg, fold_name): + return Inference(exp_cfg, fold=fold_name) + + def _record_predictions(self, task, fold, predictions, prediction_ids): + if self.task_type == "matbench": + task.record(fold, predictions) + else: + task.record_fold( + fold=fold, prediction_ids=prediction_ids, predictions=predictions + ) - exp_cfg = self.task_cfg.copy() - exp_cfg.model.finetune.exp_name = exp_name - exp_cfg.model.finetune.path.finetune_traindata = self.train_data - finetuner = FinetuneModel(exp_cfg, local_rank, fold=fold_name) - ckpt = finetuner.finetune() - print("-------------------------") - print(ckpt) - print("-------------------------") +class MatbenchmarkClassification(BaseBenchmark): + def run_benchmarking(self, local_rank=None) -> None: + task = self._initialize_task() + for i, (exp_name, test_name) in enumerate( + zip(self.exp_names, self.test_exp_names) + ): wandb.init( - config=dict(self.task_cfg.model.inference), + config=dict(self.task_cfg.model.finetune), project=self.task_cfg.model.logging.wandb_project, - name=test_name, + name=exp_name, ) + self._run_experiment(task, i, exp_name, test_name, local_rank) - exp_cfg.model.inference.path.test_data = self.test_data - exp_cfg.model.inference.path.pretrained_checkpoint = ckpt + self._save_results(task) - try: - predict = Inference(exp_cfg, fold=fold_name) - predictions, prediction_ids = predict.predict() - print(len(prediction_ids), len(predictions)) + def _initialize_task(self): + return MatTextTask(task_name=self.task, is_classification=True) - if self.task_type == "matbench": - task.record(i, predictions) - else: - task.record_fold( - fold=i, prediction_ids=prediction_ids, predictions=predictions - ) + def _get_finetuner(self, exp_cfg, local_rank, fold_name): + return FinetuneClassificationModel(exp_cfg, local_rank, fold=fold_name) - except Exception as e: - print( - f"Error occurred during inference for finetuned checkpoint '{exp_name}':" - ) - print(traceback.format_exc()) + def _get_inference(self, exp_cfg, fold_name): + return InferenceClassification(exp_cfg, fold=fold_name) - if not os.path.exists(self.benchmark_save_path): - os.makedirs(self.benchmark_save_path) - - file_name = os.path.join( - self.benchmark_save_path, - f"mattext_benchmark_{self.representation}_{self.benchmark}.json", + def _record_predictions(self, task, fold, predictions, prediction_ids): + task.record_fold( + fold=fold, prediction_ids=prediction_ids, predictions=predictions.values ) - task.to_file(file_name) - # Get final results after recording all folds - # final_results = task.get_final_results() - # print(final_results) diff --git a/src/mattext/models/finetune.py b/src/mattext/models/finetune.py index 2a3536d..76ceadd 100644 --- a/src/mattext/models/finetune.py +++ b/src/mattext/models/finetune.py @@ -1,10 +1,18 @@ +from abc import ABC, abstractmethod from functools import partial from typing import Any, Dict, List +import numpy as np import torch import wandb from datasets import DatasetDict, load_dataset from omegaconf import DictConfig +from sklearn.metrics import ( + accuracy_score, + precision_recall_fscore_support, + roc_auc_score, +) +from sklearn.preprocessing import label_binarize from torch import nn from transformers import ( AutoModelForSequenceClassification, @@ -21,15 +29,7 @@ ) -class FinetuneModel(TokenizerMixin): - """Class to perform finetuning of a language model. - Initialize the FinetuneModel. - - Args: - cfg (DictConfig): Configuration for the fine-tuning. - local_rank (int, optional): Local rank for distributed training. Defaults to None. - """ - +class BaseFinetuneModel(TokenizerMixin, ABC): def __init__(self, cfg: DictConfig, local_rank=None, fold="fold_0") -> None: super().__init__( cfg=cfg.model.representation, @@ -48,22 +48,6 @@ def __init__(self, cfg: DictConfig, local_rank=None, fold="fold_0") -> None: ) def _prepare_datasets(self, subset: str) -> DatasetDict: - """ - Prepare training and validation datasets. - - Args: - train_df (pd.DataFrame): DataFrame containing training data. - - Returns: - DatasetDict: Dictionary containing training and validation datasets. - """ - - def replace_none(example, replacement="[PAD]"): - for key, value in example.items(): - if value is None: - example[key] = replacement - return example - ds = load_dataset(self.data_repository, subset) dataset = ds[self.fold].train_test_split(shuffle=True, test_size=0.2, seed=42) dataset = dataset.filter( @@ -77,9 +61,7 @@ def replace_none(example, replacement="[PAD]"): ) def _callbacks(self) -> List[TrainerCallback]: - """Returns a list of callbacks for early stopping, and custom logging.""" callbacks = [] - if self.callbacks.early_stopping: callbacks.append( EarlyStoppingCallback( @@ -87,48 +69,27 @@ def _callbacks(self) -> List[TrainerCallback]: early_stopping_threshold=self.callbacks.early_stopping_threshold, ) ) - if self.callbacks.custom_logger: callbacks.append(CustomWandbCallback_FineTune()) - callbacks.append(EvaluateFirstStepCallback) - return callbacks - def _compute_metrics(self, p: Any, eval=True) -> Dict[str, float]: - preds = torch.tensor( - p.predictions.squeeze() - ) # Convert predictions to PyTorch tensor - label_ids = torch.tensor(p.label_ids) # Convert label_ids to PyTorch tensor - - if eval: - # Calculate RMSE as evaluation metric - eval_rmse = torch.sqrt(((preds - label_ids) ** 2).mean()).item() - return {"eval_rmse": round(eval_rmse, 3)} - else: - # Calculate RMSE as training metric - loss = torch.sqrt(((preds - label_ids) ** 2).mean()).item() - return {"train_rmse": round(loss, 3), "loss": round(loss, 3)} - - def finetune(self) -> None: - """ - Perform fine-tuning of the language model. - """ + @abstractmethod + def _compute_metrics(self, p: Any) -> Dict[str, float]: + pass + def finetune(self) -> str: pretrained_ckpt = self.cfg.path.pretrained_checkpoint - config_train_args = self.cfg.training_arguments callbacks = self._callbacks() training_args = TrainingArguments( **config_train_args, - metric_for_best_model="eval_rmse", # Metric to use for determining the best model - greater_is_better=False, # Lower eval_rmse is better + metric_for_best_model=self.get_best_metric(), + greater_is_better=self.is_greater_better(), ) - model = AutoModelForSequenceClassification.from_pretrained( - pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False - ) + model = self.get_model(pretrained_ckpt) if self.cfg.freeze_base_model: for param in model.base_model.parameters(): @@ -165,8 +126,75 @@ def finetune(self) -> None: wandb.finish() return self.cfg.path.finetuned_modelname + @abstractmethod + def get_best_metric(self) -> str: + pass + + @abstractmethod + def is_greater_better(self) -> bool: + pass + + @abstractmethod + def get_model(self, pretrained_ckpt: str): + pass + def evaluate(self): - """ - Evaluate the fine-tuned model on the test dataset. - """ ckpt = self.finetune() + + +class FinetuneModel(BaseFinetuneModel): + def _compute_metrics(self, p: Any) -> Dict[str, float]: + preds = torch.tensor(p.predictions.squeeze()) + label_ids = torch.tensor(p.label_ids) + eval_rmse = torch.sqrt(((preds - label_ids) ** 2).mean()).item() + return {"eval_rmse": round(eval_rmse, 3)} + + def get_best_metric(self) -> str: + return "eval_rmse" + + def is_greater_better(self) -> bool: + return False + + def get_model(self, pretrained_ckpt: str): + return AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False + ) + + +class FinetuneClassificationModel(BaseFinetuneModel): + def _compute_metrics(self, p: Any) -> Dict[str, float]: + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds_argmax = np.argmax(preds, axis=1) + labels = p.label_ids + precision, recall, f1, _ = precision_recall_fscore_support( + labels, preds_argmax, average="weighted" + ) + acc = accuracy_score(labels, preds_argmax) + + n_classes = preds.shape[1] + if n_classes == 2: + roc_auc = roc_auc_score(labels, preds[:, 1]) + else: + labels_binarized = label_binarize(labels, classes=range(n_classes)) + roc_auc = roc_auc_score( + labels_binarized, preds, average="weighted", multi_class="ovr" + ) + + return { + "accuracy": acc, + "f1": f1, + "precision": precision, + "recall": recall, + "roc_auc": roc_auc, + } + + def get_best_metric(self) -> str: + return "f1" + + def is_greater_better(self) -> bool: + return True + + def get_model(self, pretrained_ckpt: str): + return AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=2, ignore_mismatched_sizes=False + ) diff --git a/src/mattext/models/helper.py b/src/mattext/models/helper.py index 4f281b3..f360c67 100644 --- a/src/mattext/models/helper.py +++ b/src/mattext/models/helper.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt from datasets import load_dataset from tqdm import tqdm +from loguru import logger from mattext.models.utils import TokenizerMixin @@ -15,8 +16,8 @@ def count_tokens_and_plot( ): tokenizer = TokenizerMixin(representation) ds = load_dataset("json", data_files=dataset_path, split="train") - print(ds) - print(representation) + logger.info("Dataset: ",ds) + logger.info("Representation: "representation) dataset = ds[representation] token_counts = [] diff --git a/src/mattext/models/inference.py b/src/mattext/models/inference.py index 49352e4..f927b65 100644 --- a/src/mattext/models/inference.py +++ b/src/mattext/models/inference.py @@ -2,6 +2,7 @@ import traceback import wandb +from loguru import logger from matbench.bench import MatbenchBenchmark from omegaconf import DictConfig @@ -65,21 +66,15 @@ def run_benchmarking(self, local_rank=None) -> None: for i, (exp_name, test_name, train_data_path, test_data_path) in enumerate( zip(self.exp_names, self.test_exp_names, self.train_data, self.test_data) ): - print( + logger.info( f"Running training on {train_data_path}, and testing on {test_data_path}" ) - # wandb.init( - # config=dict(self.task_cfg.model.finetune), - # project=self.task_cfg.model.logging.wandb_project, name=exp_name) - exp_cfg = self.task_cfg.copy() exp_cfg.model.finetune.exp_name = exp_name exp_cfg.model.finetune.path.finetune_traindata = train_data_path ckpt = exp_cfg.model.finetune.path.finetuned_modelname - print("-------------------------") - print(ckpt) - print("-------------------------") + logger.info("Checkpoint: ", ckpt) wandb.init( config=dict(self.task_cfg.model.inference), @@ -95,10 +90,10 @@ def run_benchmarking(self, local_rank=None) -> None: predictions = predict.predict() benchmark.record(i, predictions) except Exception as e: - print( + logger.error( f"Error occurred during inference for finetuned checkpoint '{exp_name}':" ) - print(traceback.format_exc()) + logger.error(traceback.format_exc()) if not os.path.exists(self.benchmark_save_path): os.makedirs(self.benchmark_save_path) diff --git a/src/mattext/models/llama_sft.py b/src/mattext/models/llama_sft.py index 53942a1..5cf015d 100644 --- a/src/mattext/models/llama_sft.py +++ b/src/mattext/models/llama_sft.py @@ -4,6 +4,7 @@ import torch import wandb from datasets import load_dataset +from loguru import logger from omegaconf import DictConfig from peft import ( LoraConfig, @@ -22,7 +23,6 @@ from mattext.models.utils import ( EvaluateFirstStepCallback, ) -from loguru import logger class FinetuneLLamaSFT: @@ -209,14 +209,14 @@ def finetune(self) -> None: trainer.save_state() trainer.save_model(self.output_dir_) - # Merge LoRA and base model - merged_model = trainer.model.merge_and_unload() - # Save the merged model - merged_model.save_pretrained( - f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained", - save_config=True, - safe_serialization=True, - ) + # # Merge LoRA and base model + # merged_model = trainer.model.merge_and_unload() + # # Save the merged model + # merged_model.save_pretrained( + # f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained", + # save_config=True, + # safe_serialization=True, + # ) self.tokenizer.save_pretrained( f"{self.cfg.path.finetuned_modelname}_{self.fold}/llamav3-8b-lora-save-pretrained" ) @@ -231,5 +231,15 @@ def finetune(self) -> None: ) as json_file: json.dump(merge_pred, json_file) + # Empty VRAM + del trainer + del collator + del pipe + del self.model + del self.tokenizer + import gc + + gc.collect() + gc.collect() wandb.finish() return self.cfg.path.finetuned_modelname diff --git a/src/mattext/models/predict.py b/src/mattext/models/predict.py index a7228d7..96b4836 100644 --- a/src/mattext/models/predict.py +++ b/src/mattext/models/predict.py @@ -1,18 +1,24 @@ +from abc import ABC, abstractmethod from functools import partial -from typing import List +from typing import List, Tuple, Union +import numpy as np import pandas as pd import torch from datasets import DatasetDict, load_dataset from omegaconf import DictConfig +from sklearn.metrics import ( + accuracy_score, + precision_recall_fscore_support, + roc_auc_score, +) +from sklearn.preprocessing import label_binarize from transformers import AutoModelForSequenceClassification, Trainer, TrainerCallback from mattext.models.utils import CustomWandbCallback_Inference, TokenizerMixin -class Inference(TokenizerMixin): - """Class to perform inference on a language model with a sequence classification head.""" - +class BaseInference(TokenizerMixin, ABC): def __init__(self, cfg: DictConfig, fold="fold_0"): super().__init__( cfg=cfg.model.representation, @@ -29,20 +35,10 @@ def __init__(self, cfg: DictConfig, fold="fold_0"): self.prediction_ids = None def _prepare_datasets(self, path: str) -> DatasetDict: - """ - Prepare training and validation datasets. - - Args: - train_df (pd.DataFrame): DataFrame containing training data. - - Returns: - DatasetDict: Dictionary containing training and validation datasets. - """ dataset = load_dataset(self.data_repository, path) filtered_dataset = dataset[self.fold].filter( lambda example: example[self.representation] is not None ) - return filtered_dataset.map( partial( self._tokenize_pad_and_truncate, context_length=self.context_length @@ -51,16 +47,21 @@ def _prepare_datasets(self, path: str) -> DatasetDict: ) def _callbacks(self) -> List[TrainerCallback]: - """Returns a list of callbacks for logging.""" return [CustomWandbCallback_Inference()] - def predict(self): + @abstractmethod + def get_model(self, pretrained_ckpt: str): + pass + + @abstractmethod + def process_predictions(self, predictions) -> Union[pd.Series, pd.DataFrame]: + pass + + def predict(self) -> Tuple[Union[pd.Series, pd.DataFrame], List[str]]: pretrained_ckpt = self.cfg.path.pretrained_checkpoint callbacks = self._callbacks() - model = AutoModelForSequenceClassification.from_pretrained( - pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False - ) + model = self.get_model(pretrained_ckpt) trainer = Trainer( model=model.to("cuda"), data_collator=None, callbacks=callbacks @@ -68,16 +69,71 @@ def predict(self): predictions = trainer.predict(self.tokenized_test_datasets) for callback in callbacks: - callback.on_predict_end( - None, None, None, model, predictions - ) # Manually trigger callback + callback.on_predict_end(None, None, None, model, predictions) torch.cuda.empty_cache() - # TODO: Save predictions to disk optional - # os.makedirs(self.cfg.path.predictions, exist_ok=True) - # predictions_path = os.path.join(self.cfg.path.predictions, 'predictions.npy') - # np.save(predictions_path, predictions.predictions) prediction_ids = self.tokenized_test_datasets["mbid"] self.prediction_ids = prediction_ids - return pd.Series(predictions.predictions.flatten()), prediction_ids + processed_predictions = self.process_predictions(predictions) + + return processed_predictions, prediction_ids + + +class Inference(BaseInference): + def get_model(self, pretrained_ckpt: str): + return AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=1, ignore_mismatched_sizes=False + ) + + def process_predictions(self, predictions) -> pd.Series: + return pd.Series(predictions.predictions.flatten()) + + +class InferenceClassification(BaseInference): + def __init__(self, cfg: DictConfig, fold="fold_0"): + super().__init__(cfg, fold) + self.num_labels = 2 # You might want to make this configurable + + def get_model(self, pretrained_ckpt: str): + return AutoModelForSequenceClassification.from_pretrained( + pretrained_ckpt, num_labels=self.num_labels, ignore_mismatched_sizes=False + ) + + def process_predictions(self, predictions) -> pd.DataFrame: + probabilities = torch.nn.functional.softmax( + torch.from_numpy(predictions.predictions), dim=-1 + ).numpy() + return pd.DataFrame( + probabilities, columns=[f"class_{i}" for i in range(self.num_labels)] + ) + + def evaluate(self, true_labels: List[int]) -> dict: + predictions, _ = self.predict() + pred_labels = np.argmax(predictions.values, axis=1) + + accuracy = accuracy_score(true_labels, pred_labels) + precision, recall, f1, _ = precision_recall_fscore_support( + true_labels, pred_labels, average="weighted" + ) + + if self.num_labels == 2: + roc_auc = roc_auc_score(true_labels, predictions.iloc[:, 1]) + else: + true_labels_binarized = label_binarize( + true_labels, classes=range(self.num_labels) + ) + roc_auc = roc_auc_score( + true_labels_binarized, + predictions, + average="weighted", + multi_class="ovr", + ) + + return { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "roc_auc": roc_auc, + } diff --git a/src/mattext/models/score.py b/src/mattext/models/score.py index baacdbe..62248ff 100644 --- a/src/mattext/models/score.py +++ b/src/mattext/models/score.py @@ -1,58 +1,53 @@ import json import math -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field from typing import Any, Dict, List import numpy as np import pandas as pd from matbench.data_ops import load from sklearn.metrics import ( + accuracy_score, mean_absolute_error, mean_squared_error, + precision_recall_fscore_support, + roc_auc_score, ) MATTEXT_MATBENCH = { "kvrh": "matbench_log_kvrh", "gvrh": "matbench_log_gvrh", "perovskites": "matbench_perovskites", + "bandgap": "matbench_mp_gap", + "form_energy": "matbench_mp_e_form", + "is-metal": "matbench_mp_is_metal", } MATMINER_COLUMNS = { "kvrh": "log10(K_VRH)", "gvrh": "log10(G_VRH)", "perovskites": "e_form", + "is-metal": "is_metal", + "bandgap": "gap pbe", + "form_energy": "e_form", } -METRIC_MAP = { - "mae": mean_absolute_error, - "rmse": lambda true, pred: math.sqrt(mean_squared_error(true, pred)), -} - - -def fold_key_namer(fold_key): - return f"fold_{fold_key}" - def load_true_scores(dataset, mbids): data_frame = load(MATTEXT_MATBENCH[dataset]) scores = [] for mbid in mbids: - # Get the score for the mbid score = data_frame.loc[mbid][MATMINER_COLUMNS[dataset]] scores.append(score) return scores -def mattext_score(prediction_ids, predictions, task_name): - true = load_true_scores(task_name, prediction_ids) - return mean_squared_error(true, predictions) - - @dataclass class MatTextTask: task_name: str num_folds: int = 5 - # metric: str + is_classification: bool = False + num_classes: int = 2 folds_results: Dict[int, Dict[str, Any]] = field(default_factory=dict) recorded_folds: List[int] = field(default_factory=list) @@ -62,6 +57,21 @@ def record_fold( if fold in self.recorded_folds: raise ValueError(f"Fold {fold} has already been recorded.") true_scores = load_true_scores(self.task_name, prediction_ids) + + if self.is_classification: + self._calculate_classification_metrics( + fold, prediction_ids, predictions, true_scores + ) + else: + self._calculate_regression_metrics( + fold, prediction_ids, predictions, true_scores + ) + + self.recorded_folds.append(fold) + + def _calculate_regression_metrics( + self, fold, prediction_ids, predictions, true_scores + ): mae = mean_absolute_error(true_scores, predictions) rmse = math.sqrt(mean_squared_error(true_scores, predictions)) self.folds_results[fold] = { @@ -71,63 +81,91 @@ def record_fold( "mae": mae, "rmse": rmse, } - self.recorded_folds.append(fold) + + def _calculate_classification_metrics( + self, fold, prediction_ids, predictions, true_labels + ): + pred_labels = np.argmax(predictions, axis=1) + accuracy = accuracy_score(true_labels, pred_labels) + precision, recall, f1, _ = precision_recall_fscore_support( + true_labels, pred_labels, average="weighted" + ) + roc_auc = ( + roc_auc_score(true_labels, predictions[:, 1]) + if self.num_classes == 2 + else None + ) + self.folds_results[fold] = { + "prediction_ids": prediction_ids, + "predictions": predictions, + "true_labels": true_labels, + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "roc_auc": roc_auc, + } def get_final_results(self): if len(self.recorded_folds) < self.num_folds: raise ValueError( f"All {self.num_folds} folds must be recorded before getting final results." ) - final_scores_mae = [ - self.folds_results[fold]["mae"] for fold in range(self.num_folds) - ] - final_scores_rmse = [ - self.folds_results[fold]["rmse"] for fold in range(self.num_folds) - ] + return self._aggregate_results() + + def _aggregate_results(self): + if self.is_classification: + metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"] + else: + metrics = ["mae", "rmse"] + + final_scores = {metric: [] for metric in metrics} + for fold in range(self.num_folds): + for metric in metrics: + if metric in self.folds_results[fold]: + final_scores[metric].append(self.folds_results[fold][metric]) return { - "mean_mae_score": np.mean(final_scores_mae), - "std_mae_score": np.std(final_scores_mae), - "mean_rmse_score": np.mean(final_scores_rmse), - "std_rmse_score": np.std(final_scores_rmse), - "std_score": np.std(final_scores_mae), + f"mean_{metric}": np.mean(scores) + for metric, scores in final_scores.items() + if scores + } | { + f"std_{metric}": np.std(scores) + for metric, scores in final_scores.items() + if scores } def to_file(self, file_path: str): - final_results = ( - self.get_final_results() - if len(self.recorded_folds) == self.num_folds - else {} - ) - data_to_save = asdict(self) - data_to_save["final_results"] = final_results with open(file_path, "w") as f: - json.dump(data_to_save, f, default=self._json_serializable) + json.dump(self, f, default=self._json_serializable) @staticmethod def from_file(file_path: str): with open(file_path) as f: data = json.load(f) - task = MatTextTask(task_name=data["task_name"], metric=data["metric"]) + task = MatTextTask( + task_name=data["task_name"], + num_folds=data["num_folds"], + is_classification=data["is_classification"], + num_classes=data["num_classes"], + ) task.folds_results = data["folds_results"] task.recorded_folds = data["recorded_folds"] return task - @staticmethod - def _prepare_for_serialization(obj): - if isinstance(obj, dict): - return { - k: MatTextTask._prepare_for_serialization(v) for k, v in obj.items() - } - elif ( - isinstance(obj, (list, pd.Series, np.ndarray)) - ): - return MatTextTask._prepare_for_serialization(obj.tolist()) - else: - return obj - @staticmethod def _json_serializable(obj): if isinstance(obj, (np.ndarray, pd.Series)): return obj.tolist() + elif isinstance(obj, (np.bool_, np.integer, np.floating)): + return obj.item() + elif isinstance(obj, MatTextTask): + return { + "task_name": obj.task_name, + "num_folds": obj.num_folds, + "is_classification": obj.is_classification, + "num_classes": obj.num_classes, + "folds_results": obj.folds_results, + "recorded_folds": obj.recorded_folds, + } raise TypeError(f"Type {type(obj)} not serializable") diff --git a/src/mattext/models/utils.py b/src/mattext/models/utils.py index 13fd2cc..864df8c 100644 --- a/src/mattext/models/utils.py +++ b/src/mattext/models/utils.py @@ -2,6 +2,7 @@ import torch import wandb +from loguru import logger from tqdm import tqdm from transformers import GenerationConfig, TrainerCallback from transformers.integrations import WandbCallback @@ -117,8 +118,8 @@ def __init__( truncation=False, padding=False, ) - print(f"special_tokens: {special_tokens}") - print(self._wrapped_tokenizer.tokenize("Se2Se3")) + logger.info(f"special_tokens: {special_tokens}") + logger.info(self._wrapped_tokenizer.tokenize("Se2Se3")) # self._wrapped_tokenizer.add_special_tokens(special_tokens=special_tokens) @@ -188,7 +189,7 @@ def on_log( if state.is_world_process_zero: step = state.global_step # Retrieve the current step epoch = state.epoch # Retrieve the current epoch - print(f"Step: {step}, Epoch: {round(epoch,5)}") + logger.info(f"Step: {step}, Epoch: {round(epoch,5)}") if ( "loss" in logs and "eval_loss" in logs