|
45 | 45 | "on_standalone = not (on_databricks or on_dataproc)"
|
46 | 46 | ]
|
47 | 47 | },
|
48 |
| - { |
49 |
| - "cell_type": "markdown", |
50 |
| - "metadata": {}, |
51 |
| - "source": [ |
52 |
| - "For cloud environments, set the huggingface cache dir to DBFS/GCS so that executors can load the model from cache." |
53 |
| - ] |
54 |
| - }, |
55 | 48 | {
|
56 | 49 | "cell_type": "code",
|
57 | 50 | "execution_count": 14,
|
58 | 51 | "metadata": {},
|
59 | 52 | "outputs": [],
|
60 | 53 | "source": [
|
| 54 | + "# For cloud environments, load the model to the distributed file system.\n", |
61 | 55 | "if on_databricks:\n",
|
62 |
| - " hf_home = \"/dbfs/FileStore/hf_home\"\n", |
63 |
| - " dbutils.fs.mkdirs(hf_home)\n", |
64 |
| - " os.environ[\"HF_HOME\"] = hf_home\n", |
| 56 | + " models_dir = \"/dbfs/FileStore/spark-dl-models\"\n", |
| 57 | + " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n", |
| 58 | + " model_path = f\"{models_dir}/gemma-7b-it\"\n", |
65 | 59 | "elif on_dataproc:\n",
|
66 |
| - " hf_home = \"/mnt/gcs/hf_home\"\n", |
67 |
| - " os.mkdir(hf_home) if not os.path.exists(hf_home) else None\n", |
68 |
| - " os.environ[\"HF_HOME\"] = hf_home" |
| 60 | + " models_dir = \"/mnt/gcs/spark-dl-models\"\n", |
| 61 | + " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n", |
| 62 | + " model_path = f\"{models_dir}/gemma-7b-it\"\n", |
| 63 | + "else:\n", |
| 64 | + " model_path = os.path.abspath(\"gemma-7b-it\")" |
69 | 65 | ]
|
70 | 66 | },
|
71 | 67 | {
|
72 | 68 | "cell_type": "markdown",
|
73 | 69 | "metadata": {},
|
74 | 70 | "source": [
|
75 |
| - "### Warmup: Running locally\n", |
| 71 | + "First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub." |
| 72 | + ] |
| 73 | + }, |
| 74 | + { |
| 75 | + "cell_type": "code", |
| 76 | + "execution_count": null, |
| 77 | + "metadata": {}, |
| 78 | + "outputs": [], |
| 79 | + "source": [ |
| 80 | + "from huggingface_hub import login\n", |
76 | 81 | "\n",
|
77 |
| - "**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section." |
| 82 | + "login()" |
78 | 83 | ]
|
79 | 84 | },
|
80 | 85 | {
|
81 | 86 | "cell_type": "markdown",
|
82 | 87 | "metadata": {},
|
83 | 88 | "source": [
|
84 |
| - "First visit the [Gemma Huggingface repository](https://huggingface.co/google/gemma-7b-it) to accept the terms to access the model, then login via huggingface_hub." |
| 89 | + "Once you have access, you can download the model:" |
85 | 90 | ]
|
86 | 91 | },
|
87 | 92 | {
|
|
90 | 95 | "metadata": {},
|
91 | 96 | "outputs": [],
|
92 | 97 | "source": [
|
93 |
| - "from huggingface_hub import login\n", |
| 98 | + "from huggingface_hub import snapshot_download\n", |
94 | 99 | "\n",
|
95 |
| - "login()" |
| 100 | + "model_path = snapshot_download(\n", |
| 101 | + " repo_id=\"google/gemma-7b-it\",\n", |
| 102 | + " local_dir=model_path,\n", |
| 103 | + " ignore_patterns=\"*.gguf\"\n", |
| 104 | + ")" |
| 105 | + ] |
| 106 | + }, |
| 107 | + { |
| 108 | + "cell_type": "markdown", |
| 109 | + "metadata": {}, |
| 110 | + "source": [ |
| 111 | + "## Warmup: Running locally\n", |
| 112 | + "\n", |
| 113 | + "**Note**: If the driver node does not have sufficient GPU capacity, proceed to the PySpark section." |
96 | 114 | ]
|
97 | 115 | },
|
98 | 116 | {
|
|
119 | 137 | "import torch\n",
|
120 | 138 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
|
121 | 139 | "\n",
|
122 |
| - "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b-it\")\n", |
123 |
| - "model = AutoModelForCausalLM.from_pretrained(\"google/gemma-7b-it\",\n", |
| 140 | + "tokenizer = AutoTokenizer.from_pretrained(model_path)\n", |
| 141 | + "model = AutoModelForCausalLM.from_pretrained(model_path,\n", |
124 | 142 | " device_map=\"auto\",\n",
|
125 | 143 | " torch_dtype=torch.bfloat16)"
|
126 | 144 | ]
|
|
246 | 264 | " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
|
247 | 265 | " elif on_dataproc:\n",
|
248 | 266 | " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
|
249 |
| - " conf.set(\"spark.executorEnv.HF_HOME\", hf_home)\n", |
250 | 267 | "\n",
|
251 | 268 | " conf.set(\"spark.executor.cores\", \"8\")\n",
|
252 | 269 | " conf.set(\"spark.task.maxFailures\", \"1\")\n",
|
|
421 | 438 | "metadata": {},
|
422 | 439 | "outputs": [],
|
423 | 440 | "source": [
|
424 |
| - "def triton_server(ports):\n", |
| 441 | + "def triton_server(ports, model_path):\n", |
425 | 442 | " import time\n",
|
426 | 443 | " import signal\n",
|
427 | 444 | " import numpy as np\n",
|
|
434 | 451 | "\n",
|
435 | 452 | " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
|
436 | 453 | " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
437 |
| - " tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b-it\")\n", |
438 |
| - " model = AutoModelForCausalLM.from_pretrained(\"google/gemma-7b-it\", device_map=\"auto\", torch_dtype=torch.bfloat16)\n", |
| 454 | + " tokenizer = AutoTokenizer.from_pretrained(model_path)\n", |
| 455 | + " model = AutoModelForCausalLM.from_pretrained(model_path, device_map=\"auto\", torch_dtype=torch.bfloat16)\n", |
439 | 456 | " print(f\"SERVER: Using {device} device.\")\n",
|
440 | 457 | "\n",
|
441 | 458 | " @batch\n",
|
|
523 | 540 | "outputs": [],
|
524 | 541 | "source": [
|
525 | 542 | "model_name = \"gemma-7b\"\n",
|
526 |
| - "server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name)" |
| 543 | + "server_manager = TritonServerManager(num_nodes=num_nodes, model_name=model_name, model_path=model_path)" |
527 | 544 | ]
|
528 | 545 | },
|
529 | 546 | {
|
|
0 commit comments