Skip to content

Commit

Permalink
Merge branch 'master' into macos-support
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov authored Aug 27, 2023
2 parents 903c102 + d90a14d commit 7aff062
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ of [Go toolchain](https://golang.org/doc/install) (1.15 or 1.16 are supported).

- __Linux__ is the default OS for which hivemind is developed and tested. We recommend Ubuntu 18.04+ (64-bit), but
other 64-bit distros should work as well. Legacy 32-bit is not recommended.
- __macOS 10.x__ can run hivemind using [Docker](https://docs.docker.com/desktop/mac/install/).
- __macOS__ is partially supported.
If you have issues, you can run hivemind using [Docker](https://docs.docker.com/desktop/mac/install/) instead.
We recommend using [our Docker image](https://hub.docker.com/r/learningathome/hivemind).
- __Windows 10+ (experimental)__ can run hivemind
using [WSL](https://docs.microsoft.com/ru-ru/windows/wsl/install-win10). You can configure WSL to use GPU by
Expand Down
20 changes: 11 additions & 9 deletions hivemind/moe/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from queue import SimpleQueue
from selectors import EVENT_READ, DefaultSelector
from statistics import mean
from time import time
from typing import Dict, NamedTuple, Optional
from time import perf_counter
from typing import Any, Dict, NamedTuple, Optional, Tuple

import torch
from prefetch_generator import BackgroundGenerator

from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.utils import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -85,15 +86,11 @@ def run(self):

for pool, batch_index, batch in batch_iterator:
logger.debug(f"Processing batch {batch_index} from pool {pool.name}")

start = time()
start = perf_counter()
try:
outputs = pool.process_func(*batch)
batch_processing_time = time() - start

batch_size = outputs[0].size(0)
outputs, batch_size = self.process_batch(pool, batch_index, *batch)
batch_processing_time = perf_counter() - start
logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")

if self.stats_report_interval is not None:
self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)

Expand All @@ -108,6 +105,11 @@ def run(self):
if not self.shutdown_trigger.is_set():
self.shutdown()

def process_batch(self, pool: TaskPoolBase, batch_index: int, *batch: torch.Tensor) -> Tuple[Any, int]:
"""process one batch of tasks from a given pool, return a batch of results and total batch size"""
outputs = pool.process_func(*batch)
return outputs, outputs[0].size(0)

def shutdown(self):
"""Gracefully terminate a running runtime."""
logger.info("Shutting down")
Expand Down

0 comments on commit 7aff062

Please sign in to comment.