Skip to content

Commit

Permalink
style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bvonodiripsa committed Aug 27, 2024
1 parent 81b0705 commit 50df271
Showing 1 changed file with 44 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,16 @@
},
"outputs": [],
"source": [
"import fitz \n",
"import fitz\n",
"import pyspark.sql.functions as F\n",
"from pyspark.sql.types import ArrayType, FloatType, StringType\n",
"from pyspark.sql.functions import explode, col, monotonically_increasing_id, concat_ws, collect_list\n",
"from pyspark.sql.functions import (\n",
" explode,\n",
" col,\n",
" monotonically_increasing_id,\n",
" concat_ws,\n",
" collect_list,\n",
")\n",
"from pyspark.ml.functions import predict_batch_udf\n",
"from sentence_transformers import SentenceTransformer\n",
"from synapse.ml.featurize.text import PageSplitter\n",
Expand Down Expand Up @@ -245,6 +251,7 @@
"# Register the function as a UDF\n",
"extract_text_udf = udf(extract_text_from_binary_pdf, StringType())\n",
"\n",
"\n",
"# Apply the UDF to extract text from the binary content\n",
"analyzed_df = df.withColumn(\"output_content\", extract_text_udf(df[\"content\"]))"
]
Expand Down Expand Up @@ -388,14 +395,15 @@
" # def encode_text_batch(texts):\n",
" def encode_text_batch():\n",
" # Load the model inside the function\n",
" model = SentenceTransformer('bzantium/NV-Embed-v1', trust_remote_code=True)\n",
" model = SentenceTransformer(\"bzantium/NV-Embed-v1\", trust_remote_code=True)\n",
" model.max_seq_length = 4096\n",
" model.tokenizer.padding_side = \"right\"\n",
" \n",
"\n",
" def predict(inputs):\n",
" \n",
"\n",
" output = model.encode(\n",
" inputs.tolist(), prompt=query_prefix, normalize_embeddings=True)\n",
" inputs.tolist(), prompt=query_prefix, normalize_embeddings=True\n",
" )\n",
" return output\n",
"\n",
" return predict\n",
Expand All @@ -406,9 +414,7 @@
"\n",
" # Define the predict_batch_udf with the above function\n",
" return predict_batch_udf(\n",
" encode_text_batch,\n",
" return_type=ArrayType(FloatType()),\n",
" batch_size=1\n",
" encode_text_batch, return_type=ArrayType(FloatType()), batch_size=1\n",
" )"
]
},
Expand Down Expand Up @@ -522,18 +528,19 @@
"source": [
"from pyspark.sql.types import StructType, StructField, StringType, IntegerType\n",
"\n",
"task_name_to_instruct = {\"example\": \"Given a question, retrieve passages from the provided context that answer the question\",}\n",
"task_name_to_instruct = {\n",
" \"example\": \"Given a question, retrieve passages from the provided context that answer the question\",\n",
"}\n",
"\n",
"query_prefix = \"Instruct: \"+task_name_to_instruct[\"example\"]+\"\\nQuery: \"\n",
"query_prefix = \"Instruct: \" + task_name_to_instruct[\"example\"] + \"\\nQuery: \"\n",
"\n",
"encode_udf = create_encode_udf(query_prefix)\n",
"\n",
"user_question = \"What did the astronaut Edgar Mitchell call Earth?\"\n",
"# Define schema explicitly\n",
"schema = StructType([\n",
" StructField(\"id\", IntegerType(), True),\n",
" StructField(\"query\", StringType(), True)\n",
"])\n",
"schema = StructType(\n",
" [StructField(\"id\", IntegerType(), True), StructField(\"query\", StringType(), True)]\n",
")\n",
"\n",
"# Create DataFrame with id = 1 and the user query\n",
"temp_df = spark.createDataFrame([(1, user_question)], schema).cache()\n",
Expand Down Expand Up @@ -577,7 +584,9 @@
},
"outputs": [],
"source": [
"(_, _, knn_df) = rapids_knn_model.kneighbors(query_embeddings.select(\"id\", \"embeddings\"))"
"(_, _, knn_df) = rapids_knn_model.kneighbors(\n",
" query_embeddings.select(\"id\", \"embeddings\")\n",
")"
]
},
{
Expand All @@ -599,17 +608,17 @@
"source": [
"# Add text to the results\n",
"result_df = (\n",
" knn_df.withColumn(\n",
" \"zipped\", F.explode(F.arrays_zip(F.col(\"indices\"), F.col(\"distances\")))\n",
" )\n",
" .select(\n",
" F.col(\"query_id\"),\n",
" F.col(\"zipped.indices\").alias(\"id\"),\n",
" F.col(\"zipped.distances\").alias(\"distance\"),\n",
" )\n",
" .join(embeddings, on=\"id\", how=\"inner\")\n",
" .select(\"query_id\", \"id\", \"chunk\", \"distance\")\n",
" )"
" knn_df.withColumn(\n",
" \"zipped\", F.explode(F.arrays_zip(F.col(\"indices\"), F.col(\"distances\")))\n",
" )\n",
" .select(\n",
" F.col(\"query_id\"),\n",
" F.col(\"zipped.indices\").alias(\"id\"),\n",
" F.col(\"zipped.distances\").alias(\"distance\"),\n",
" )\n",
" .join(embeddings, on=\"id\", how=\"inner\")\n",
" .select(\"query_id\", \"id\", \"chunk\", \"distance\")\n",
")"
]
},
{
Expand All @@ -630,7 +639,9 @@
"outputs": [],
"source": [
"# Concatenate all strings in the 'combined_text' column across all question related chunks\n",
"concatenated_text = result_df.agg(concat_ws(\" \", collect_list(\"chunk\")).alias(\"concatenated_text\")).collect()[0][\"concatenated_text\"]\n"
"concatenated_text = result_df.agg(\n",
" concat_ws(\" \", collect_list(\"chunk\")).alias(\"concatenated_text\")\n",
").collect()[0][\"concatenated_text\"]"
]
},
{
Expand Down Expand Up @@ -678,7 +689,7 @@
"# Put model in global if we want to reuse it\n",
"global llm\n",
"\n",
"if 'llm' in globals() and llm is not None:\n",
"if \"llm\" in globals() and llm is not None:\n",
" print(\"Model is already loaded.\")\n",
"else:\n",
" print(\"Model is not loaded.\")\n",
Expand All @@ -689,8 +700,8 @@
" build_config.max_input_len = 5120\n",
" build_config.max_seq_len = 5632\n",
"\n",
" llm = LLM(model=\"microsoft/Phi-3-mini-4k-instruct\", build_config=build_config) \n",
" \n",
" llm = LLM(model=\"microsoft/Phi-3-mini-4k-instruct\", build_config=build_config)\n",
"\n",
"sampling_params = SamplingParams(temperature=0.8, top_p=0.95)\n",
"\n",
"context = concatenated_text\n",
Expand Down Expand Up @@ -741,32 +752,15 @@
},
"outputs": [],
"source": [
"# Mocking the custom classes if they are not already defined\n",
"class CompletionOutput:\n",
" def __init__(self, index, text, token_ids, cumulative_logprob=None, logprobs=None):\n",
" self.index = index\n",
" self.text = text\n",
" self.token_ids = token_ids\n",
" self.cumulative_logprob = cumulative_logprob\n",
" self.logprobs = logprobs\n",
"\n",
"class RequestOutput:\n",
" def __init__(self, request_id, prompt, prompt_token_ids, outputs, finished):\n",
" self.request_id = request_id\n",
" self.prompt = prompt\n",
" self.prompt_token_ids = prompt_token_ids\n",
" self.outputs = outputs\n",
" self.finished = finished\n",
"\n",
"output_text = outputs.outputs[0].text\n",
"\n",
"# Split the text by '\\n'\n",
"split_text = output_text.split('\\n')\n",
"split_text = output_text.split(\"\\n\")\n",
"\n",
"for item in split_text:\n",
" if len(item) > 10:\n",
" # Split the item at the colon and take the part after it\n",
" result = item.split(':', 1)[-1].strip()\n",
" result = item.split(\":\", 1)[-1].strip()\n",
" print(\"Answer: \" + result)\n",
" break"
]
Expand Down

0 comments on commit 50df271

Please sign in to comment.