Skip to content

Commit 0f211c7

Browse files
authored
[Refactor] Remove config for some LTune modules (#8)
* [Refactor] Remove config for LTuneDataSet * [Refactor] Remove LTune generator and model builder reliance on config
1 parent f4d288e commit 0f211c7

9 files changed

+118
-97
lines changed

endure.toml

-2
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,6 @@ norm_layer = "Batch"
312312

313313
categorical_mode = "reinmax"
314314

315-
k_clip = true
316-
317315
# kwargs specific to LTune models during forward pass
318316
[ltune.model.train_kwargs]
319317
temp = 1

endure/ltune/data/dataset.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,33 @@
11
import glob
2-
import logging
32
import numpy as np
43
import os
54
import pandas as pd
65
import pyarrow.parquet as pa
76
import torch
87
import torch.utils.data
9-
from typing import Any
108

9+
from endure.ltune.data.input_features import kINPUT_FEATS
1110

12-
class LTuneIterableDataSet(torch.utils.data.IterableDataset):
11+
12+
class LTuneDataSet(torch.utils.data.IterableDataset):
1313
def __init__(
1414
self,
15-
config: dict[str, Any],
1615
folder: str,
1716
format: str = "parquet",
1817
shuffle: bool = False,
19-
):
20-
self.log = logging.getLogger(config["log"]["name"])
21-
self._config = config
18+
) -> None:
2219
self._format = format
2320
self._fnames = glob.glob(os.path.join(folder, "*." + format))
2421
self._shuffle = shuffle
2522

2623
def _get_input_cols(self):
27-
return self._config["ltune"]["input_features"]
24+
return kINPUT_FEATS
2825

2926
def _load_data(self, fname):
3027
if self._format == "parquet":
3128
df = pa.read_table(fname).to_pandas()
3229
else:
3330
df = pd.read_csv(fname)
34-
if self._config["ltune"]["data"]["normalize_inputs"]:
35-
df = self._normalize_df(df)
36-
37-
return df
38-
39-
def _normalize_df(self, df):
40-
df[["z0", "z1", "q", "w"]] -= [0.5, 0.5, 0.5, 0.5]
41-
df[["z0", "z1", "q", "w"]] /= [0.3, 0.3, 0.3, 0.3]
4231

4332
return df
4433

endure/ltune/data/generator.py

+23-30
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,26 @@
1-
from typing import Any, Union
2-
import logging
1+
from typing import List, Tuple, Union
32

43
import numpy as np
54

6-
from endure.lsm.types import LSMDesign, System, Policy
5+
from endure.lsm.types import System
6+
from endure.ltune.data.input_features import kSYSTEM_HEADER, kWORKLOAD_HEADER
77

88

9-
class LTuneGenerator:
9+
class LTuneDataGenerator:
1010
def __init__(
1111
self,
12-
config: dict[str, Any],
13-
format: str = "parquet",
12+
page_sizes: List[int] = [4, 8, 16],
13+
entry_sizes: List[int] = [1024, 2048, 4096, 8192],
14+
memory_budget_range: Tuple[float, float] = (5.0, 20.0),
15+
selectivity_range: Tuple[float, float] = (1e-7, 1e-9),
16+
elements_range: Tuple[int, int] = (100000000, 1000000000),
1417
precision: int = 3,
1518
) -> None:
16-
self.log = logging.getLogger(config["log"]["name"])
17-
self._config = config
18-
self._header = self._gen_workload_header() + self._gen_system_header()
19-
self.format = format
19+
self.entry_sizes = entry_sizes
20+
self.memory_budget_range = memory_budget_range
21+
self.page_sizes = page_sizes
22+
self.selectivity_range = selectivity_range
23+
self.elements_range = elements_range
2024
self.precision = precision
2125

2226
def _sample_workload(self, dimensions: int) -> list:
@@ -31,25 +35,25 @@ def _sample_workload(self, dimensions: int) -> list:
3135
# TODO: Will want to configure environment to simulate larger ranges over
3236
# potential system values
3337
def _sample_entry_per_page(self, entry_size: int = 8192) -> int:
38+
# Potential page sizes are 4KB, 8KB, 16KB
3439
KB_TO_BITS = 8 * 1024
35-
page_sizes = np.array(self._config["generator"]["page_sizes"])
40+
page_sizes = np.array(self.page_sizes)
3641
entries_per_page = (page_sizes * KB_TO_BITS) / entry_size
3742
return np.random.choice(entries_per_page)
3843

3944
def _sample_selectivity(self) -> float:
40-
low, high = self._config["generator"]["selectivity_range"]
45+
low, high = self.selectivity_range
4146
return (high - low) * np.random.rand() + low
4247

4348
def _sample_entry_size(self) -> int:
44-
choices = self._config["generator"]["entry_sizes"]
45-
return np.random.choice(choices)
49+
return np.random.choice(self.entry_sizes)
4650

