Skip to content

Commit

Permalink
fix: latest versions of typing dont support Text instead str is recom…
Browse files Browse the repository at this point in the history
…mended
  • Loading branch information
init-22 committed Dec 3, 2024
1 parent 3afd1df commit 6ff2010
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
14 changes: 7 additions & 7 deletions algorithmic_efficiency/halton.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import functools
import itertools
import math
from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union

from absl import logging
from numpy import random

_SweepSequence = List[Dict[Text, Any]]
_GeneratorFn = Callable[[float], Tuple[Text, float]]
_SweepSequence = List[Dict[str, Any]]
_GeneratorFn = Callable[[float], Tuple[str, float]]


def generate_primes(n: int) -> List[int]:
Expand Down Expand Up @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int,
return halton_sequence


def _generate_double_point(name: Text,
def _generate_double_point(name: str,
min_val: float,
max_val: float,
scaling: Text,
scaling: str,
halton_point: float) -> Tuple[str, float]:
"""Generate a float hyperparameter value from a Halton sequence point."""
if scaling not in ['linear', 'log']:
Expand Down Expand Up @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]:
return start, end


def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn:
min_val, max_val = range_endpoints
return functools.partial(_generate_double_point,
name,
Expand All @@ -244,7 +244,7 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn:


def uniform(
name: Text, search_points: Union[_DiscretePoints,
name: str, search_points: Union[_DiscretePoints,
Tuple[int, int]]) -> _GeneratorFn:
if isinstance(search_points, _DiscretePoints):
return functools.partial(_generate_discrete_point,
Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from algorithmic_efficiency import param_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.wmt import bleu
#from algorithmic_efficiency.workloads.wmt import bleu
from algorithmic_efficiency.workloads.wmt.wmt_jax import decode
from algorithmic_efficiency.workloads.wmt.wmt_jax import models
from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from algorithmic_efficiency import param_utils
from algorithmic_efficiency import pytorch_utils
from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.wmt import bleu
#from algorithmic_efficiency.workloads.wmt import bleu
from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode
from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer
from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload
Expand Down

0 comments on commit 6ff2010

Please sign in to comment.