diff --git a/fleece-worker/__main__.py b/fleece-worker/__main__.py index 0f23fd0..c3aa71c 100644 --- a/fleece-worker/__main__.py +++ b/fleece-worker/__main__.py @@ -1,17 +1,19 @@ -from typing import List, Tuple, Optional -from fastapi import FastAPI, HTTPException, Request -from fleece_network import Peer, loads -from pydantic import BaseModel -import anyio -import uvicorn -from .worker import Worker -from .__init__ import __version__ import argparse -import requests +import concurrent.futures import json +from typing import List, Optional, Tuple + +import anyio +import requests import torch -import concurrent.futures +import uvicorn from anyio.from_thread import BlockingPortal +from fastapi import FastAPI, HTTPException, Request +from fleece_network import Peer, loads +from pydantic import BaseModel + +from .__init__ import __version__ +from .worker import Worker app = FastAPI() worker = Worker() @@ -176,7 +178,8 @@ async def main() -> None: json=data, headers={"api-token": worker.api_token}) res = json.loads(r.content) - worker.worker_id = res["id"] + # worker.worker_id = res["id"] + worker.worker_id = worker_url worker.pull_worker_url() worker.start_heartbeat_daemon() worker.start_layer_forward_engine() diff --git a/fleece-worker/worker.py b/fleece-worker/worker.py index 5bbeda6..6b66323 100644 --- a/fleece-worker/worker.py +++ b/fleece-worker/worker.py @@ -1,20 +1,22 @@ -from typing import List, Optional, Tuple, Dict, Any, Set +import concurrent.futures +import json import os -import torch -from torch import Tensor, nn -from .model import ModelArgs, TransformerBlock, RMSNorm, precompute_freqs_cis -from fleece_network import Peer, dumps -import requests +import queue +import socket import threading -import concurrent.futures import time -import socket +import traceback +from typing import Any, Dict, List, Optional, Set, Tuple from urllib.parse import urlparse -import json -from cryptography.hazmat.primitives.asymmetric import ec + +import requests +import torch from cryptography.hazmat.primitives import hashes -import queue -import traceback +from cryptography.hazmat.primitives.asymmetric import ec +from fleece_network import Peer, dumps +from torch import Tensor, nn + +from .model import ModelArgs, RMSNorm, TransformerBlock, precompute_freqs_cis torch.set_default_device("cpu") @@ -195,13 +197,14 @@ def del_tensor(t): executor = concurrent.futures.ThreadPoolExecutor(max_workers=400) executor_forward = concurrent.futures.ThreadPoolExecutor(max_workers=40) +latency = 0.001 def requests_post(url, headers=None, data=None, json=None, worker=None, to_worker_id=None): try: if to_worker_id is not None: st = time.monotonic() - # time.sleep(0.01) + time.sleep(latency) r = requests.post(url, headers=headers, data=data, json=json) assert r.status_code == 200 if to_worker_id is not None: @@ -410,10 +413,11 @@ def pull_worker_url(self): for worker in res["workers"]: self.worker_urls[worker["worker_id"]] = worker["url"] - def get_worker_url(self, worker_id): - if worker_id not in self.worker_urls: - self.pull_worker_url() - return self.worker_urls.get(worker_id) + def get_worker_url(self, worker_id: str) -> str: + return worker_id + # if worker_id not in self.worker_urls: + # self.pull_worker_url() + # return self.worker_urls.get(worker_id) def verify(self, tm_url, task_id, plan, timestamp, signature_hex): public_key_bytes = bytes.fromhex(self.tm_pubkeys[tm_url]) @@ -635,6 +639,50 @@ def post_layer_forward_engine_step(self, task_list: List[LayerForward], merged_h return task_update_list def layer_forward_engine(self): + q = self.layer_forward_engine_queue + while True: + q_buffered: list[list[LayerForward]] = [q.get()] + while True: + try: + tasks = q.get(block=False) + q_buffered.append(tasks) + except queue.Empty: + break + prefill_tasks_list = [tasks for tasks in q_buffered if tasks[0].seqlen > 1] + decode_tasks_list = [tasks for tasks in q_buffered if tasks[0].seqlen == 1] + + for tasks in prefill_tasks_list: + h = self.layer_forward_engine_step(tasks) + task_update_list = self.post_layer_forward_engine_step(tasks, h) + tmp_len = sum([len(task[4]) for task in task_update_list]) + print(time.monotonic(), len(tasks), sum([task.bsz for task in tasks]), tmp_len) + executor_forward.submit( + self.tmptmp, + task_update_list + ) + + decode_tasks_list.sort(key=lambda x: x[0].bsz, reverse=False) + while len(decode_tasks_list) > 0: + total_bsz = 0 + task_list = [] + for i in reversed(range(len(decode_tasks_list))): + print(i) + cur_bsz = sum([task.bsz for task in decode_tasks_list[i]]) + if total_bsz + cur_bsz > MAX_TOTAL_BSZ: + continue + total_bsz += cur_bsz + task_list.extend(decode_tasks_list.pop(i)) + h = self.layer_forward_engine_step(task_list) + task_update_list = self.post_layer_forward_engine_step(task_list, h) + tmp_len = sum([len(task[4]) for task in task_update_list]) + print(time.monotonic(), len(task_list), sum([task.bsz for task in task_list]), tmp_len) + executor_forward.submit( + self.tmptmp, + task_update_list + ) + + + def layer_forward_engine_old(self): q = self.layer_forward_engine_queue while True: task_list = []