4751
def _sample_memory_budget(self) -> float:
48-
low, high = self._config["generator"]["memory_budget"]
52+
low, high = self.memory_budget_range
4953
return (high - low) * np.random.rand() + low
5054

5155
def _sample_total_elements(self) -> int:
52-
low, high = self._config["generator"]["elements_range"]
56+
low, high = self.elements_range
5357
return np.random.randint(low=low, high=high)
5458

5559
def _sample_system(self) -> System:
@@ -63,10 +67,10 @@ def _sample_system(self) -> System:
6367
return system
6468

6569
def _gen_system_header(self) -> list:
66-
return ["B", "s", "E", "H", "N"]
70+
return kSYSTEM_HEADER
6771

6872
def _gen_workload_header(self) -> list:
69-
return ["z0", "z1", "q", "w"]
73+
return kWORKLOAD_HEADER
7074

7175
def generate_header(self) -> list:
7276
return self._gen_workload_header() + self._gen_system_header()
@@ -89,22 +93,11 @@ def generate_row_csv(self) -> list:
8993

9094
return line
9195

92-
def generate_row_parquet(self) -> dict[str, Union[int, float]]:
96+
def generate_row(self) -> dict[str, Union[int, float]]:
9397
header = self.generate_header()
9498
row = self.generate_row_csv()
9599
line = {}
96100
for key, val in zip(header, row):
97101
line[key] = val
98102

99103
return line
100-
101-
def generate_row(
102-
self,
103-
row_type: str = "parquet"
104-
) -> Union[list, dict[str, Union[int, float]]]:
105-
if row_type == "parquet":
106-
row = self.generate_row_parquet()
107-
else: # format == 'csv'
108-
row = self.generate_row_csv()
109-
110-
return row

endure/ltune/data/input_features.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
kSYSTEM_HEADER = [
2+
"entry_p_page",
3+
"selec",
4+
"entry_size",
5+
"max_h",
6+
"num_elem"
7+
]
8+
9+
kWORKLOAD_HEADER = [
10+
"z0",
11+
"z1",
12+
"q",
13+
"w",
14+
]
15+
16+
kINPUT_FEATS = kSYSTEM_HEADER + kWORKLOAD_HEADER

endure/ltune/loss.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import toml
77

88
from endure.lcm.model.builder import LearnedCostModelBuilder
9+
from endure.lsm.types import STR_POLICY_DICT
910

1011

1112
class LearnedCostModelLoss(torch.nn.Module):
@@ -24,7 +25,9 @@ def __init__(self, config: dict[str, Any], model_path: str):
2425
max_levels=lcm_cfg["lsm"]["max_levels"],
2526
**lcm_cfg["lcm"]["model"],
2627
)
27-
lcm_model = lcm_cfg["job"]["LCMTrain"]["model"]
28+
lcm_model = STR_POLICY_DICT.get(lcm_cfg["lsm"]["design"], None)
29+
if lcm_model is None:
30+
raise TypeError(f"Illegal LCM model choice: {lcm_model=}")
2831
self.model = self.lcm_builder.build_model(lcm_model)
2932

