Skip to content

Commit

Permalink
feat: add Azure Cognitive Search vector store (#2041)
Browse files Browse the repository at this point in the history
* add vector column option

* add the vector option

* vector fields are added and code compiles, untested

* fix bug on checkparity when the index exists

* add FloatType to edm-spark type conversions

* fix synonymmap

* core functionality works

* add no nested field vector check

* add vector validation check

* modify vector columns behavior when column doesn't exist in df schema

* add another test

* clean up the unit test file

* add more tests

* add openai embedding pipeline test

* address comments

* address comments

* address comments

* update notebook

* change index name in notebook
  • Loading branch information
aydan-at-microsoft authored Sep 7, 2023
1 parent 111823d commit d494f6e
Show file tree
Hide file tree
Showing 5 changed files with 484 additions and 197 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import org.apache.spark.internal.{Logging => SLogging}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.functions.vector_to_array
import org.apache.spark.sql.functions.{col, expr, struct, to_json}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -142,7 +144,7 @@ class AddDocuments(override val uid: String) extends CognitiveServicesBase(uid)
override def responseDataType: DataType = ASResponses.schema
}

object AzureSearchWriter extends IndexParser with SLogging {
object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging {

val Logger: Logger = LogManager.getRootLogger

Expand All @@ -166,9 +168,11 @@ object AzureSearchWriter extends IndexParser with SLogging {
private def convertFields(fields: Seq[StructField],
keyCol: String,
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]],
prefix: Option[String]): Seq[IndexField] = {
fields.filterNot(_.name == searchActionCol).map { sf =>
val fullName = prefix.map(_ + sf.name).getOrElse(sf.name)
val isVector = vectorCols.exists(_.exists(_.name == fullName))
val (innerType, _) = sparkTypeToEdmType(sf.dataType)
IndexField(
sf.name,
Expand All @@ -177,31 +181,44 @@ object AzureSearchWriter extends IndexParser with SLogging {
if (keyCol == fullName) Some(true) else None,
None, None, None, None,
structFieldToSearchFields(sf.dataType,
keyCol, searchActionCol, prefix = Some(prefix.getOrElse("") + sf.name + "."))
keyCol, searchActionCol, None, prefix = Some(prefix.getOrElse("") + sf.name + ".")),
if (isVector) vectorCols.get.find(_.name == fullName).map(_.dimension) else None,
if (isVector) Some(AzureSearchAPIConstants.VectorConfigName) else None
)
}
}

private def structFieldToSearchFields(schema: DataType,
keyCol: String,
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]],
prefix: Option[String] = None
): Option[Seq[IndexField]] = {
schema match {
case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, prefix))
case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, prefix))
case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, vectorCols, prefix))
// TODO: Support vector search in nested fields
case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, None, prefix))
case _ => None
}
}

private def parseVectorColsJson(str: String): Seq[VectorColParams] = {
str.parseJson.convertTo[Seq[VectorColParams]]
}

