-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
80 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,13 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
import flax.linen as nn | ||
import jax | ||
from jax.sharding import PartitionSpec as PS | ||
|
||
from fanan.config import Config | ||
|
||
_ARCHITECTURES: dict[str, Any] = {} # registry | ||
|
||
|
||
def register_architecture(cls): | ||
_ARCHITECTURES[cls.__name__.lower()] = cls | ||
return cls | ||
|
||
|
||
# class Architecture(ABC, nn.Module): | ||
class Architecture(ABC): | ||
"""Base class for all architectures.""" | ||
|
||
def __init__(self, config: Config) -> None: | ||
self._config = config | ||
|
||
@property | ||
def config(self) -> Config: | ||
return self._config | ||
|
||
# @abstractmethod | ||
# def __call__( | ||
# self, batch: dict[str, jax.Array], training: bool | ||
# ) -> dict[str, jax.Array]: | ||
# pass | ||
|
||
# @abstractmethod | ||
# def shard(self, ps: PS) -> tuple[Architecture, PS]: | ||
# pass | ||
|
||
|
||
from fanan.modeling.architectures.ddim import * # isort:skip | ||
# from fanan.modeling.architectures.ddpm import * # isort:skip | ||
|
||
|
||
def get_architecture(config: Config) -> Architecture: | ||
assert config.arch.architecture_name, "Arch config must specify 'architecture'." | ||
return _ARCHITECTURES[config.arch.architecture_name.lower()](config) | ||
__all__ = [ | ||
"Architecture", | ||
"register_architecture", | ||
"get_architecture", | ||
"DDIM", | ||
] | ||
|
||
from fanan.modeling.architectures.base import Architecture | ||
from fanan.modeling.architectures.ddim import DDIM | ||
from fanan.modeling.architectures.registry import ( | ||
get_architecture, | ||
register_architecture, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from __future__ import annotations | ||
|
||
from fanan.config import Config | ||
|
||
|
||
# class Architecture(ABC, nn.Module): | ||
class Architecture: # (ABC): | ||
"""Base class for all architectures.""" | ||
|
||
def __init__(self, config: Config) -> None: | ||
self._config = config | ||
|
||
@property | ||
def config(self) -> Config: | ||
return self._config | ||
|
||
# @abstractmethod | ||
# def __call__( | ||
# self, batch: dict[str, jax.Array], training: bool | ||
# ) -> dict[str, jax.Array]: | ||
# pass | ||
|
||
# @abstractmethod | ||
# def shard(self, ps: PS) -> tuple[Architecture, PS]: | ||
# pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from fanan.config import Config | ||
from fanan.modeling.architectures.base import Architecture | ||
|
||
_ARCHITECTURES: dict[str, Any] = {} # registry | ||
|
||
|
||
def register_architecture(cls): | ||
global _ARCHITECTURES | ||
_ARCHITECTURES[cls.__name__.lower()] = cls | ||
return cls | ||
|
||
|
||
def get_architecture(config: Config) -> Architecture: | ||
assert config.arch.architecture_name, "Arch config must specify 'architecture'." | ||
return _ARCHITECTURES[config.arch.architecture_name.lower()](config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters