diff --git a/algorithmic_efficiency/halton.py b/algorithmic_efficiency/halton.py index 9eb30861d..d710e3fce 100644 --- a/algorithmic_efficiency/halton.py +++ b/algorithmic_efficiency/halton.py @@ -10,13 +10,13 @@ import functools import itertools import math -from typing import Any, Callable, Dict, List, Sequence, Text, Tuple, Union +from typing import Any, Callable, Dict, List, Sequence, Tuple, Union from absl import logging from numpy import random -_SweepSequence = List[Dict[Text, Any]] -_GeneratorFn = Callable[[float], Tuple[Text, float]] +_SweepSequence = List[Dict[str, Any]] +_GeneratorFn = Callable[[float], Tuple[str, float]] def generate_primes(n: int) -> List[int]: @@ -195,10 +195,10 @@ def generate_sequence(num_samples: int, return halton_sequence -def _generate_double_point(name: Text, +def _generate_double_point(name: str, min_val: float, max_val: float, - scaling: Text, + scaling: str, halton_point: float) -> Tuple[str, float]: """Generate a float hyperparameter value from a Halton sequence point.""" if scaling not in ['linear', 'log']: @@ -234,7 +234,7 @@ def interval(start: int, end: int) -> Tuple[int, int]: return start, end -def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: +def loguniform(name: str, range_endpoints: Tuple[int, int]) -> _GeneratorFn: min_val, max_val = range_endpoints return functools.partial(_generate_double_point, name, @@ -244,7 +244,7 @@ def loguniform(name: Text, range_endpoints: Tuple[int, int]) -> _GeneratorFn: def uniform( - name: Text, search_points: Union[_DiscretePoints, + name: str, search_points: Union[_DiscretePoints, Tuple[int, int]]) -> _GeneratorFn: if isinstance(search_points, _DiscretePoints): return functools.partial(_generate_discrete_point, diff --git a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py index 442c85899..72108c9d9 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py @@ -16,7 +16,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu +#from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_jax import decode from algorithmic_efficiency.workloads.wmt.wmt_jax import models from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload diff --git a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py index 327ca34ad..b554b2ab3 100644 --- a/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py @@ -16,7 +16,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.wmt import bleu +#from algorithmic_efficiency.workloads.wmt import bleu from algorithmic_efficiency.workloads.wmt.wmt_pytorch import decode from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import Transformer from algorithmic_efficiency.workloads.wmt.workload import BaseWmtWorkload