Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use absolute paths instead of relying on HF cache #506

Merged
merged 2 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ Make sure you are in [this](./) directory.
- Under `Advanced Options > Init Scripts`, upload the init script from your workspace.
- Under environment variables, set:
- `FRAMEWORK=torch` or `FRAMEWORK=tf` based on the notebook used.
- `HF_HOME=/dbfs/FileStore/hf_home` to cache Huggingface models in DBFS.
- `TF_GPU_ALLOCATOR=cuda_malloc_async` to implicity release unused GPU memory in Tensorflow notebooks.


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,44 @@
"on_standalone = not (on_databricks or on_dataproc)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# For cloud environments, load the model to the distributed file system.\n",
"if on_databricks:\n",
" models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
" dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
" model_path = f\"{models_dir}/deepseek-r1-distill-llama-8b\"\n",
"elif on_dataproc:\n",
" models_dir = \"/mnt/gcs/spark-dl-models\"\n",
" os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
" model_path = f\"{models_dir}/deepseek-r1-distill-llama-8b\"\n",
"else:\n",
" model_path = os.path.abspath(\"deepseek-r1-distill-llama-8b\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For cloud environments, set the huggingface cache dir to DBFS/GCS so that executors can load the model from cache."
"Download the model from huggingface hub."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if on_databricks:\n",
" hf_home = \"/dbfs/FileStore/hf_home\"\n",
" dbutils.fs.mkdirs(hf_home)\n",
" os.environ[\"HF_HOME\"] = hf_home\n",
"elif on_dataproc:\n",
" hf_home = \"/mnt/gcs/hf_home\"\n",
" os.mkdir(hf_home) if not os.path.exists(hf_home) else None\n",
" os.environ[\"HF_HOME\"] = hf_home"
"from huggingface_hub import snapshot_download\n",
"\n",
"model_path = snapshot_download(\n",
" repo_id=\"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\",\n",
" local_dir=model_path\n",
")"
]
},
{
Expand Down Expand Up @@ -108,7 +125,7 @@
"import torch\n",
"from transformers import pipeline\n",
"\n",
"pipe = pipeline(\"text-generation\", model=\"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\", torch_dtype=torch.bfloat16, device=\"cuda\")"
"pipe = pipeline(\"text-generation\", model=model_path, torch_dtype=torch.bfloat16, device=\"cuda\")"
]
},
{
Expand Down Expand Up @@ -307,7 +324,6 @@
" conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
" elif on_dataproc:\n",
" conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
" conf.set(\"spark.executorEnv.HF_HOME\", hf_home)\n",
"\n",
" conf.set(\"spark.executor.cores\", \"8\")\n",
" conf.set(\"spark.task.maxFailures\", \"1\")\n",
Expand Down Expand Up @@ -445,7 +461,7 @@
"metadata": {},
"outputs": [],
"source": [
"def triton_server(ports):\n",
"def triton_server(ports, model_path):\n",
" import time\n",
" import signal\n",
" import numpy as np\n",
Expand All @@ -458,7 +474,7 @@
"\n",
" print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" pipe = pipeline(\"text-generation\", model=\"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\", torch_dtype=torch.bfloat16, device=device)\n",
" pipe = pipeline(\"text-generation\", model=model_path, torch_dtype=torch.bfloat16, device=device)\n",
" print(f\"SERVER: Using {device} device.\")\n",
"\n",
" @batch\n",
Expand Down Expand Up @@ -543,7 +559,7 @@
"outputs": [],
"source": [
"model_name = \"deepseek-r1\"\n",
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name)"
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name, model_path=model_path)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,43 +45,48 @@
"on_standalone = not (on_databricks or on_dataproc)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For cloud environments, set the huggingface cache dir to DBFS/GCS so that executors can load the model from cache."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# For cloud environments, load the model to the distributed file system.\n",
"if on_databricks:\n",
" hf_home = \"/dbfs/FileStore/hf_home\"\n",
" dbutils.fs.mkdirs(hf_home)\n",
" os.environ[\"HF_HOME\"] = hf_home\n",
" models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
" dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
" model_path = f\"{models_dir}/gemma-7b-it\"\n",
"elif on_dataproc:\n",
" hf_home = \"/mnt/gcs/hf_home\"\n",
" os.mkdir(hf_home) if not os.path.exists(hf_home) else None\n",
" os.environ[\"HF_HOME\"] = hf_home"
" models_dir = \"/mnt/gcs/spark-dl-models\"\n",
" os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
" model_path = f\"{models_dir}/gemma-7b-it\"\n",
"else:\n",
" model_path = os.path.abspath(\"gemma-7b-it\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Warmup: Running locally\n",
"First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import login\n",
"\n",
"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
"login()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub."
"Once you have access, you can download the model:"
]
},
{
Expand All @@ -90,9 +95,22 @@
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import login\n",
"from huggingface_hub import snapshot_download\n",
"\n",
"login()"
"model_path = snapshot_download(\n",
" repo_id=\"google/gemma-7b-it\",\n",
" local_dir=model_path,\n",
" ignore_patterns=\"*.gguf\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Warmup: Running locally\n",
"\n",
"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
]
},
{
Expand All @@ -119,8 +137,8 @@
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b-it\")\n",
"model = AutoModelForCausalLM.from_pretrained(\"google/gemma-7b-it\",\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"model = AutoModelForCausalLM.from_pretrained(model_path,\n",
" device_map=\"auto\",\n",
" torch_dtype=torch.bfloat16)"
]
Expand Down Expand Up @@ -246,7 +264,6 @@
" conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
" elif on_dataproc:\n",
" conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
" conf.set(\"spark.executorEnv.HF_HOME\", hf_home)\n",
"\n",
" conf.set(\"spark.executor.cores\", \"8\")\n",
" conf.set(\"spark.task.maxFailures\", \"1\")\n",
Expand Down Expand Up @@ -421,7 +438,7 @@
"metadata": {},
"outputs": [],
"source": [
"def triton_server(ports):\n",
"def triton_server(ports, model_path):\n",
" import time\n",
" import signal\n",
" import numpy as np\n",
Expand All @@ -434,8 +451,8 @@
"\n",
" print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b-it\")\n",
" model = AutoModelForCausalLM.from_pretrained(\"google/gemma-7b-it\", device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
" tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
" model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
" print(f\"SERVER: Using {device} device.\")\n",
"\n",
" @batch\n",
Expand Down Expand Up @@ -523,7 +540,7 @@
"outputs": [],
"source": [
"model_name = \"gemma-7b\"\n",
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name)"
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name, model_path=model_path)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,34 +75,51 @@
"on_standalone = not (on_databricks or on_dataproc)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# For cloud environments, load the model to the distributed file system.\n",
"if on_databricks:\n",
" models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
" dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
" model_path = f\"{models_dir}/qwen-2.5-7b\"\n",
"elif on_dataproc:\n",
" models_dir = \"/mnt/gcs/spark-dl-models\"\n",
" os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
" model_path = f\"{models_dir}/qwen-2.5-7b\"\n",
"else:\n",
" model_path = os.path.abspath(\"qwen-2.5-7b\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For cloud environments, set the huggingface cache dir to DBFS/GCS so that executors can load the model from cache."
"Download the model from huggingface hub."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if on_databricks:\n",
" hf_home = \"/dbfs/FileStore/hf_home\"\n",
" dbutils.fs.mkdirs(hf_home)\n",
" os.environ[\"HF_HOME\"] = hf_home\n",
"elif on_dataproc:\n",
" hf_home = \"/mnt/gcs/hf_home\"\n",
" os.mkdir(hf_home) if not os.path.exists(hf_home) else None\n",
" os.environ[\"HF_HOME\"] = hf_home"
"from huggingface_hub import snapshot_download\n",
"\n",
"model_path = snapshot_download(\n",
" repo_id=\"Qwen/Qwen2.5-7B-Instruct\",\n",
" local_dir=model_path\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Warmup: Running locally\n",
"## Warmup: Running locally\n",
"\n",
"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
]
Expand Down Expand Up @@ -131,14 +148,12 @@
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"model_name = \"Qwen/Qwen2.5-7B-Instruct\"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" model_path,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\"\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=\"left\")"
"tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"left\")"
]
},
{
Expand Down Expand Up @@ -309,7 +324,6 @@
" conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
" elif on_dataproc:\n",
" conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
" conf.set(\"spark.executorEnv.HF_HOME\", hf_home)\n",
"\n",
" conf.set(\"spark.executor.cores\", \"8\")\n",
" conf.set(\"spark.task.maxFailures\", \"1\")\n",
Expand Down Expand Up @@ -508,7 +522,7 @@
"metadata": {},
"outputs": [],
"source": [
"def triton_server(ports):\n",
"def triton_server(ports, model_path):\n",
" import time\n",
" import signal\n",
" import torch\n",
Expand All @@ -521,11 +535,11 @@
"\n",
" print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" \"Qwen/Qwen2.5-7B-Instruct\",\n",
" model_path,\n",
" torch_dtype=torch.bfloat16,\n",
" device_map=\"auto\"\n",
" )\n",
" tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\", padding_side=\"left\")\n",
" tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"left\")\n",
"\n",
" @batch\n",
" def _infer_fn(**inputs):\n",
Expand Down Expand Up @@ -610,7 +624,7 @@
"outputs": [],
"source": [
"model_name = \"qwen-2.5\"\n",
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name)"
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name, model_path=model_path)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def _start_triton_server(
) -> List[tuple]:
"""Task to start Triton server process on a Spark executor."""

from pyspark import BarrierTaskContext

def _prepare_pytriton_env():
"""Expose PyTriton to correct libpython3.11.so and Triton bundled libraries."""
ld_library_paths = []
Expand Down Expand Up @@ -82,6 +84,7 @@ def _find_ports(start_port: int = 7000) -> List[int]:

return ports

tc = BarrierTaskContext.get()
ports = _find_ports()
sig = inspect.signature(triton_server_fn)
params = sig.parameters
Expand All @@ -105,6 +108,7 @@ def _find_ports(start_port: int = 7000) -> List[int]:
for _ in range(wait_retries):
try:
client.wait_for_model(wait_timeout)
tc.barrier()
client.close()
return [(hostname, (process.pid, ports))]
except Exception:
Expand Down