Skip to content

Commit

Permalink
Add new constraints to NL formulation, fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
k8culver committed Dec 24, 2024
1 parent 482f3ca commit 4780226
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 64 deletions.
10 changes: 4 additions & 6 deletions demo_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from dash import ALL, MATCH, Input, Output, State, ctx
from dash.exceptions import PreventUpdate

from demo_enums import SolverType
from src.demo_enums import SolverType
import src.employee_scheduling as employee_scheduling
import src.utils as utils
from demo_configs import (
Expand Down Expand Up @@ -322,8 +322,6 @@ def run_optimization(
availability = utils.availability_to_dict(sched_df["props"]["data"])
employees = list(availability.keys())

isolated_days_allowed = True if 0 in checklist else False

forecast = [
val if isinstance(val, int)
else forecast_placeholder[i]
Expand All @@ -335,10 +333,10 @@ def run_optimization(
shifts=shifts,
min_shifts=min(shifts_per_employee),
max_shifts=max(shifts_per_employee),
forecast,
allow_isolated_days_off=isolated_days_allowed,
shift_forecast=forecast,
allow_isolated_days_off=0 in checklist,
max_consecutive_shifts=consecutive_shifts,
num_full_time,
num_full_time=num_full_time,
)

if solver_type is SolverType.NL:
Expand Down
2 changes: 1 addition & 1 deletion demo_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"value": 12,
}

# number of full time employees slider (value means default)
# number of full-time employees slider (value means default)
NUM_FULL_TIME = {
"min": 0,
"max": 9,
Expand Down
2 changes: 1 addition & 1 deletion demo_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
THUMBNAIL,
UNAVAILABLE_ICON,
)
from demo_enums import SolverType
from src.demo_enums import SolverType
from src.utils import COL_IDS


Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
dash[diskcache]==2.16.1
dash-bootstrap-components==1.6.0
dwave-ocean-sdk>=7.0.0
dwave-optimization>=0.3.0
dwave-optimization>=0.4.0
Faker==21.0.0
pandas>=2.0
77 changes: 38 additions & 39 deletions src/employee_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dwave.optimization.symbols import BinaryVariable
from dwave.system import LeapHybridCQMSampler, LeapHybridNLSampler

from utils import DAYS, FULL_TIME_SHIFTS, SHIFTS, validate_nl_schedule
from src.utils import DAYS, FULL_TIME_SHIFTS, SHIFTS, validate_nl_schedule


MSGS = {
Expand All @@ -42,7 +42,7 @@
}


def build_cqm(#params: ModelParams
def build_cqm( # params: ModelParams
availability: dict[str, list[int]],
shifts: list[str],
min_shifts: int,
Expand All @@ -62,7 +62,7 @@ def build_cqm(#params: ModelParams
shift_forecast: A list of the number of expected employees needed per shift.
allow_isolated_days_off: Whether on-off-on should be allowed in the schedule.
max_consecutive_shifts: The maximum consectutive shifts to schedule a part-time employee for.
num_full_time: The number of full time employees.
num_full_time: The number of full-time employees.
Returns:
cqm: A Constrained Quadratic Model representing the problem.
Expand Down Expand Up @@ -118,7 +118,7 @@ def build_cqm(#params: ModelParams
)

for employee in employees_ft:
# Schedule employees for at most max_shifts
# Schedule full-time employees for all their shifts
cqm.add_constraint(
quicksum(x[employee, shift] for shift in shifts) <= FULL_TIME_SHIFTS,
label=f"overtime,{employee},",
Expand All @@ -129,7 +129,7 @@ def build_cqm(#params: ModelParams
label=f"insufficient,{employee},",
)

# Every shift needs shift_min and shift_max employees working
# Every shift needs shift_forecast employees working
for i, shift in enumerate(shifts):
cqm.add_constraint(
sum(x[employee, shift] for employee in employees) >= shift_forecast[i],
Expand Down Expand Up @@ -234,16 +234,15 @@ def run_cqm(cqm: ConstrainedQuadraticModel):
return feasible_sampleset, None


def build_nl(
def build_nl( # params: ModelParams
availability: dict[str, list[int]],
shifts: list[str],
min_shifts: int,
max_shifts: int,
shift_min: int,
shift_max: int,
requires_manager: bool,
shift_forecast: list,
allow_isolated_days_off: bool,
max_consecutive_shifts: int,
num_full_time: int,
) -> tuple[Model, BinaryVariable]:
"""Builds an employee scheduling nonlinear model.
Expand All @@ -252,11 +251,10 @@ def build_nl(
shifts (list[str]): Shift labels.
min_shifts (int): Minimum shifts per employee.
max_shifts (int): Maximum shifts per employee.
shift_min (int): Minimum employees per shift.
shift_max (int): Maximum employees per shift.
requires_manager (bool): Whether to require exactly one manager on every shift.
shift_forecast (list[int]): A list of the number of expected employees needed per shift.
allow_isolated_days_off (bool): Whether to allow isolated days off.
max_consecutive_shifts (int): Maximum consecutive shifts per employee.
num_full_time (int): The number of full-time employees.
Returns:
tuple[Model, BinaryVariable]: the NL model and assignments decision variable
Expand All @@ -281,8 +279,8 @@ def build_nl(
# Initialize model constants
min_shifts_constant = model.constant(min_shifts)
max_shifts_constant = model.constant(max_shifts)
shift_min_constant = model.constant(shift_min)
shift_max_constant = model.constant(shift_max)
full_time_shifts_constant = model.constant(FULL_TIME_SHIFTS)
shift_forecast_constant = model.constant(shift_forecast)
max_consecutive_shifts_c = model.constant(max_consecutive_shifts)
one_c = model.constant(1)

Expand All @@ -292,28 +290,32 @@ def build_nl(

# Objective: for infeasible solutions, focus on right number of shifts for employees
target_shifts = model.constant((min_shifts + max_shifts) / 2)
shift_difference_list = [
(assignments[e, :].sum() - target_shifts) ** 2 for e in range(num_employees)
shift_difference_list_pt = [
(assignments[e, :].sum() - target_shifts) ** 2 for e in range(num_full_time, num_employees)
]
obj += add(*shift_difference_list)
shift_difference_list_ft = [
(assignments[e, :].sum() - full_time_shifts_constant) ** 2 for e in range(num_full_time)
]
obj += add(*shift_difference_list_pt, *shift_difference_list_ft)

model.minimize(-obj)

# CONSTRAINTS:
# Only schedule employees when they're available
model.add_constraint((availability_const >= assignments).all())

for e in range(len(employees)):
# Schedule employees for at most max_shifts
model.add_constraint(assignments[e, :].sum() <= max_shifts_constant)
# Schedule part-time employees for at most max_shifts
model.add_constraint((assignments[num_full_time:, :].sum(axis=1) <= max_shifts_constant).all())

# Schedule employees for at least min_shifts
model.add_constraint(assignments[e, :].sum() >= min_shifts_constant)
# Schedule part-time employees for at least min_shifts
model.add_constraint((assignments[num_full_time:, :].sum(axis=1) >= min_shifts_constant).all())

if num_full_time:
# Schedule full-time employees for all their shifts
model.add_constraint((assignments[:num_full_time, :].sum(axis=1) == full_time_shifts_constant).all())

# Every shift needs shift_min and shift_max employees working
for s in range(num_shifts):
model.add_constraint(assignments[:, s].sum() <= shift_max_constant)
model.add_constraint(assignments[:, s].sum() >= shift_min_constant)
# shft_fcst = model.constant(shift_forecast)
model.add_constraint((assignments.sum(axis=0) == shift_forecast_constant).all())

managers_c = model.constant(
[employees.index(e) for e in employees if e[-3:] == "Mgr"]
Expand All @@ -326,7 +328,7 @@ def build_nl(
negthree_c = model.constant(-3)
zero_c = model.constant(0)
# Adding many small constraints greatly improves feasibility
for e in range(len(employees)):
for e in range(num_full_time, num_employees): # for part-time employees
for s1 in range(len(shifts) - 2):
s2, s3 = s1 + 1, s1 + 2
model.add_constraint(
Expand All @@ -337,12 +339,11 @@ def build_nl(
<= zero_c
)

if requires_manager:
for shift in range(len(shifts)):
model.add_constraint(assignments[managers_c][:, shift].sum() == one_c)
# At least 1 manager per shift
model.add_constraint((assignments[managers_c].sum(axis=0) >= one_c).all())

# Don't exceed max_consecutive_shifts
for e in range(num_employees):
# Don't exceed max_consecutive_shifts for part-time employees
for e in range(num_full_time, num_employees):
for s in range(num_shifts - max_consecutive_shifts + 1):
s_window = s + max_consecutive_shifts + 1
model.add_constraint(
Expand All @@ -368,12 +369,11 @@ def run_nl(
shifts: list[str],
min_shifts: int,
max_shifts: int,
shift_min: int,
shift_max: int,
requires_manager: bool,
shift_forecast: list[int],
allow_isolated_days_off: bool,
max_consecutive_shifts: int,
time_limit: int | None = None,
num_full_time: int,
time_limit: Optional[int] = None,
msgs: dict[str, tuple[str, str]] = MSGS,
) -> Optional[defaultdict[str, list[str]]]:
"""Solves the NL scheduling model and detects any errors.
Expand All @@ -395,11 +395,10 @@ def run_nl(
shifts,
min_shifts,
max_shifts,
shift_min,
shift_max,
requires_manager,
shift_forecast,
allow_isolated_days_off,
max_consecutive_shifts,
num_full_time,
)

# Return errors if any error message list is populated
Expand Down
27 changes: 11 additions & 16 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,23 +59,21 @@ class ModelParams:
shifts (list[str]): List of shift labels.
min_shifts (int): Min shifts per employee.
max_shifts (int): Max shifts per employee.
shift_min (int): Min employees per shift.
shift_max (int): Max employees per shift.
requires_manager (bool): Whether a manager is required on every shift.
shift_forecast (list[int]): The forecasted employees per shift requirements.
allow_isolated_days_off (bool): Whether isolated shifts off are allowed
(pattern of on-off-on).
max_consecutive_shifts (int): Max consecutive shifts for each employee.
num_full_time: The number of full-time employees.
"""

availability: dict[str, list[int]]
shifts: list[str]
min_shifts: int
max_shifts: int
shift_min: int
shift_max: int
requires_manager: bool
shift_forecast: list[int]
allow_isolated_days_off: bool
max_consecutive_shifts: int
num_full_time: int


def get_random_string(length):
Expand Down Expand Up @@ -380,11 +378,10 @@ def validate_nl_schedule(
shifts: list[str],
min_shifts: int,
max_shifts: int,
shift_min: int,
shift_max: int,
requires_manager: bool,
shift_forecast: list[int],
allow_isolated_days_off: bool,
max_consecutive_shifts: int,
num_full_time: int,
) -> defaultdict[str, list[str]]:
"""Detect any errors in a solved NL scheduling model.
Expand Down Expand Up @@ -438,11 +435,10 @@ def validate_nl_schedule(

_validate_availability(result, availability, employees, shift_labels, errors, msgs)
_validate_shifts_per_employee(result, employees, min_shifts, max_shifts, errors, msgs)
_validate_employees_per_shift(result, shift_min, shift_max, shift_labels, errors, msgs)
_validate_employees_per_shift(result, shift_forecast, shift_labels, errors, msgs)
_validate_max_consecutive_shifts(result, max_consecutive_shifts, employees, shift_labels, errors, msgs)
_validate_trainee_shifts(result, employees, shift_labels, errors, msgs)
if requires_manager:
_validate_requires_manager(result, employees, shift_labels, errors, msgs)
_validate_requires_manager(result, employees, shift_labels, errors, msgs)
if not allow_isolated_days_off:
_validate_isolated_days_off(result, employees, shift_labels, errors, msgs)

Expand Down Expand Up @@ -496,8 +492,7 @@ def _validate_shifts_per_employee(

def _validate_employees_per_shift(
results: np.ndarray,
shift_min: int,
shift_max: int,
shift_forecast: list[int],
shift_labels: list[int],
errors: defaultdict[str, list[str]],
msgs: dict[str, tuple[str, str]],
Expand All @@ -510,9 +505,9 @@ def _validate_employees_per_shift(

for s, day in enumerate(shift_labels):
num_employees = results[:, s].sum()
if num_employees < shift_min:
if num_employees < shift_forecast[s]:
errors[understaffed_key].append(understaffed_template.format(day=day))
elif num_employees > shift_max:
elif num_employees > shift_forecast[s]:
errors[overstaffed_key].append(overstaffed_template.format(day=day))
return errors

Expand Down

0 comments on commit 4780226

Please sign in to comment.