private def dfToIndexJson(schema: StructType,
indexName: String,
keyCol: String,
searchActionCol: String): String = {
searchActionCol: String,
vectorCols: Option[Seq[VectorColParams]]): String = {

val vectorConfig = Some(VectorSearch(Seq(AlgorithmConfigs(AzureSearchAPIConstants.VectorConfigName,
AzureSearchAPIConstants.VectorSearchAlgorithm))))
val is = IndexInfo(
Some(indexName),
structFieldToSearchFields(schema, keyCol, searchActionCol).get,
None, None, None, None, None, None, None, None
structFieldToSearchFields(schema, keyCol, searchActionCol, vectorCols).get,
None, None, None, None, None, None, None, None,
if (vectorCols.isEmpty) None else vectorConfig
)
is.toJson.compactPrint
}
Expand All @@ -210,7 +227,7 @@ object AzureSearchWriter extends IndexParser with SLogging {
options: Map[String, String] = Map()): DataFrame = {
val applicableOptions = Set(
"subscriptionKey", "actionCol", "serviceName", "indexName", "indexJson",
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol"
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "vectorCols"
)

options.keys.foreach(k =>
Expand All @@ -224,11 +241,12 @@ object AzureSearchWriter extends IndexParser with SLogging {
val batchSize = options.getOrElse("batchSize", "100").toInt
val fatalErrors = options.getOrElse("fatalErrors", "true").toBoolean
val filterNulls = options.getOrElse("filterNulls", "false").toBoolean
val vectorColsInfo = options.get("vectorCols")

val keyCol = options.get("keyCol")
val indexName = options.getOrElse("indexName", parseIndexJson(indexJsonOpt.get).name.get)
if (indexJsonOpt.isDefined) {
List("keyCol", "indexName").foreach(opt =>
List("keyCol", "indexName", "vectorCols").foreach(opt =>
assert(!options.contains(opt), s"Cannot set both indexJson options and $opt")
)
}
Expand All @@ -242,22 +260,41 @@ object AzureSearchWriter extends IndexParser with SLogging {
}
}

val indexJson = indexJsonOpt.getOrElse {
dfToIndexJson(df.schema, indexName, keyCol.get, actionCol)
val (indexJson, preppedDF) = if (getExisting(subscriptionKey, serviceName, apiVersion).contains(indexName)) {
if (indexJsonOpt.isDefined) {
println(f"indexJsonOpt is specified, however an index for $indexName already exists," +
f"we will use the index definition obtained from the existing index instead")
}
val existingIndexJson = getIndexJsonFromExistingIndex(subscriptionKey, serviceName, indexName)
val vectorColNameTypeTuple = getVectorColConf(existingIndexJson)
(existingIndexJson, makeColsCompatible(vectorColNameTypeTuple, df))
} else if (indexJsonOpt.isDefined) {
val vectorColNameTypeTuple = getVectorColConf(indexJsonOpt.get)
(indexJsonOpt.get, makeColsCompatible(vectorColNameTypeTuple, df))
} else {
val vectorCols = vectorColsInfo.map(parseVectorColsJson)
val vectorColNameTypeTuple = vectorCols.map(_.map(vc => (vc.name, "Collection(Edm.Single)"))).getOrElse(Seq.empty)
val newDF = makeColsCompatible(vectorColNameTypeTuple, df)
val inferredIndexJson = dfToIndexJson(newDF.schema, indexName, keyCol.getOrElse(""), actionCol, vectorCols)
(inferredIndexJson, newDF)
}

// TODO: Support vector search in nested fields
// Throws an exception if any nested field is a vector in the schema
parseIndexJson(indexJson).fields.foreach(_.fields.foreach(assertNoNestedVectors))

SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion)

logInfo("checking schema parity")
checkSchemaParity(df.schema, indexJson, actionCol)
checkSchemaParity(preppedDF.schema, indexJson, actionCol)

val df1 = if (filterNulls) {
val collectionColumns = parseIndexJson(indexJson).fields
.filter(_.`type`.startsWith("Collection"))
.map(_.name)
collectionColumns.foldLeft(df) { (ndf, c) => filterOutNulls(ndf, c) }
collectionColumns.foldLeft(preppedDF) { (ndf, c) => filterOutNulls(ndf, c) }
} else {
df
preppedDF
}

new AddDocuments()
Expand All @@ -273,6 +310,48 @@ object AzureSearchWriter extends IndexParser with SLogging {
UDFUtils.oldUdf(checkForErrors(fatalErrors) _, ErrorUtils.ErrorSchema)(col("error"), col("input")))
}

private def assertNoNestedVectors(fields: Seq[IndexField]): Unit = {
def checkVectorField(field: IndexField): Unit = {
if (field.dimensions.nonEmpty && field.vectorSearchConfiguration.nonEmpty) {
throw new IllegalArgumentException(s"Nested field ${field.name} is a vector field, vector fields in nested" +
s" fields are not supported.")
}
field.fields.foreach(_.foreach(checkVectorField))
}
fields.foreach(checkVectorField)
}

private def getVectorColConf(indexJson: String): Seq[(String, String)] = {
parseIndexJson(indexJson).fields
.filter(f => f.vectorSearchConfiguration.nonEmpty && f.dimensions.nonEmpty)
.map(f => (f.name, f.`type`))
}
private def makeColsCompatible(vectorColNameTypeTuple: Seq[(String, String)],
df: DataFrame): DataFrame = {
vectorColNameTypeTuple.foldLeft(df) { case (accDF, (colName, colType)) =>
if (!accDF.columns.contains(colName)) {
println(s"Column $colName is specified in either indexJson or vectorCols but not found in dataframe " +
s"columns ${accDF.columns.toList}")
accDF
}
else {
val colDataType = accDF.schema(colName).dataType
assert(colDataType match {
case ArrayType(elementType, _) => elementType == FloatType || elementType == DoubleType
case VectorType => true
case _ => false
}, s"Vector column $colName needs to be one of (ArrayType(FloatType), ArrayType(DoubleType), VectorType)")
if (colDataType.isInstanceOf[ArrayType]) {
accDF.withColumn(colName, accDF(colName).cast(edmTypeToSparkType(colType, None)))
} else {
// first cast vectorUDT to array<double>, then cast it to correct array type
val modifiedDF = accDF.withColumn(colName, vector_to_array(accDF(colName)))
modifiedDF.withColumn(colName, modifiedDF(colName).cast(edmTypeToSparkType(colType, None)))
}
}
}
}

private def isEdmCollection(t: String): Boolean = {
t.startsWith("Collection(") && t.endsWith(")")
}
Expand All @@ -290,6 +369,7 @@ object AzureSearchWriter extends IndexParser with SLogging {
case "Edm.Int64" => LongType
case "Edm.Int32" => IntegerType
case "Edm.Double" => DoubleType
case "Edm.Single" => FloatType
case "Edm.DateTimeOffset" => StringType //See if there's a way to use spark datetimes
case "Edm.GeographyPoint" => StringType
case "Edm.ComplexType" => StructType(fields.get.map(f =>
Expand All @@ -310,10 +390,12 @@ object AzureSearchWriter extends IndexParser with SLogging {
case IntegerType => ("Edm.Int32", None)
case LongType => ("Edm.Int64", None)
case DoubleType => ("Edm.Double", None)
case FloatType => ("Edm.Single", None)
case DateType => ("Edm.DateTimeOffset", None)
case StructType(fields) => ("Edm.ComplexType", Some(fields.map { f =>
val (innerType, innerFields) = sparkTypeToEdmType(f.dataType)
IndexField(f.name, innerType, None, None, None, None, None, None, None, None, None, None, innerFields)
IndexField(f.name, innerType, None, None, None, None, None, None, None, None, None, None, innerFields,
None, None) // TODO: Support vector search in nested fields
}))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import spray.json._
import scala.util.{Failure, Success, Try}

object AzureSearchAPIConstants {
val DefaultAPIVersion = "2019-05-06"
val DefaultAPIVersion = "2023-07-01-Preview"
val VectorConfigName = "vectorConfig"
val VectorSearchAlgorithm = "hnsw"
}
import com.microsoft.azure.synapse.ml.cognitive.search.AzureSearchAPIConstants._

Expand All @@ -39,6 +41,26 @@ trait IndexLister {
}
}

trait IndexJsonGetter extends IndexLister {
def getIndexJsonFromExistingIndex(key: String,
serviceName: String,
indexName: String,
apiVersion: String = DefaultAPIVersion): String = {
val existingIndexNames = getExisting(key, serviceName, apiVersion)
assert(existingIndexNames.contains(indexName), s"Cannot find an existing index name with $indexName")

val indexJsonRequest = new HttpGet(
s"https://$serviceName.search.windows.net/indexes/$indexName?api-version=$apiVersion"
)
indexJsonRequest.setHeader("api-key", key)
indexJsonRequest.setHeader("Content-Type", "application/json")
val indexJsonResponse = safeSend(indexJsonRequest, close = false)
val indexJson = IOUtils.toString(indexJsonResponse.getEntity.getContent, "utf-8")
indexJsonResponse.close()
indexJson
}
}

object SearchIndex extends IndexParser with IndexLister {

import AzureSearchProtocol._
Expand Down Expand Up @@ -94,7 +116,9 @@ object SearchIndex extends IndexParser with IndexLister {
_ <- validAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validSearchAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validIndexAnalyzer(field.analyzer, field.searchAnalyzer, field.indexAnalyzer)
_ <- validSynonymMaps(field.synonymMap)
_ <- validVectorField(field.dimensions, field.vectorSearchConfiguration)
// TODO: Fix and add back validSynonymMaps check. SynonymMaps needs to be Option[Seq[String]] type
//_ <- validSynonymMaps(field.synonymMap)
} yield field
}

Expand Down Expand Up @@ -182,6 +206,15 @@ object SearchIndex extends IndexParser with IndexLister {
}
}

private def validVectorField(d: Option[Int], v: Option[String]): Try[Option[String]] = {
if ((d.isDefined && v.isEmpty) || (v.isDefined && d.isEmpty)) {
Failure(new IllegalArgumentException("Both dimensions and vectorSearchConfig fields need to be defined for " +
"vector search"))
} else {
Success(v)
}
}

def getStatistics(indexName: String,
key: String,
serviceName: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package com.microsoft.azure.synapse.ml.cognitive.search

import com.microsoft.azure.synapse.ml.core.schema.SparkBindings
import spray.json.DefaultJsonProtocol._
import spray.json.{JsonFormat, RootJsonFormat}
import spray.json.{DefaultJsonProtocol, JsonFormat, RootJsonFormat}

object ASResponses extends SparkBindings[ASResponses]

Expand All @@ -23,9 +23,19 @@ case class IndexInfo(
tokenizers: Option[Seq[String]],
tokenFilters: Option[Seq[String]],
defaultScoringProfile: Option[Seq[String]],
corsOptions: Option[Seq[String]]
corsOptions: Option[Seq[String]],
vectorSearch: Option[VectorSearch]
)

case class AlgorithmConfigs(
name: String,
kind: String
)

case class VectorSearch(
algorithmConfigurations: Seq[AlgorithmConfigs]
)

case class IndexField(
name: String,
`type`: String,
Expand All @@ -38,21 +48,32 @@ case class IndexField(
analyzer: Option[String],
searchAnalyzer: Option[String],
indexAnalyzer: Option[String],
synonymMap: Option[String],
fields: Option[Seq[IndexField]]
synonymMap: Option[Seq[String]],
fields: Option[Seq[IndexField]],
dimensions: Option[Int],
vectorSearchConfiguration: Option[String]
)

case class VectorColParams(
name: String,
dimension: Int
)

case class IndexStats(documentCount: Int, storageSize: Int)

case class IndexList(`@odata.context`: String, value: Seq[IndexName])
case class IndexName(name: String)

object AzureSearchProtocol {
object AzureSearchProtocol extends DefaultJsonProtocol {
implicit val IfEnc: JsonFormat[IndexField] = lazyFormat(jsonFormat(
IndexField,"name","type","searchable","filterable","sortable",
"facetable","retrievable", "key","analyzer","searchAnalyzer", "indexAnalyzer", "synonymMaps", "fields"))
implicit val IiEnc: RootJsonFormat[IndexInfo] = jsonFormat10(IndexInfo.apply)
"facetable","retrievable", "key","analyzer","searchAnalyzer", "indexAnalyzer", "synonymMaps", "fields",
"dimensions", "vectorSearchConfiguration"))
implicit val AcEnc: RootJsonFormat[AlgorithmConfigs] = jsonFormat2(AlgorithmConfigs.apply)
implicit val VsEnc: RootJsonFormat[VectorSearch] = jsonFormat1(VectorSearch.apply)
implicit val IiEnc: RootJsonFormat[IndexInfo] = jsonFormat11(IndexInfo.apply)
implicit val IsEnc: RootJsonFormat[IndexStats] = jsonFormat2(IndexStats.apply)
implicit val InEnc: RootJsonFormat[IndexName] = jsonFormat1(IndexName.apply)
implicit val IlEnc: RootJsonFormat[IndexList] = jsonFormat2(IndexList.apply)
implicit val VcpEnc: RootJsonFormat[VectorColParams] = jsonFormat2(VectorColParams.apply)
}
Loading

0 comments on commit d494f6e

Please sign in to comment.