Skip to content

Commit

Permalink
[CP-SAT] more python type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Nov 20, 2024
1 parent 92cc57b commit 6c92b6b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
4 changes: 2 additions & 2 deletions ortools/sat/python/cp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3522,8 +3522,8 @@ def value(self, expression: LinearExprT) -> int:
if not self.has_response():
raise RuntimeError("solve() has not been called.")

value = 0
to_process = [(expression, 1)]
value: int = 0
to_process: list[tuple[LinearExprT, int]] = [(expression, 1)]
while to_process:
expr, coeff = to_process.pop()
if isinstance(expr, IntegralTypes):
Expand Down
20 changes: 13 additions & 7 deletions ortools/sat/python/cp_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def on_solution_callback(self) -> None:
self.__solution_count += 1

@property
def solution_count(self) -> None:
def solution_count(self) -> int:
return self.__solution_count


Expand All @@ -57,13 +57,13 @@ class SolutionObjective(cp_model.CpSolverSolutionCallback):

def __init__(self) -> None:
cp_model.CpSolverSolutionCallback.__init__(self)
self.__obj = 0
self.__obj: float = 0

def on_solution_callback(self) -> None:
self.__obj = self.objective_value

@property
def obj(self) -> None:
def obj(self) -> float:
return self.__obj


Expand All @@ -88,11 +88,11 @@ def on_solution_callback(self) -> None:
self.__bool_var_values.append(self.boolean_value(bool_var))

@property
def int_var_values(self) -> None:
def int_var_values(self) -> list[int]:
return self.__int_var_values

@property
def bool_var_values(self) -> None:
def bool_var_values(self) -> list[bool]:
return self.__bool_var_values


Expand Down Expand Up @@ -647,7 +647,10 @@ def testCircuit(self) -> None:
print("testCircuit")
model = cp_model.CpModel()
x = [model.new_bool_var(f"x{i}") for i in range(5)]
model.add_circuit((i, i + 1, x[i]) for i in range(5))
arcs: list[tuple[int, int, cp_model.LiteralT]] = [
(i, i + 1, x[i]) for i in range(5)
]
model.add_circuit(arcs)
self.assertLen(model.proto.variables, 5)
self.assertLen(model.proto.constraints, 1)
self.assertLen(model.proto.constraints[0].circuit.heads, 5)
Expand All @@ -659,7 +662,10 @@ def testMultipleCircuit(self) -> None:
print("testMultipleCircuit")
model = cp_model.CpModel()
x = [model.new_bool_var(f"x{i}") for i in range(5)]
model.add_multiple_circuit((i, i + 1, x[i]) for i in range(5))
arcs: list[tuple[int, int, cp_model.LiteralT]] = [
(i, i + 1, x[i]) for i in range(5)
]
model.add_multiple_circuit(arcs)
self.assertLen(model.proto.variables, 5)
self.assertLen(model.proto.constraints, 1)
self.assertLen(model.proto.constraints[0].routes.heads, 5)
Expand Down

0 comments on commit 6c92b6b

Please sign in to comment.