3033
data = torch.load(

endure/ltune/model/builder.py

+43-36
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,65 @@
11
import torch
22
import logging
3-
from typing import Any, Optional
3+
from typing import Any, Optional, Tuple, Type
44
from torch import nn
5-
from reinmax import reinmax
5+
from endure.lsm.types import Policy
66

77
from endure.ltune.model import ClassicTuner, QLSMTuner, KapLSMTuner
8+
from endure.ltune.data.input_features import kINPUT_FEATS
89

910

1011
class LTuneModelBuilder:
11-
def __init__(self, config: dict[str, Any]):
12-
self._config = config
13-
self.log = logging.getLogger(self._config["log"]["name"])
12+
def __init__(
13+
self,
14+
hidden_length: int = 1,
15+
hidden_width: int = 64,
16+
norm_layer: str = "Batch",
17+
dropout: float = 0.0,
18+
categorical_mode: str = "gumbel",
19+
size_ratio_range: Tuple[int, int] = (2, 31),
20+
max_levels: int = 16,
21+
) -> None:
22+
self.hidden_length = hidden_length
23+
self.hidden_width = hidden_width
24+
self.dropout = dropout
25+
self.categorical_mode = categorical_mode
26+
self.max_levels = max_levels
27+
self.size_ratio_min, self.size_ratio_max = size_ratio_range
28+
self.capacity_range = self.size_ratio_max - self.size_ratio_min + 1
29+
30+
self.norm_layer = nn.BatchNorm1d
31+
if norm_layer == "Layer":
32+
self.norm_layer = nn.LayerNorm
33+
1434
self._models = {
15-
# "Tier": ClassicTuner,
16-
# "Level": ClassicTuner,
17-
"KLSM": KapLSMTuner,
18-
"Classic": ClassicTuner,
19-
"QLSM": QLSMTuner,
35+
Policy.Classic: ClassicTuner,
36+
Policy.QFixed: QLSMTuner,
37+
Policy.KHybrid: KapLSMTuner,
2038
}
2139

2240
def get_choices(self):
2341
return self._models.keys()
2442

25-
def build_model(self, choice: Optional[str] = None) -> torch.nn.Module:
26-
lsm_design: str = self._config["lsm"]["design"]
27-
if choice is None:
28-
choice = lsm_design
29-
30-
model_params = self._config["ltune"]["model"]
31-
capacity_range = (
32-
self._config["lsm"]["size_ratio"]["max"] -
33-
self._config["lsm"]["size_ratio"]["min"] + 1
34-
)
35-
args = {
36-
'num_feats': len(self._config["ltune"]["input_features"]),
37-
'capacity_range': capacity_range,
38-
'hidden_length': model_params["hidden_length"],
39-
'hidden_width': model_params["hidden_width"],
40-
'dropout_percentage': model_params["dropout"],
41-
}
43+
def build_model(self, policy: Policy) -> torch.nn.Module:
44+
feat_list = kINPUT_FEATS
4245

43-
if model_params["norm_layer"] == "Batch":
44-
args['norm_layer'] = nn.BatchNorm1d
45-
elif model_params["norm_layer"] == "Layer":
46-
args['norm_layer'] = nn.LayerNorm
46+
kwargs = {
47+
"num_feats": len(feat_list),
48+
"capacity_range": self.capacity_range,
49+
"hidden_length": self.hidden_length,
50+
"hidden_width": self.hidden_width,
51+
"dropout_percentage": self.dropout,
52+
"norm_layer": self.norm_layer,
53+
}
4754

48-
model_class = self._models.get(choice, None)
55+
model_class = self._models.get(policy, None)
4956
if model_class is None:
50-
raise NotImplementedError(f"Model for LSM Design not implemented yet")
57+
raise NotImplementedError(f"Tuner for LSM Design not implemented.")
5158

5259
if model_class is KapLSMTuner:
53-
args['num_kap'] = self._config['lsm']['max_levels']
54-
args['categorical_mode'] = model_params.get('categorical_mode', 'gumbel')
60+
kwargs["num_kap"] = self.max_levels
61+
kwargs["categorical_mode"] = self.categorical_mode
5562

56-
model = model_class(**args)
63+
model = model_class(**kwargs)
5764

5865
return model

endure/ltune/util/ltune_eval.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from endure.lcm.util import eval_lcm_impl
88
from endure.lsm.cost import EndureCost
99
from endure.lsm.types import LSMDesign, System, Policy, STR_POLICY_DICT
10-
from endure.ltune.data.generator import LTuneGenerator
10+
from endure.ltune.data.generator import LTuneDataGenerator
1111
from endure.ltune.loss import LearnedCostModelLoss
1212
import endure.lsm.solver as Solver
1313

@@ -20,7 +20,7 @@ def __init__(
2020
design_type: str = "Level",
2121
) -> None:
2222
self.policy = STR_POLICY_DICT.get(design_type, Policy.KHybrid)
23-
self.gen = LTuneGenerator(config)
23+
self.gen = LTuneDataGenerator()
2424
self.loss = LearnedCostModelLoss(
2525
config,
2626
config["job"]["LTuneTrain"]["loss_fn_path"]

jobs/ltune_data_gen.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pyarrow.parquet as pq
1010

1111
from endure.data.io import Reader
12-
from endure.ltune.data.generator import LTuneGenerator
12+
from endure.ltune.data.generator import LTuneDataGenerator
1313

1414

1515
class LTuneDataGenJob:
@@ -23,7 +23,7 @@ def __init__(self, config):
2323
)
2424

2525
def _choose_generator(self):
26-
return LTuneGenerator(self.config)
26+
return LTuneDataGenerator()
2727

2828
def generate_csv_file(self, generator, idx: int, pos: int) -> int:
2929
fname_prefix = self.setting["file_prefix"]
@@ -52,7 +52,7 @@ def generate_csv_file(self, generator, idx: int, pos: int) -> int:
5252
return idx
5353

5454
def generate_parquet_file(
55-
self, generator: LTuneGenerator, idx: int, pos: int
55+
self, generator: LTuneDataGenerator, idx: int, pos: int
5656
) -> int:
5757
fname_prefix = self.setting["file_prefix"]
5858
fname = f"{fname_prefix}-{idx:04}.parquet"
@@ -71,7 +71,7 @@ def generate_parquet_file(
7171
ncols=80,
7272
disable=self.config["log"]["disable_tqdm"],
7373
):
74-
table.append(generator.generate_row_parquet())
74+
table.append(generator.generate_row())
7575
table = pa.Table.from_pylist(table)
7676
pq.write_table(table, fpath)
7777

0 commit comments

Comments
 (0)