Skip to content

Commit 8c934b9

Browse files
authored
Use absolute paths instead of relying on HF cache (#506)
Relying on Huggingface cache to load LLMs is finicky on DBFS and when an access token is needed. Switch to downloading the model and using the abs path, which is better practice anyway. Added barrier to pytriton server startup to ensure all servers are shut down if one fails. --------- Signed-off-by: Rishi Chandra <[email protected]>
1 parent 5d52c1b commit 8c934b9

File tree

5 files changed

+111
-61
lines changed

5 files changed

+111
-61
lines changed

examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/README.md

-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ Make sure you are in [this](./) directory.
5454
- Under `Advanced Options > Init Scripts`, upload the init script from your workspace.
5555
- Under environment variables, set:
5656
- `FRAMEWORK=torch` or `FRAMEWORK=tf` based on the notebook used.
57-
- `HF_HOME=/dbfs/FileStore/hf_home` to cache Huggingface models in DBFS.
5857
- `TF_GPU_ALLOCATOR=cuda_malloc_async` to implicity release unused GPU memory in Tensorflow notebooks.
5958

6059

examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/deepseek-r1_torch.ipynb

+31-15
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,44 @@
4545
"on_standalone = not (on_databricks or on_dataproc)"
4646
]
4747
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": 2,
51+
"metadata": {},
52+
"outputs": [],
53+
"source": [
54+
"# For cloud environments, load the model to the distributed file system.\n",
55+
"if on_databricks:\n",
56+
" models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
57+
" dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
58+
" model_path = f\"{models_dir}/deepseek-r1-distill-llama-8b\"\n",
59+
"elif on_dataproc:\n",
60+
" models_dir = \"/mnt/gcs/spark-dl-models\"\n",
61+
" os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
62+
" model_path = f\"{models_dir}/deepseek-r1-distill-llama-8b\"\n",
63+
"else:\n",
64+
" model_path = os.path.abspath(\"deepseek-r1-distill-llama-8b\")"
65+
]
66+
},
4867
{
4968
"cell_type": "markdown",
5069
"metadata": {},
5170
"source": [
52-
"For cloud environments, set the huggingface cache dir to DBFS/GCS so that executors can load the model from cache."
71+
"Download the model from huggingface hub."
5372
]
5473
},
5574
{
5675
"cell_type": "code",
57-
"execution_count": 2,
76+
"execution_count": null,
5877
"metadata": {},
5978
"outputs": [],
6079
"source": [
61-
"if on_databricks:\n",
62-
" hf_home = \"/dbfs/FileStore/hf_home\"\n",
63-
" dbutils.fs.mkdirs(hf_home)\n",
64-
" os.environ[\"HF_HOME\"] = hf_home\n",
65-
"elif on_dataproc:\n",
66-
" hf_home = \"/mnt/gcs/hf_home\"\n",
67-
" os.mkdir(hf_home) if not os.path.exists(hf_home) else None\n",
68-
" os.environ[\"HF_HOME\"] = hf_home"
80+
"from huggingface_hub import snapshot_download\n",
81+
"\n",
82+
"model_path = snapshot_download(\n",
83+
" repo_id=\"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\",\n",
84+
" local_dir=model_path\n",
85+
")"
6986
]
7087
},
7188
{
@@ -108,7 +125,7 @@
108125
"import torch\n",
109126
"from transformers import pipeline\n",
110127
"\n",
111-
"pipe = pipeline(\"text-generation\", model=\"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\", torch_dtype=torch.bfloat16, device=\"cuda\")"
128+
"pipe = pipeline(\"text-generation\", model=model_path, torch_dtype=torch.bfloat16, device=\"cuda\")"
112129
]
113130
},
114131
{
@@ -307,7 +324,6 @@
307324
" conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
308325
" elif on_dataproc:\n",
309326
" conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
310-
" conf.set(\"spark.executorEnv.HF_HOME\", hf_home)\n",
311327
"\n",
312328
" conf.set(\"spark.executor.cores\", \"8\")\n",
313329
" conf.set(\"spark.task.maxFailures\", \"1\")\n",
@@ -445,7 +461,7 @@
445461
"metadata": {},
446462
"outputs": [],
447463
"source": [
448-
"def triton_server(ports):\n",
464+
"def triton_server(ports, model_path):\n",
449465
" import time\n",
450466
" import signal\n",
451467
" import numpy as np\n",
@@ -458,7 +474,7 @@
458474
"\n",
459475
" print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
460476
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
461-
" pipe = pipeline(\"text-generation\", model=\"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\", torch_dtype=torch.bfloat16, device=device)\n",
477+
" pipe = pipeline(\"text-generation\", model=model_path, torch_dtype=torch.bfloat16, device=device)\n",
462478
" print(f\"SERVER: Using {device} device.\")\n",
463479
"\n",
464480
" @batch\n",
@@ -543,7 +559,7 @@
543559
"outputs": [],
544560
"source": [
545561
"model_name = \"deepseek-r1\"\n",
546-
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name)"
562+
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name, model_path=model_path)"
547563
]
548564
},
549565
{

examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/gemma-7b_torch.ipynb

+42-25
Original file line numberDiff line numberDiff line change
@@ -45,43 +45,48 @@
4545
"on_standalone = not (on_databricks or on_dataproc)"
4646
]
4747
},
48-
{
49-
"cell_type": "markdown",
50-
"metadata": {},
51-
"source": [
52-
"For cloud environments, set the huggingface cache dir to DBFS/GCS so that executors can load the model from cache."
53-
]
54-
},
5548
{
5649
"cell_type": "code",
5750
"execution_count": 14,
5851
"metadata": {},
5952
"outputs": [],
6053
"source": [
54+
"# For cloud environments, load the model to the distributed file system.\n",
6155
"if on_databricks:\n",
62-
" hf_home = \"/dbfs/FileStore/hf_home\"\n",
63-
" dbutils.fs.mkdirs(hf_home)\n",
64-
" os.environ[\"HF_HOME\"] = hf_home\n",
56+
" models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
57+
" dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
58+
" model_path = f\"{models_dir}/gemma-7b-it\"\n",
6559
"elif on_dataproc:\n",
66-
" hf_home = \"/mnt/gcs/hf_home\"\n",
67-
" os.mkdir(hf_home) if not os.path.exists(hf_home) else None\n",
68-
" os.environ[\"HF_HOME\"] = hf_home"
60+
" models_dir = \"/mnt/gcs/spark-dl-models\"\n",
61+
" os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
62+
" model_path = f\"{models_dir}/gemma-7b-it\"\n",
63+
"else:\n",
64+
" model_path = os.path.abspath(\"gemma-7b-it\")"
6965
]
7066
},
7167
{
7268
"cell_type": "markdown",
7369
"metadata": {},
7470
"source": [
75-
"### Warmup: Running locally\n",
71+
"First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub."
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"from huggingface_hub import login\n",
7681
"\n",
77-
"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
82+
"login()"
7883
]
7984
},
8085
{
8186
"cell_type": "markdown",
8287
"metadata": {},
8388
"source": [
84-
"First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub."
89+
"Once you have access, you can download the model:"
8590
]
8691
},
8792
{
@@ -90,9 +95,22 @@
9095
"metadata": {},
9196
"outputs": [],
9297
"source": [
93-
"from huggingface_hub import login\n",
98+
"from huggingface_hub import snapshot_download\n",
9499
"\n",
95-
"login()"
100+
"model_path = snapshot_download(\n",
101+
" repo_id=\"google/gemma-7b-it\",\n",
102+
" local_dir=model_path,\n",
103+
" ignore_patterns=\"*.gguf\"\n",
104+
")"
105+
]
106+
},
107+
{
108+
"cell_type": "markdown",
109+
"metadata": {},
110+
"source": [
111+
"## Warmup: Running locally\n",
112+
"\n",
113+
"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
96114
]
97115
},
98116
{
@@ -119,8 +137,8 @@
119137
"import torch\n",
120138
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
121139
"\n",
122-
"tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b-it\")\n",
123-
"model = AutoModelForCausalLM.from_pretrained(\"google/gemma-7b-it\",\n",
140+
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
141+
"model = AutoModelForCausalLM.from_pretrained(model_path,\n",
124142
" device_map=\"auto\",\n",
125143
" torch_dtype=torch.bfloat16)"
126144
]
@@ -246,7 +264,6 @@
246264
" conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
247265
" elif on_dataproc:\n",
248266
" conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
249-
" conf.set(\"spark.executorEnv.HF_HOME\", hf_home)\n",
250267
"\n",
251268
" conf.set(\"spark.executor.cores\", \"8\")\n",
252269
" conf.set(\"spark.task.maxFailures\", \"1\")\n",
@@ -421,7 +438,7 @@
421438
"metadata": {},
422439
"outputs": [],
423440
"source": [
424-
"def triton_server(ports):\n",
441+
"def triton_server(ports, model_path):\n",
425442
" import time\n",
426443
" import signal\n",
427444
" import numpy as np\n",
@@ -434,8 +451,8 @@
434451
"\n",
435452
" print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
436453
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
437-
" tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b-it\")\n",
438-
" model = AutoModelForCausalLM.from_pretrained(\"google/gemma-7b-it\", device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
454+
" tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
455+
" model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
439456
" print(f\"SERVER: Using {device} device.\")\n",
440457
"\n",
441458
" @batch\n",
@@ -523,7 +540,7 @@
523540
"outputs": [],
524541
"source": [
525542
"model_name = \"gemma-7b\"\n",
526-
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name)"
543+
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name, model_path=model_path)"
527544
]
528545
},
529546
{

examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/qwen-2.5-7b_torch.ipynb

+34-20
Original file line numberDiff line numberDiff line change
@@ -75,34 +75,51 @@
7575
"on_standalone = not (on_databricks or on_dataproc)"
7676
]
7777
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": 3,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"# For cloud environments, load the model to the distributed file system.\n",
85+
"if on_databricks:\n",
86+
" models_dir = \"/dbfs/FileStore/spark-dl-models\"\n",
87+
" dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
88+
" model_path = f\"{models_dir}/qwen-2.5-7b\"\n",
89+
"elif on_dataproc:\n",
90+
" models_dir = \"/mnt/gcs/spark-dl-models\"\n",
91+
" os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
92+
" model_path = f\"{models_dir}/qwen-2.5-7b\"\n",
93+
"else:\n",
94+
" model_path = os.path.abspath(\"qwen-2.5-7b\")"
95+
]
96+
},
7897
{
7998
"cell_type": "markdown",
8099
"metadata": {},
81100
"source": [
82-
"For cloud environments, set the huggingface cache dir to DBFS/GCS so that executors can load the model from cache."
101+
"Download the model from huggingface hub."
83102
]
84103
},
85104
{
86105
"cell_type": "code",
87-
"execution_count": 3,
106+
"execution_count": null,
88107
"metadata": {},
89108
"outputs": [],
90109
"source": [
91-
"if on_databricks:\n",
92-
" hf_home = \"/dbfs/FileStore/hf_home\"\n",
93-
" dbutils.fs.mkdirs(hf_home)\n",
94-
" os.environ[\"HF_HOME\"] = hf_home\n",
95-
"elif on_dataproc:\n",
96-
" hf_home = \"/mnt/gcs/hf_home\"\n",
97-
" os.mkdir(hf_home) if not os.path.exists(hf_home) else None\n",
98-
" os.environ[\"HF_HOME\"] = hf_home"
110+
"from huggingface_hub import snapshot_download\n",
111+
"\n",
112+
"model_path = snapshot_download(\n",
113+
" repo_id=\"Qwen/Qwen2.5-7B-Instruct\",\n",
114+
" local_dir=model_path\n",
115+
")"
99116
]
100117
},
101118
{
102119
"cell_type": "markdown",
103120
"metadata": {},
104121
"source": [
105-
"### Warmup: Running locally\n",
122+
"## Warmup: Running locally\n",
106123
"\n",
107124
"**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section."
108125
]
@@ -131,14 +148,12 @@
131148
"import torch\n",
132149
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
133150
"\n",
134-
"model_name = \"Qwen/Qwen2.5-7B-Instruct\"\n",
135-
"\n",
136151
"model = AutoModelForCausalLM.from_pretrained(\n",
137-
" model_name,\n",
152+
" model_path,\n",
138153
" torch_dtype=torch.bfloat16,\n",
139154
" device_map=\"auto\"\n",
140155
")\n",
141-
"tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side=\"left\")"
156+
"tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"left\")"
142157
]
143158
},
144159
{
@@ -309,7 +324,6 @@
309324
" conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
310325
" elif on_dataproc:\n",
311326
" conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
312-
" conf.set(\"spark.executorEnv.HF_HOME\", hf_home)\n",
313327
"\n",
314328
" conf.set(\"spark.executor.cores\", \"8\")\n",
315329
" conf.set(\"spark.task.maxFailures\", \"1\")\n",
@@ -508,7 +522,7 @@
508522
"metadata": {},
509523
"outputs": [],
510524
"source": [
511-
"def triton_server(ports):\n",
525+
"def triton_server(ports, model_path):\n",
512526
" import time\n",
513527
" import signal\n",
514528
" import torch\n",
@@ -521,11 +535,11 @@
521535
"\n",
522536
" print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
523537
" model = AutoModelForCausalLM.from_pretrained(\n",
524-
" \"Qwen/Qwen2.5-7B-Instruct\",\n",
538+
" model_path,\n",
525539
" torch_dtype=torch.bfloat16,\n",
526540
" device_map=\"auto\"\n",
527541
" )\n",
528-
" tokenizer = AutoTokenizer.from_pretrained(\"Qwen/Qwen2.5-7B-Instruct\", padding_side=\"left\")\n",
542+
" tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side=\"left\")\n",
529543
"\n",
530544
" @batch\n",
531545
" def _infer_fn(**inputs):\n",
@@ -610,7 +624,7 @@
610624
"outputs": [],
611625
"source": [
612626
"model_name = \"qwen-2.5\"\n",
613-
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name)"
627+
"server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name, model_path=model_path)"
614628
]
615629
},
616630
{

examples/ML+DL-Examples/Spark-DL/dl_inference/pytriton_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def _start_triton_server(
4444
) -> List[tuple]:
4545
"""Task to start Triton server process on a Spark executor."""
4646

47+
from pyspark import BarrierTaskContext
48+
4749
def _prepare_pytriton_env():
4850
"""Expose PyTriton to correct libpython3.11.so and Triton bundled libraries."""
4951
ld_library_paths = []
@@ -82,6 +84,7 @@ def _find_ports(start_port: int = 7000) -> List[int]:
8284

8385
return ports
8486

87+
tc = BarrierTaskContext.get()
8588
ports = _find_ports()
8689
sig = inspect.signature(triton_server_fn)
8790
params = sig.parameters
@@ -105,6 +108,7 @@ def _find_ports(start_port: int = 7000) -> List[int]:
105108
for _ in range(wait_retries):
106109
try:
107110
client.wait_for_model(wait_timeout)
111+
tc.barrier()
108112
client.close()
109113
return [(hostname, (process.pid, ports))]
110114
except Exception:

0 commit comments

Comments
 (0)