Skip to content

Commit

Permalink
pylinting
Browse files Browse the repository at this point in the history
  • Loading branch information
AmrMKayid committed May 25, 2024
1 parent 32b9447 commit 8759600
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 90 deletions.
14 changes: 14 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
PYTHON_MODULE_PATH=fanan

clean:
find . -name "*.pyc" -type f -delete
find . -name "__pycache__" -type d -delete
find . -name ".ipynb_checkpoints" -type d -delete

format:
ruff check ${PYTHON_MODULE_PATH}
docformatter --in-place --recursive ${PYTHON_MODULE_PATH}

pylinting:
## https://vald-phoenix.github.io/pylint-errors/
pylint --output-format=colorized ${PYTHON_MODULE_PATH}
10 changes: 5 additions & 5 deletions fanan/config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:


class DataConfig(ConfigDict):
"""data configuration class."""
"""Data configuration class."""

def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
super().__init__(initial_dictionary=initial_dictionary, **kwargs)
Expand All @@ -46,7 +46,7 @@ def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
super().__init__(initial_dictionary=initial_dictionary, **kwargs)
self.timesteps: int = 1000
self.beta_1: float = 1e-4
self.beta_T: float = 0.02
self.beta_t: float = 0.02
self.timestep_size: float = 0.001
self.noise_schedule: str = "linear"
self.ema_decay: float = 0.999
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:


class TrainingConfig(ConfigDict):
"""training configuration class."""
"""Training configuration class."""

def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
super().__init__(initial_dictionary=initial_dictionary, **kwargs)
Expand All @@ -107,7 +107,7 @@ def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:


class FananConfig(ConfigDict):
"""fanan configuration class."""
"""Fanan configuration class."""

def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:
super().__init__(initial_dictionary=initial_dictionary, **kwargs)
Expand All @@ -128,7 +128,7 @@ def __init__(self, initial_dictionary: dict | None = None, **kwargs) -> None:

@classmethod
def read_config_from_yaml(cls, file_path: str):
with open(file_path) as file:
with open(file_path, encoding="utf-8") as file:
updates = yaml.safe_load(file)

cfg = cls()
Expand Down
17 changes: 9 additions & 8 deletions fanan/optimization/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ def create_lr_schedule(config: Config):
lr_config = config.optimization.lr_schedule
schedule_type = lr_config.schedule_type
lr_kwargs = lr_config.lr_kwargs
if schedule_type == "constant":
return optax.constant_schedule(**lr_kwargs)
elif schedule_type == "constant_warmup":
return _constant_with_warmup(**lr_kwargs)
elif schedule_type == "cosine":
return _cosine_with_warmup(**lr_kwargs)
else:
raise NotImplementedError(schedule_type)
match schedule_type:
case "constant":
return optax.constant_schedule(**lr_kwargs)
case "constant_warmup":
return _constant_with_warmup(**lr_kwargs)
case "cosine":
return _cosine_with_warmup(**lr_kwargs)
case _:
raise NotImplementedError(schedule_type)


def _constant_with_warmup(value: float, warmup_steps: int):
Expand Down
7 changes: 4 additions & 3 deletions fanan/portal/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def load_tokenizer(
local_cache_dir: str | None = None,
trust_remote_code: bool = True,
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
"""Load a tokenizer from a model name or path, with optional revision and local cache directory."""
"""Load a tokenizer from a model name or path, with optional revision and
local cache directory."""
if _is_url_like(model_name_or_path):
if revision is not None:
raise ValueError("revision is not supported for URLs")
Expand All @@ -62,5 +63,5 @@ def load_tokenizer(
os.path.join(local_cache_dir, base_path),
trust_remote_code=trust_remote_code,
)
else:
return AutoTokenizer.from_pretrained(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code)

return AutoTokenizer.from_pretrained(model_name_or_path, revision=revision, trust_remote_code=trust_remote_code)
2 changes: 1 addition & 1 deletion fanan/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ def upsample2d(x, scale: Union[int, Tuple[int, int]], method: str = "bilinear"):
elif len(scale) == 2:
h_out, w_out = scale[0] * h, scale[1] * w
else:
raise ValueError("scale argument should be either int" "or Tuple[int, int]")
raise ValueError("scale argument should be either int or Tuple[int, int]")

return jax.image.resize(x, shape=(b, h_out, w_out, c), method=method)
60 changes: 30 additions & 30 deletions fanan/utils/logger.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
import logging
import os
from datetime import datetime
# import logging
# import os
# from datetime import datetime

from rich.logging import RichHandler
# from rich.logging import RichHandler


def setup_logger(output_dir="logs"):
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# def setup_logger(output_dir="logs"):
# # Create the output directory if it doesn't exist
# os.makedirs(output_dir, exist_ok=True)

# Create a logger
logger = logging.getLogger(__name__)
# # Create a logger
# logger = logging.getLogger(__name__)

# Configure the RichHandler with the desired formatting
rich_handler = RichHandler(
rich_tracebacks=True,
markup=True,
show_time=True,
omit_repeated_times=False,
show_level=True,
show_path=True,
tracebacks_show_locals=True,
)
# # Configure the RichHandler with the desired formatting
# rich_handler = RichHandler(
# rich_tracebacks=True,
# markup=True,
# show_time=True,
# omit_repeated_times=False,
# show_level=True,
# show_path=True,
# tracebacks_show_locals=True,
# )

# Add the RichHandler to the logger
logger.addHandler(rich_handler)
# # Add the RichHandler to the logger
# logger.addHandler(rich_handler)

# Get the current date and time
current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")
# # Get the current date and time
# current_datetime = datetime.now().strftime("%Y%m%d_%H%M%S")

# Create a FileHandler with the specified output directory and set the formatter
file_handler = logging.FileHandler(os.path.join(output_dir, f"output_{current_datetime}.log"))
# # Create a FileHandler with the specified output directory and set the formatter
# file_handler = logging.FileHandler(os.path.join(output_dir, f"output_{current_datetime}.log"))

# Add the FileHandler to the logger
logger.addHandler(file_handler)
# # Add the FileHandler to the logger
# logger.addHandler(file_handler)

# Set the logging level
logger.setLevel(logging.INFO)
# # Set the logging level
# logger.setLevel(logging.INFO)

return logger
# return logger
Loading

0 comments on commit 8759600

Please sign in to comment.