Skip to content

Commit

Permalink
Support Python 3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Dec 9, 2021
1 parent 8b99b35 commit a165b10
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 43 deletions.
9 changes: 6 additions & 3 deletions efax/_src/exp_to_nat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
from jax import jit
from jax.tree_util import tree_map
from tjax import default_atol, default_rtol
from tjax.dataclasses import dataclass, field
from tjax.fixed_point import ComparingIteratedFunctionWithCombinator, ComparingState
from tjax.gradient import Adam, GradientTransformation
Expand All @@ -24,12 +25,15 @@ class ExpToNat(ExpectationParametrization[NP], Generic[NP, SP]):
This mixin implements the conversion from expectation to natural parameters using Newton's
method with a Jacobian to invert the gradient log-normalizer.
"""

# Implemented methods --------------------------------------------------------------------------
@jit
def to_nat(self) -> NP:
iterated_function = ExpToNatIteratedFunction[NP, SP](minimum_iterations=1000,
maximum_iterations=1000,
rtol=default_rtol(),
atol=default_atol(),
z_minimum_iterations=100,
z_maximum_iterations=1000,
transform=Adam(1e-1))
initial_search_parameters = self.initial_search_parameters()
initial_gt_state = iterated_function.transform.init(initial_search_parameters)
Expand Down Expand Up @@ -83,8 +87,7 @@ class ExpToNatIteratedFunction(
SP,
NP],
Generic[NP, SP]):

transform: GradientTransformation[Any, NP] = field() # TODO: kw_only=True
transform: GradientTransformation[Any, NP] = field()

def sampled_state(self, theta: ExpToNat[NP, SP], state: Tuple[Any, SP]) -> Tuple[Any, SP]:
current_gt_state, search_parameters = state
Expand Down
108 changes: 70 additions & 38 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = 'poetry.core.masonry.api'

[tool.poetry]
name = 'efax'
version = "1.4.10"
version = "1.5.0"
description = "Exponential families for JAX"
license = 'MIT'
authors = ['Neil Girdhar <[email protected]>']
Expand All @@ -18,13 +18,14 @@ classifiers = [
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development :: Libraries :: Python Modules',
'Typing :: Typed',
'License :: OSI Approved :: MIT License']

[tool.poetry.dependencies]
python = '>=3.8,<3.10'
python = '>=3.8,<3.11'
jax = '>=0.2'
numpy = '>=1.21'
scipy = '^1.4'
Expand Down

0 comments on commit a165b10

Please sign in to comment.