Skip to content

Commit

Permalink
optim(graph): Minor optimizations to examples using graph parameters (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
eddiebergman authored May 11, 2024
1 parent 529f334 commit d05b2ef
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 119 deletions.
28 changes: 8 additions & 20 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,19 @@ jobs:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11']
os: [ubuntu-latest, macos-latest, windows-latest]
exclude:
- os: macos-latest # Segmentation fault on github actions that we can not reproduce in the wild
python-version: '3.8'
defaults:
run:
shell: bash

runs-on: ${{ matrix.os }}

steps:
- uses: actions/checkout@v4

- run: pipx install poetry
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'poetry'
cache-dependency-path: '**/pyproject.toml'
- run: poetry install
- run: poetry run pytest -m "" # Run all markers

# Temporary bugfix see https://github.com/pre-commit/pre-commit/issues/2178
- name: Pin virtualenv version
run: pip install virtualenv==20.10.0

- name: Install Poetry
uses: abatilo/[email protected]
with:
poetry-version: 1.3.2

- name: Run poetry install
run: poetry install

- name: Run pytest
timeout-minutes: 15
run: poetry run pytest -m "all_examples or runtime or neps_api or summary_csv"
149 changes: 78 additions & 71 deletions neps/search_spaces/architecture/cfg.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations
import itertools
import math
import sys
from collections import defaultdict, deque
from functools import partial
from queue import LifoQueue
from typing import Deque, Tuple
from typing import Deque, Tuple, Hashable

import numpy as np
from nltk import CFG
from nltk import CFG, Production
from nltk.grammar import Nonterminal
from scipy.integrate._ivp.radau import P
from torch import Value


class Grammar(CFG):
Expand Down Expand Up @@ -192,10 +195,12 @@ def sampler(
for i in range(0, n)
]

def _sampler(self, symbol=None, user_priors: bool = False):
def _sampler(self, symbol=None, user_priors: bool = False, *, _cache: dict[Hashable, str] | None = None):
# simple sampler where each production is sampled uniformly from all possible productions
# Tree choses if return tree or list of terminals
# recursive implementation
if _cache is None:
_cache = {}

# init the sequence
tree = "(" + str(symbol)
Expand All @@ -208,12 +213,19 @@ def _sampler(self, symbol=None, user_priors: bool = False):
production = choice(productions, probs=self._prior[str(symbol)])
else:
production = choice(productions)

for sym in production.rhs():
if isinstance(sym, str):
# if terminal then add string to sequence
## if terminal then add string to sequence
tree = tree + " " + sym
else:
tree = tree + " " + self._sampler(sym, user_priors=user_priors) + ")"
cached = _cache.get(sym)
if cached is None:
cached = self._sampler(sym, user_priors=user_priors, _cache=_cache)
_cache[sym] = cached

tree = tree + " " + cached + ")"

return tree

def sampler_maxMin_func(self, symbol: str = None, largest: bool = True):
Expand Down Expand Up @@ -284,88 +296,83 @@ def _convergent_sampler(
return tree, depth, num_prod

def compute_prior(self, string_tree: str, log: bool = True) -> float:
def skip_char(char: str) -> bool:
if char in [" ", "\t", "\n"]:
return True
# special case: "(" is (part of) a terminal
if (
i != 0
and char == "("
and string_tree[i - 1] == " "
and string_tree[i + 1] == " "
):
return False
if char == "(":
return True
return False

def find_longest_match(
i: int, string_tree: str, symbols: list, max_match: int
) -> int:
# search for longest matching symbol and add it
# assumes that the longest match is the true match
j = min(i + max_match, len(string_tree) - 1)
while j > i and j < len(string_tree):
if string_tree[i:j] in symbols:
break
j -= 1
if j == i:
raise Exception(f"Terminal or nonterminal at position {i} does not exist")
return j

prior_prob = 1.0 if not log else 0.0

symbols = self.nonterminals + self.terminals
max_match = max(map(len, symbols))
find_longest_match_func = partial(
find_longest_match,
string_tree=string_tree,
symbols=symbols,
max_match=max_match,
)
q_production_rules: list[tuple[list, int]] = []
non_terminal_productions: dict[str, list[Production]] = {
sym: self.productions(lhs=Nonterminal(sym))
for sym in self.nonterminals
}

q_production_rules: LifoQueue = LifoQueue()
_symbols_by_size = sorted(symbols, key=len, reverse=True)
_longest = len(_symbols_by_size[0])

i = 0
while i < len(string_tree):
_tree_len = len(string_tree)
while i < _tree_len:
char = string_tree[i]
if skip_char(char):
pass
elif char == ")" and not string_tree[i - 1] == " ":
if char in " \t\n":
i += 1
continue

if char == "(":
if i == 0:
i += 1
continue

# special case: "(" is (part of) a terminal
if string_tree[i - 1: i + 2] != " ( ":
i += 1
continue

if char == ")" and not string_tree[i - 1] == " ":
# closing symbol of production
production = q_production_rules.get(block=False)[0][0]
idx = self.productions(production.lhs()).index(production)
production = q_production_rules.pop()[0][0]
lhs_production = production.lhs()

idx = self.productions(lhs=lhs_production).index(production)
if log:
prior_prob += np.log(self.prior[str(production.lhs())][idx] + 1e-1000)
prior_prob += np.log(self.prior[(lhs_production)][idx] + 1e-15)
else:
prior_prob *= self.prior[str(production.lhs())][idx]
prior_prob *= self.prior[str(lhs_production)][idx]
i+=1
continue

_s = string_tree[i : i + _longest]
for sym in _symbols_by_size:
if _s.startswith(sym):
break
else:
j = find_longest_match_func(i)
sym = string_tree[i:j]
i = j - 1
raise RuntimeError(f"Terminal or nonterminal at position {i} does not exist")

i += len(sym) - 1

if sym in self.terminals:
q_production_rules.queue[-1][0] = [
if sym in self.terminals:
_productions, _count = q_production_rules[-1]
new_productions = [
production
for production in _productions
if production.rhs()[_count] == sym
]
q_production_rules[-1] = (new_productions, _count + 1)
elif sym in self.nonterminals:
if len(q_production_rules) > 0:
_productions, _count = q_production_rules[-1]
new_productions = [
production
for production in q_production_rules.queue[-1][0]
if production.rhs()[q_production_rules.queue[-1][1]] == sym
for production in _productions
if str(production.rhs()[_count])
== sym
]
q_production_rules.queue[-1][1] += 1
elif sym in self.nonterminals:
if not q_production_rules.empty():
q_production_rules.queue[-1][0] = [
production
for production in q_production_rules.queue[-1][0]
if str(production.rhs()[q_production_rules.queue[-1][1]])
== sym
]
q_production_rules.queue[-1][1] += 1
q_production_rules.put([self.productions(lhs=Nonterminal(sym)), 0])
else:
raise Exception(f"Unknown symbol {sym}")
q_production_rules[-1] = (new_productions, _count + 1)

q_production_rules.append((non_terminal_productions[sym], 0))
else:
raise Exception(f"Unknown symbol {sym}")
i += 1

if not q_production_rules.empty():
if len(q_production_rules) > 0:
raise Exception(f"Error in prior computation for {string_tree}")

return prior_prob
Expand Down
42 changes: 30 additions & 12 deletions neps/search_spaces/architecture/graph_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial
from typing import Any, ClassVar, Mapping
from typing_extensions import override, Self
from neps.utils.types import NotSet, _NotSet

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -41,6 +42,7 @@ class GraphParameter(ParameterWithPrior[nx.DiGraph, str], MutatableParameter):
DEFAULT_CONFIDENCE_SCORES: ClassVar[Mapping[str, float]] = {"not_in_use": 1.0}
default_confidence_choice = "not_in_use"
has_prior: bool
input_kwargs: dict[str, Any]

@property
@abstractmethod
Expand Down Expand Up @@ -71,10 +73,6 @@ def __eq__(self, other: Any) -> bool:
@abstractmethod
def compute_prior(self, normalized_value: float) -> float: ...

@override
def serialize_value(self) -> str:
return self.id # type: ignore

@override
def set_value(self, value: str | None) -> None:
# NOTE(eddiebergman): Not entirely sure how this should be done
Expand Down Expand Up @@ -156,9 +154,29 @@ def normalized_to_value(self, normalized_value: float) -> nx.DiGraph:

@override
def clone(self) -> Self:
# NOTE(eddiebergman): We don't have any safe way better than a deepcopy
# I think
return deepcopy(self)
new_self = self.__class__(**self.input_kwargs)

# HACK(eddiebergman): It seems the subclasses all have these and
# so we just copy over those attributes, deepcloning anything that is mutable
if self._value is not None:
_attrs_that_subclasses_use_to_reoresent_a_value = (
("_value", True),
("string_tree", False),
("string_tree_list", False),
("nxTree", False),
("_function_id", False),
)
for _attr, is_mutable in _attrs_that_subclasses_use_to_reoresent_a_value:
retrieved_attr = getattr(self, _attr, NotSet)
if retrieved_attr is NotSet:
continue

if is_mutable:
setattr(new_self, _attr, deepcopy(retrieved_attr))
else:
setattr(new_self, _attr, retrieved_attr)

return new_self

class GraphGrammar(GraphParameter, CoreGraphGrammar):
hp_name = "graph_grammar"
Expand Down Expand Up @@ -207,7 +225,7 @@ def __init__(

@override
def sample(self, *, user_priors: bool = False) -> Self:
copy_self = deepcopy(self)
copy_self = self.clone()
copy_self.reset()
copy_self.string_tree = copy_self.grammars[0].sampler(1, user_priors=user_priors)[0]
_ = copy_self.value # required for checking if graph is valid!
Expand Down Expand Up @@ -386,7 +404,7 @@ def create_graph_from_string(self, child: str):
raise NotImplementedError


class GraphGrammarRepetitive(CoreGraphGrammar, GraphParameter):
class GraphGrammarRepetitive(GraphParameter, CoreGraphGrammar):
hp_name = "graph_grammar_repetitive"

def __init__(
Expand Down Expand Up @@ -487,7 +505,7 @@ def crossover(

@override
def sample(self, *, user_priors: bool = False) -> Self:
copy_self = deepcopy(self)
copy_self = self.clone()
copy_self.reset()
copy_self.string_tree_list = [grammar.sampler(1)[0] for grammar in copy_self.grammars]
copy_self.string_tree = copy_self.assemble_trees(
Expand Down Expand Up @@ -614,7 +632,7 @@ def recursive_worker(
)


class GraphGrammarMultipleRepetitive(CoreGraphGrammar, GraphParameter):
class GraphGrammarMultipleRepetitive(GraphParameter, CoreGraphGrammar):
hp_name = "graph_grammar_multiple_repetitive"

def __init__(
Expand Down Expand Up @@ -734,7 +752,7 @@ def _identify_macro_grammar(grammar, terminal_to_sublanguage_map):

@override
def sample(self, *, user_priors: bool = False) -> Self:
copy_self = deepcopy(self)
copy_self = self.clone()
copy_self.reset()
copy_self.string_tree_list = [
grammar.sampler(1, user_priors=user_priors)[0]
Expand Down
27 changes: 23 additions & 4 deletions neps/search_spaces/hyperparameters/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from __future__ import annotations

from functools import lru_cache
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Mapping, TypeVar
from typing_extensions import Self, override

Expand All @@ -38,6 +39,22 @@
T = TypeVar("T", int, float)


# OPTIM(eddiebergman): When calculating priors over and over,
# creating this scipy.rvs is surprisingly slow. Since we do not
# mutate them, we just cache them. This is done across instances so
# we also can access this cache with new copies of the hyperparameters.
@lru_cache(maxsize=128, typed=False)
def _get_truncnorm_prior_and_std(
low: int | float,
high: int | float,
default: int | float,
confidence_score: float,
) -> tuple[TruncNorm, float]:
std = (high - low) * confidence_score
a, b = (low - default) / std, (high - default) / std
return scipy.stats.truncnorm(a, b), float(std)


class NumericalParameter(ParameterWithPrior[T, T], MutatableParameter):
"""A numerical hyperparameter is bounded by a lower and upper value.
Expand Down Expand Up @@ -222,10 +239,12 @@ def _get_truncnorm_prior_and_std(self) -> tuple[TruncNorm, float]:
default = self.default

assert default is not None

std = (high - low) * self.default_confidence_score
a, b = (low - default) / std, (high - default) / std
return scipy.stats.truncnorm(a, b), float(std)
return _get_truncnorm_prior_and_std(
low=low,
high=high,
default=default,
confidence_score=self.default_confidence_score,
)

def to_integer(self) -> IntegerParameter:
"""Convert the numerical hyperparameter to an integer hyperparameter."""
Expand Down
Loading

0 comments on commit d05b2ef

Please sign in to comment.