From 50df271ba6bd8b74d930228afd60309627ed545d Mon Sep 17 00:00:00 2001 From: bvonodiripsa Date: Tue, 27 Aug 2024 03:28:22 +0000 Subject: [PATCH] style changes --- ...Question - Answering with LLM on GPU.ipynb | 94 +++++++++---------- 1 file changed, 44 insertions(+), 50 deletions(-) diff --git a/docs/Explore Algorithms/AI Services/QuickStart - Distributed Question - Answering with LLM on GPU.ipynb b/docs/Explore Algorithms/AI Services/QuickStart - Distributed Question - Answering with LLM on GPU.ipynb index b518a64cf6..7ae41e1949 100644 --- a/docs/Explore Algorithms/AI Services/QuickStart - Distributed Question - Answering with LLM on GPU.ipynb +++ b/docs/Explore Algorithms/AI Services/QuickStart - Distributed Question - Answering with LLM on GPU.ipynb @@ -92,10 +92,16 @@ }, "outputs": [], "source": [ - "import fitz \n", + "import fitz\n", "import pyspark.sql.functions as F\n", "from pyspark.sql.types import ArrayType, FloatType, StringType\n", - "from pyspark.sql.functions import explode, col, monotonically_increasing_id, concat_ws, collect_list\n", + "from pyspark.sql.functions import (\n", + " explode,\n", + " col,\n", + " monotonically_increasing_id,\n", + " concat_ws,\n", + " collect_list,\n", + ")\n", "from pyspark.ml.functions import predict_batch_udf\n", "from sentence_transformers import SentenceTransformer\n", "from synapse.ml.featurize.text import PageSplitter\n", @@ -245,6 +251,7 @@ "# Register the function as a UDF\n", "extract_text_udf = udf(extract_text_from_binary_pdf, StringType())\n", "\n", + "\n", "# Apply the UDF to extract text from the binary content\n", "analyzed_df = df.withColumn(\"output_content\", extract_text_udf(df[\"content\"]))" ] @@ -388,14 +395,15 @@ " # def encode_text_batch(texts):\n", " def encode_text_batch():\n", " # Load the model inside the function\n", - " model = SentenceTransformer('bzantium/NV-Embed-v1', trust_remote_code=True)\n", + " model = SentenceTransformer(\"bzantium/NV-Embed-v1\", trust_remote_code=True)\n", " model.max_seq_length = 4096\n", " model.tokenizer.padding_side = \"right\"\n", - " \n", + "\n", " def predict(inputs):\n", - " \n", + "\n", " output = model.encode(\n", - " inputs.tolist(), prompt=query_prefix, normalize_embeddings=True)\n", + " inputs.tolist(), prompt=query_prefix, normalize_embeddings=True\n", + " )\n", " return output\n", "\n", " return predict\n", @@ -406,9 +414,7 @@ "\n", " # Define the predict_batch_udf with the above function\n", " return predict_batch_udf(\n", - " encode_text_batch,\n", - " return_type=ArrayType(FloatType()),\n", - " batch_size=1\n", + " encode_text_batch, return_type=ArrayType(FloatType()), batch_size=1\n", " )" ] }, @@ -522,18 +528,19 @@ "source": [ "from pyspark.sql.types import StructType, StructField, StringType, IntegerType\n", "\n", - "task_name_to_instruct = {\"example\": \"Given a question, retrieve passages from the provided context that answer the question\",}\n", + "task_name_to_instruct = {\n", + " \"example\": \"Given a question, retrieve passages from the provided context that answer the question\",\n", + "}\n", "\n", - "query_prefix = \"Instruct: \"+task_name_to_instruct[\"example\"]+\"\\nQuery: \"\n", + "query_prefix = \"Instruct: \" + task_name_to_instruct[\"example\"] + \"\\nQuery: \"\n", "\n", "encode_udf = create_encode_udf(query_prefix)\n", "\n", "user_question = \"What did the astronaut Edgar Mitchell call Earth?\"\n", "# Define schema explicitly\n", - "schema = StructType([\n", - " StructField(\"id\", IntegerType(), True),\n", - " StructField(\"query\", StringType(), True)\n", - "])\n", + "schema = StructType(\n", + " [StructField(\"id\", IntegerType(), True), StructField(\"query\", StringType(), True)]\n", + ")\n", "\n", "# Create DataFrame with id = 1 and the user query\n", "temp_df = spark.createDataFrame([(1, user_question)], schema).cache()\n", @@ -577,7 +584,9 @@ }, "outputs": [], "source": [ - "(_, _, knn_df) = rapids_knn_model.kneighbors(query_embeddings.select(\"id\", \"embeddings\"))" + "(_, _, knn_df) = rapids_knn_model.kneighbors(\n", + " query_embeddings.select(\"id\", \"embeddings\")\n", + ")" ] }, { @@ -599,17 +608,17 @@ "source": [ "# Add text to the results\n", "result_df = (\n", - " knn_df.withColumn(\n", - " \"zipped\", F.explode(F.arrays_zip(F.col(\"indices\"), F.col(\"distances\")))\n", - " )\n", - " .select(\n", - " F.col(\"query_id\"),\n", - " F.col(\"zipped.indices\").alias(\"id\"),\n", - " F.col(\"zipped.distances\").alias(\"distance\"),\n", - " )\n", - " .join(embeddings, on=\"id\", how=\"inner\")\n", - " .select(\"query_id\", \"id\", \"chunk\", \"distance\")\n", - " )" + " knn_df.withColumn(\n", + " \"zipped\", F.explode(F.arrays_zip(F.col(\"indices\"), F.col(\"distances\")))\n", + " )\n", + " .select(\n", + " F.col(\"query_id\"),\n", + " F.col(\"zipped.indices\").alias(\"id\"),\n", + " F.col(\"zipped.distances\").alias(\"distance\"),\n", + " )\n", + " .join(embeddings, on=\"id\", how=\"inner\")\n", + " .select(\"query_id\", \"id\", \"chunk\", \"distance\")\n", + ")" ] }, { @@ -630,7 +639,9 @@ "outputs": [], "source": [ "# Concatenate all strings in the 'combined_text' column across all question related chunks\n", - "concatenated_text = result_df.agg(concat_ws(\" \", collect_list(\"chunk\")).alias(\"concatenated_text\")).collect()[0][\"concatenated_text\"]\n" + "concatenated_text = result_df.agg(\n", + " concat_ws(\" \", collect_list(\"chunk\")).alias(\"concatenated_text\")\n", + ").collect()[0][\"concatenated_text\"]" ] }, { @@ -678,7 +689,7 @@ "# Put model in global if we want to reuse it\n", "global llm\n", "\n", - "if 'llm' in globals() and llm is not None:\n", + "if \"llm\" in globals() and llm is not None:\n", " print(\"Model is already loaded.\")\n", "else:\n", " print(\"Model is not loaded.\")\n", @@ -689,8 +700,8 @@ " build_config.max_input_len = 5120\n", " build_config.max_seq_len = 5632\n", "\n", - " llm = LLM(model=\"microsoft/Phi-3-mini-4k-instruct\", build_config=build_config) \n", - " \n", + " llm = LLM(model=\"microsoft/Phi-3-mini-4k-instruct\", build_config=build_config)\n", + "\n", "sampling_params = SamplingParams(temperature=0.8, top_p=0.95)\n", "\n", "context = concatenated_text\n", @@ -741,32 +752,15 @@ }, "outputs": [], "source": [ - "# Mocking the custom classes if they are not already defined\n", - "class CompletionOutput:\n", - " def __init__(self, index, text, token_ids, cumulative_logprob=None, logprobs=None):\n", - " self.index = index\n", - " self.text = text\n", - " self.token_ids = token_ids\n", - " self.cumulative_logprob = cumulative_logprob\n", - " self.logprobs = logprobs\n", - "\n", - "class RequestOutput:\n", - " def __init__(self, request_id, prompt, prompt_token_ids, outputs, finished):\n", - " self.request_id = request_id\n", - " self.prompt = prompt\n", - " self.prompt_token_ids = prompt_token_ids\n", - " self.outputs = outputs\n", - " self.finished = finished\n", - "\n", "output_text = outputs.outputs[0].text\n", "\n", "# Split the text by '\\n'\n", - "split_text = output_text.split('\\n')\n", + "split_text = output_text.split(\"\\n\")\n", "\n", "for item in split_text:\n", " if len(item) > 10:\n", " # Split the item at the colon and take the part after it\n", - " result = item.split(':', 1)[-1].strip()\n", + " result = item.split(\":\", 1)[-1].strip()\n", " print(\"Answer: \" + result)\n", " break" ]