Skip to content

Commit

Permalink
ci: update test workflow to use CPU-specific dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
maxftz committed Feb 5, 2025
1 parent a7190da commit 01364f7
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
run: uv venv

- name: Install dependencies
run: uv pip install -r requirements_intel_cuda.txt
run: uv pip install -r requirements_intel_cpu.txt

- name: Run tests
run: uv run pytest
41 changes: 41 additions & 0 deletions requirements_intel_cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
--extra-index-url https://download.pytorch.org/whl/cpu
aiofiles==24.1.0
aiosqlite==0.21.0
annotated-types==0.7.0
asyncua==1.1.5
cffi==1.17.1
contourpy==1.2.1
cryptography==44.0.0
cycler==0.12.1
filelock==3.13.1
fonttools==4.53.0
fsspec==2024.2.0
iniconfig==2.0.0
jinja2==3.1.4
kiwisolver==1.4.5
markupsafe==2.1.5
matplotlib==3.9.0
mpmath==1.3.0
networkx==3.2.1
numpy==2.1.2
packaging==24.1
pillow==10.2.0
pluggy==1.5.0
pycparser==2.22
pydantic==2.8.0
pydantic-core==2.20.0
pyopenssl==25.0.0
pyparsing==3.1.2
pytest==8.3.4
python-dateutil==2.9.0.post0
pytz==2024.1
returns==0.23.0
setuptools==70.2.0
six==1.16.0
sortedcontainers==2.4.0
sympy==1.12
torch==2.3.1+cpu
torchaudio==2.3.1+cpu
torchvision==0.18.1+cpu
triton==3.2.0
typing-extensions==4.12.2
6 changes: 3 additions & 3 deletions src/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from io_models import CostConstraintMatrix, IncorrectInput, UserInput, UserOutput
from utils.logging import LogFile, create_logger, process_start_time
from utils.types import CostConstraintMatricesPath, UserInputPath
from utils.utils import random_string
from utils.utils import cuda_else_cpu, random_string

logger_params = {"log_file": LogFile(desc="assign")}

Expand Down Expand Up @@ -205,8 +205,8 @@ def load_cost_constraint_matrices(
return Success(
[
CostConstraintMatrix(
input[i][0].cuda(),
input[i][1].cuda(),
cuda_else_cpu(input[i][0]),
cuda_else_cpu(input[i][1]),
)
for i in range(len(input))
]
Expand Down
8 changes: 6 additions & 2 deletions src/io_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from returns.result import Failure, Result, Success
from torch import Tensor

from utils.utils import cuda_else_cpu

class IncorrectInput(Exception):

__slots__ = ("message", "input")
Expand Down Expand Up @@ -64,7 +66,7 @@ def to_cost_constraint_matrix(self) -> CostConstraintMatrix:
problem_size = num_resources * num_tasks
batch_size = 1

cost_matrix = torch.zeros(batch_size, problem_size, 3).cuda()
cost_matrix = cuda_else_cpu(torch.zeros(batch_size, problem_size, 3))
for i, resource in enumerate(self.resources):
for j, task in enumerate(self.tasks):
cost = next(
Expand All @@ -74,7 +76,9 @@ def to_cost_constraint_matrix(self) -> CostConstraintMatrix:
cost_matrix[0, i * num_tasks + j, 1] = i
cost_matrix[0, i * num_tasks + j, 2] = j

constraint_matrix = torch.zeros(batch_size, num_tasks, num_resources).cuda()
constraint_matrix = cuda_else_cpu(
torch.zeros(batch_size, num_tasks, num_resources)
)

for i, task in enumerate(self.tasks):
for constraint in task.constraints:
Expand Down
12 changes: 11 additions & 1 deletion src/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import random
import string

import torch


def safe_div(x,y):
if y == 0:
return None
return x / y

def random_string(length):
return ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))
return ''.join(random.choices(string.ascii_uppercase + string.digits, k=length))


def cuda_else_cpu(tensor: torch.Tensor) -> torch.Tensor:
if torch.backends.mps.is_available():
return tensor.cuda()
if torch.backends.cuda.is_built():
return tensor.cuda()
return tensor.cpu()

0 comments on commit 01364f7

Please sign in to comment.