Skip to content

Commit

Permalink
fix: add logger
Browse files Browse the repository at this point in the history
  • Loading branch information
n0w0f committed Sep 22, 2024
1 parent e29825f commit bfd872c
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 44 deletions.
15 changes: 6 additions & 9 deletions src/mattext/models/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MatTextTask,
)
from mattext.models.utils import fold_key_namer
from loguru import logger


class BaseBenchmark(ABC):
Expand Down Expand Up @@ -44,22 +45,18 @@ def _initialize_task(self):

def _run_experiment(self, task, i, exp_name, test_name, local_rank):
fold_name = fold_key_namer(i)
print(
logger.info(
f"Running training on {self.train_data}, and testing on {self.test_data} for fold {i}"
)
print("-------------------------")
print(fold_name)
print("-------------------------")
logger.info("Fold Name: ",fold_name)

exp_cfg = self.task_cfg.copy()
exp_cfg.model.finetune.exp_name = exp_name
exp_cfg.model.finetune.path.finetune_traindata = self.train_data

finetuner = self._get_finetuner(exp_cfg, local_rank, fold_name)
ckpt = finetuner.finetune()
print("-------------------------")
print(ckpt)
print("-------------------------")
logger.info("Checkpoint: ",ckpt)

wandb.init(
config=dict(self.task_cfg.model.inference),
Expand All @@ -75,12 +72,12 @@ def _run_experiment(self, task, i, exp_name, test_name, local_rank):
predictions, prediction_ids = predict.predict()
self._record_predictions(task, i, predictions, prediction_ids)
except Exception as e:
print(
logger.error(
f"Error occurred during inference for finetuned checkpoint '{exp_name}': {str(e)}"
)
if isinstance(e, (ValueError, TypeError)):
raise
print(traceback.format_exc())
logger.error(traceback.format_exc())

@abstractmethod
def _get_finetuner(self, exp_cfg, local_rank, fold_name):
Expand Down
5 changes: 3 additions & 2 deletions src/mattext/models/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import matplotlib.pyplot as plt
from datasets import load_dataset
from tqdm import tqdm
from loguru import logger

from mattext.models.utils import TokenizerMixin

Expand All @@ -15,8 +16,8 @@ def count_tokens_and_plot(
):
tokenizer = TokenizerMixin(representation)
ds = load_dataset("json", data_files=dataset_path, split="train")
print(ds)
print(representation)
logger.info("Dataset: ",ds)
logger.info("Representation: "representation)
dataset = ds[representation]

token_counts = []
Expand Down
15 changes: 5 additions & 10 deletions src/mattext/models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import traceback

import wandb
from loguru import logger
from matbench.bench import MatbenchBenchmark
from omegaconf import DictConfig

Expand Down Expand Up @@ -65,21 +66,15 @@ def run_benchmarking(self, local_rank=None) -> None:
for i, (exp_name, test_name, train_data_path, test_data_path) in enumerate(
zip(self.exp_names, self.test_exp_names, self.train_data, self.test_data)
):
print(
logger.info(
f"Running training on {train_data_path}, and testing on {test_data_path}"
)
# wandb.init(
# config=dict(self.task_cfg.model.finetune),
# project=self.task_cfg.model.logging.wandb_project, name=exp_name)

exp_cfg = self.task_cfg.copy()
exp_cfg.model.finetune.exp_name = exp_name
exp_cfg.model.finetune.path.finetune_traindata = train_data_path

ckpt = exp_cfg.model.finetune.path.finetuned_modelname
print("-------------------------")
print(ckpt)
print("-------------------------")
logger.info("Checkpoint: ", ckpt)

wandb.init(
config=dict(self.task_cfg.model.inference),
Expand All @@ -95,10 +90,10 @@ def run_benchmarking(self, local_rank=None) -> None:
predictions = predict.predict()
benchmark.record(i, predictions)
except Exception as e:
print(
logger.error(
f"Error occurred during inference for finetuned checkpoint '{exp_name}':"
)
print(traceback.format_exc())
logger.error(traceback.format_exc())

if not os.path.exists(self.benchmark_save_path):
os.makedirs(self.benchmark_save_path)
Expand Down
2 changes: 1 addition & 1 deletion src/mattext/models/llama_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import wandb
from datasets import load_dataset
from loguru import logger
from omegaconf import DictConfig
from peft import (
LoraConfig,
Expand All @@ -22,7 +23,6 @@
from mattext.models.utils import (
EvaluateFirstStepCallback,
)
from loguru import logger


class FinetuneLLamaSFT:
Expand Down
64 changes: 45 additions & 19 deletions src/mattext/models/score.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import math
from dataclasses import dataclass, field
from typing import Any, Dict, List
Expand All @@ -12,7 +13,6 @@
precision_recall_fscore_support,
roc_auc_score,
)
import json

MATTEXT_MATBENCH = {
"kvrh": "matbench_log_kvrh",
Expand All @@ -32,6 +32,7 @@
"form_energy": "e_form",
}


def load_true_scores(dataset, mbids):
data_frame = load(MATTEXT_MATBENCH[dataset])
scores = []
Expand All @@ -40,6 +41,7 @@ def load_true_scores(dataset, mbids):
scores.append(score)
return scores


@dataclass
class MatTextTask:
task_name: str
Expand All @@ -49,19 +51,27 @@ class MatTextTask:
folds_results: Dict[int, Dict[str, Any]] = field(default_factory=dict)
recorded_folds: List[int] = field(default_factory=list)

def record_fold(self, fold: int, prediction_ids: List[str], predictions: List[float]):
def record_fold(
self, fold: int, prediction_ids: List[str], predictions: List[float]
):
if fold in self.recorded_folds:
raise ValueError(f"Fold {fold} has already been recorded.")
true_scores = load_true_scores(self.task_name, prediction_ids)

if self.is_classification:
self._calculate_classification_metrics(fold, prediction_ids, predictions, true_scores)
self._calculate_classification_metrics(
fold, prediction_ids, predictions, true_scores
)
else:
self._calculate_regression_metrics(fold, prediction_ids, predictions, true_scores)

self._calculate_regression_metrics(
fold, prediction_ids, predictions, true_scores
)

self.recorded_folds.append(fold)

def _calculate_regression_metrics(self, fold, prediction_ids, predictions, true_scores):
def _calculate_regression_metrics(
self, fold, prediction_ids, predictions, true_scores
):
mae = mean_absolute_error(true_scores, predictions)
rmse = math.sqrt(mean_squared_error(true_scores, predictions))
self.folds_results[fold] = {
Expand All @@ -72,11 +82,19 @@ def _calculate_regression_metrics(self, fold, prediction_ids, predictions, true_
"rmse": rmse,
}

def _calculate_classification_metrics(self, fold, prediction_ids, predictions, true_labels):
def _calculate_classification_metrics(
self, fold, prediction_ids, predictions, true_labels
):
pred_labels = np.argmax(predictions, axis=1)
accuracy = accuracy_score(true_labels, pred_labels)
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, pred_labels, average='weighted')
roc_auc = roc_auc_score(true_labels, predictions[:, 1]) if self.num_classes == 2 else None
precision, recall, f1, _ = precision_recall_fscore_support(
true_labels, pred_labels, average="weighted"
)
roc_auc = (
roc_auc_score(true_labels, predictions[:, 1])
if self.num_classes == 2
else None
)
self.folds_results[fold] = {
"prediction_ids": prediction_ids,
"predictions": predictions,
Expand All @@ -85,7 +103,7 @@ def _calculate_classification_metrics(self, fold, prediction_ids, predictions, t
"precision": precision,
"recall": recall,
"f1": f1,
"roc_auc": roc_auc
"roc_auc": roc_auc,
}

def get_final_results(self):
Expand All @@ -97,9 +115,9 @@ def get_final_results(self):

def _aggregate_results(self):
if self.is_classification:
metrics = ['accuracy', 'precision', 'recall', 'f1', 'roc_auc']
metrics = ["accuracy", "precision", "recall", "f1", "roc_auc"]
else:
metrics = ['mae', 'rmse']
metrics = ["mae", "rmse"]

final_scores = {metric: [] for metric in metrics}
for fold in range(self.num_folds):
Expand All @@ -108,9 +126,13 @@ def _aggregate_results(self):
final_scores[metric].append(self.folds_results[fold][metric])

return {
f"mean_{metric}": np.mean(scores) for metric, scores in final_scores.items() if scores
f"mean_{metric}": np.mean(scores)
for metric, scores in final_scores.items()
if scores
} | {
f"std_{metric}": np.std(scores) for metric, scores in final_scores.items() if scores
f"std_{metric}": np.std(scores)
for metric, scores in final_scores.items()
if scores
}

def to_file(self, file_path: str):
Expand All @@ -121,8 +143,12 @@ def to_file(self, file_path: str):
def from_file(file_path: str):
with open(file_path) as f:
data = json.load(f)
task = MatTextTask(task_name=data["task_name"], num_folds=data["num_folds"],
is_classification=data["is_classification"], num_classes=data["num_classes"])
task = MatTextTask(
task_name=data["task_name"],
num_folds=data["num_folds"],
is_classification=data["is_classification"],
num_classes=data["num_classes"],
)
task.folds_results = data["folds_results"]
task.recorded_folds = data["recorded_folds"]
return task
Expand All @@ -140,6 +166,6 @@ def _json_serializable(obj):
"is_classification": obj.is_classification,
"num_classes": obj.num_classes,
"folds_results": obj.folds_results,
"recorded_folds": obj.recorded_folds
"recorded_folds": obj.recorded_folds,
}
raise TypeError(f"Type {type(obj)} not serializable")
raise TypeError(f"Type {type(obj)} not serializable")
7 changes: 4 additions & 3 deletions src/mattext/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import wandb
from loguru import logger
from tqdm import tqdm
from transformers import GenerationConfig, TrainerCallback
from transformers.integrations import WandbCallback
Expand Down Expand Up @@ -117,8 +118,8 @@ def __init__(
truncation=False,
padding=False,
)
print(f"special_tokens: {special_tokens}")
print(self._wrapped_tokenizer.tokenize("Se2Se3"))
logger.info(f"special_tokens: {special_tokens}")
logger.info(self._wrapped_tokenizer.tokenize("Se2Se3"))

# self._wrapped_tokenizer.add_special_tokens(special_tokens=special_tokens)

Expand Down Expand Up @@ -188,7 +189,7 @@ def on_log(
if state.is_world_process_zero:
step = state.global_step # Retrieve the current step
epoch = state.epoch # Retrieve the current epoch
print(f"Step: {step}, Epoch: {round(epoch,5)}")
logger.info(f"Step: {step}, Epoch: {round(epoch,5)}")

if (
"loss" in logs and "eval_loss" in logs
Expand Down

0 comments on commit bfd872c

Please sign in to comment.