Skip to content

Commit

Permalink
Merge pull request #370 from argonne-lcf/server-sort
Browse files Browse the repository at this point in the history
Server sort
  • Loading branch information
cms21 authored Aug 1, 2023
2 parents 1d23048 + ff9dd8c commit 51b9c7d
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 13 deletions.
2 changes: 2 additions & 0 deletions balsam/_api/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def acquire_jobs(
max_nodes_per_job: Optional[int] = None,
max_aggregate_nodes: Optional[float] = None,
serial_only: bool = False,
sort_by: Optional[str] = None,
filter_tags: Optional[Dict[str, str]] = None,
states: Set[JobState] = RUNNABLE_STATES,
app_ids: Optional[Set[int]] = None,
Expand All @@ -385,6 +386,7 @@ def acquire_jobs(
max_nodes_per_job=max_nodes_per_job,
max_aggregate_nodes=max_aggregate_nodes,
serial_only=serial_only,
sort_by=sort_by,
filter_tags=filter_tags,
states=states,
app_ids=app_ids,
Expand Down
1 change: 1 addition & 0 deletions balsam/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class LauncherSettings(BaseSettings):
local_app_launcher: Type[AppRun] = Field("balsam.platform.app_run.LocalAppRun")
mpirun_allows_node_packing: bool = False
serial_mode_prefetch_per_rank: int = 64
sort_by: Optional[str] = None
serial_mode_startup_params: Dict[str, str] = {"cpu_affinity": "none"}

@validator("compute_node", pre=True, always=True)
Expand Down
1 change: 1 addition & 0 deletions balsam/schemas/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class SessionAcquire(BaseModel):
max_nodes_per_job: Optional[int]
max_aggregate_nodes: Optional[float]
serial_only: bool = False
sort_by: Optional[str] = None
filter_tags: Dict[str, str]
states: Set[JobState] = RUNNABLE_STATES
app_ids: Set[int] = set()
Expand Down
59 changes: 48 additions & 11 deletions balsam/server/models/crud/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def create(db: Session, owner: schemas.UserOut, session: schemas.SessionCreate)
def _acquire_jobs(db: orm.Session, job_q: Select, session: models.Session) -> List[Dict[str, Any]]:
acquired_jobs = [{str(key): value for key, value in job.items()} for job in db.execute(job_q).mappings()]
acquired_ids = [job["id"] for job in acquired_jobs]
# logger.info(f"*** in _acquire_jobs acquired_ids={acquired_ids}")

stmt = update(models.Job.__table__).where(models.Job.id.in_(acquired_ids)).values(session_id=session.id)

Expand All @@ -130,7 +131,7 @@ def _acquire_jobs(db: orm.Session, job_q: Select, session: models.Session) -> Li
return acquired_jobs


def _footprint_func() -> Any:
def _footprint_func_nodes() -> Any:
footprint = cast(models.Job.num_nodes, Float) / cast(models.Job.node_packing_count, Float)
return (
func.sum(footprint)
Expand All @@ -146,6 +147,22 @@ def _footprint_func() -> Any:
)


def _footprint_func_walltime() -> Any:
footprint = cast(models.Job.num_nodes, Float) / cast(models.Job.node_packing_count, Float)
return (
func.sum(footprint)
.over(
order_by=(
models.Job.wall_time_min.desc(),
models.Job.num_nodes.desc(),
models.Job.node_packing_count.desc(),
models.Job.id.asc(),
)
)
.label("aggregate_footprint")
)


def acquire(
db: Session, owner: schemas.UserOut, session_id: int, spec: schemas.SessionAcquire
) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -182,21 +199,41 @@ def acquire(
return _acquire_jobs(db, job_q, session)

# MPI Mode Launcher will take this path:
lock_ids_q = (
job_q.with_only_columns([models.Job.id])
.order_by(
models.Job.num_nodes.asc(),
models.Job.node_packing_count.desc(),
models.Job.wall_time_min.desc(),
# logger.info(f"*** In session.acquire: spec.sort_by = {spec.sort_by}")
if spec.sort_by == "long_large_first":
lock_ids_q = (
job_q.with_only_columns([models.Job.id])
.order_by(
models.Job.wall_time_min.desc(),
models.Job.num_nodes.desc(),
models.Job.node_packing_count.desc(),
)
.limit(spec.max_num_jobs)
.with_for_update(of=models.Job.__table__, skip_locked=True)
)
.limit(spec.max_num_jobs)
.with_for_update(of=models.Job.__table__, skip_locked=True)
)
else:
lock_ids_q = (
job_q.with_only_columns([models.Job.id])
.order_by(
models.Job.num_nodes.asc(),
models.Job.node_packing_count.desc(),
models.Job.wall_time_min.desc(),
)
.limit(spec.max_num_jobs)
.with_for_update(of=models.Job.__table__, skip_locked=True)
)

locked_ids = db.execute(lock_ids_q).scalars().all()
# logger.info(f"*** locked_ids: {locked_ids}")
if spec.sort_by == "long_large_first":
subq = select(models.Job.__table__, _footprint_func_walltime()).where(models.Job.id.in_(locked_ids)).subquery() # type: ignore
else:
subq = select(models.Job.__table__, _footprint_func_nodes()).where(models.Job.id.in_(locked_ids)).subquery() # type: ignore

subq = select(models.Job.__table__, _footprint_func()).where(models.Job.id.in_(locked_ids)).subquery() # type: ignore
# logger.info(f"*** max_aggregate_nodes: {spec.max_aggregate_nodes}")
cols = [c for c in subq.c if c.name not in ["aggregate_footprint", "session_id"]]
job_q = select(cols).where(subq.c.aggregate_footprint <= spec.max_aggregate_nodes)

return _acquire_jobs(db, job_q, session)


Expand Down
6 changes: 6 additions & 0 deletions balsam/site/job_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
filter_tags: Optional[Dict[str, str]] = None,
states: Set[str] = {"PREPROCESSED", "RESTART_READY"},
serial_only: bool = False,
sort_by: Optional[str] = None,
max_wall_time_min: Optional[int] = None,
max_nodes_per_job: Optional[int] = None,
max_aggregate_nodes: Optional[float] = None,
Expand All @@ -90,6 +91,7 @@ def __init__(
self.app_ids = set() if app_ids is None else app_ids
self.states = states
self.serial_only = serial_only
self.sort_by = sort_by
self.max_wall_time_min = max_wall_time_min
self.max_nodes_per_job = max_nodes_per_job
self.max_aggregate_nodes = max_aggregate_nodes
Expand Down Expand Up @@ -158,6 +160,7 @@ def _get_acquire_parameters(self, num_jobs: int) -> Dict[str, Any]:
max_aggregate_nodes=self.max_aggregate_nodes,
max_wall_time_min=request_time,
serial_only=self.serial_only,
sort_by=self.sort_by,
filter_tags=self.filter_tags,
states=self.states,
app_ids=self.app_ids,
Expand All @@ -182,6 +185,7 @@ def __init__(
filter_tags: Optional[Dict[str, str]] = None,
states: Set[JobState] = {JobState.preprocessed, JobState.restart_ready},
serial_only: bool = False,
sort_by: Optional[str] = None,
max_wall_time_min: Optional[int] = None,
scheduler_id: Optional[int] = None,
app_ids: Optional[Set[int]] = None,
Expand All @@ -192,6 +196,7 @@ def __init__(
self.app_ids = set() if app_ids is None else app_ids
self.states = states
self.serial_only = serial_only
self.sort_by = sort_by
self.max_wall_time_min = max_wall_time_min
self.start_time = time.time()

Expand Down Expand Up @@ -229,6 +234,7 @@ def get_jobs(
max_aggregate_nodes=max_aggregate_nodes,
max_wall_time_min=request_time,
serial_only=self.serial_only,
sort_by=self.sort_by,
filter_tags=self.filter_tags,
states=self.states,
app_ids=self.app_ids,
Expand Down
3 changes: 3 additions & 0 deletions balsam/site/launcher/_mpi_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def acquire_jobs(self) -> List["Job"]:
def launch_runs(self) -> None:
acquired = self.acquire_jobs()
acquired.extend(self.job_stash)
logger.info(f"acquired jobs: {acquired}")
self.job_stash = []
for job in acquired:
assert job.id is not None
Expand Down Expand Up @@ -271,10 +272,12 @@ def main(
)

scheduler_id = node_cls.get_scheduler_id()

job_source = SynchronousJobSource(
client=site_config.client,
site_id=site_config.site_id,
filter_tags=filter_tags_dict,
sort_by=site_config.settings.launcher.sort_by,
max_wall_time_min=wall_time_min,
scheduler_id=scheduler_id,
)
Expand Down
1 change: 1 addition & 0 deletions balsam/site/launcher/_serial_mode_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def master_main(wall_time_min: int, master_port: int, log_filename: str, num_wor
max_wall_time_min=wall_time_min,
scheduler_id=scheduler_id,
serial_only=True,
sort_by=site_config.settings.launcher.sort_by,
max_nodes_per_job=1,
)
status_updater = BulkStatusUpdater(site_config.client)
Expand Down
2 changes: 1 addition & 1 deletion tests/server/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_unauth_user_cannot_view_sites(anon_client):
def test_register(anon_client):
login_credentials = {"username": f"user{uuid4()}", "password": "foo"}
resp = anon_client.post("/" + urls.PASSWORD_REGISTER, **login_credentials)
assert type(resp["id"]) == int
assert isinstance(resp["id"], int)
assert resp["username"] == login_credentials["username"]


Expand Down
2 changes: 1 addition & 1 deletion tests/server/test_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_create_site(auth_client):
name="thetalogin3.alcf.anl.gov",
path="/projects/myProject/balsam-site",
)
assert type(posted_site["id"]) == int
assert isinstance(posted_site["id"], int)
site_list = auth_client.get("/sites/")["results"]
assert isinstance(site_list, list)
assert len(site_list) == 1
Expand Down

0 comments on commit 51b9c7d

Please sign in to comment.