Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for userCol and itemCol as string types in SAR model #2283

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
5 changes: 5 additions & 0 deletions core/src/main/python/synapse/ml/recommendation/SARModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@
class SARModel(_SARModel):
def recommendForAllUsers(self, numItems):
return self._call_java("recommendForAllUsers", numItems)

def recommendForUserSubset(self, dataset, numItems):
if dataset.schema[self.getUserCol()].dataType == StringType():
dataset = dataset.withColumn(self.getUserCol(), dataset[self.getUserCol()].cast("int"))
return self._call_java("recommendForUserSubset", dataset, numItems)
dciborow marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.recommendation
dciborow marked this conversation as resolved.
Show resolved Hide resolved

import breeze.linalg.{CSCMatrix => BSM}
Expand All @@ -13,7 +10,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, I
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseMatrix}
import org.apache.spark.sql.functions.{col, collect_list, sum, udf, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

import java.text.SimpleDateFormat
Expand Down Expand Up @@ -106,8 +103,22 @@ class SAR(override val uid: String) extends Estimator[SARModel]
(0 to numItems.value).map(i => map.getOrElse(i, 0.0).toFloat).toArray
})

dataset
.withColumn(C.AffinityCol, (dataset.columns.contains(getTimeCol), dataset.columns.contains(getRatingCol)) match {
val userColType = dataset.schema(getUserCol).dataType
val itemColType = dataset.schema(getItemCol).dataType

val castedDataset = (userColType, itemColType) match {
case (StringType, StringType) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
.withColumn(getItemCol, col(getItemCol).cast("int"))
case (StringType, _) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
case (_, StringType) =>
dataset.withColumn(getItemCol, col(getItemCol).cast("int"))
case _ => dataset
}

castedDataset
.withColumn(C.AffinityCol, (castedDataset.columns.contains(getTimeCol), castedDataset.columns.contains(getRatingCol)) match {
case (true, true) => blendWeights(timeDecay(col(getTimeCol)), col(getRatingCol))
case (true, false) => timeDecay(col(getTimeCol))
case (false, true) => col(getRatingCol)
Expand Down Expand Up @@ -197,7 +208,21 @@ class SAR(override val uid: String) extends Estimator[SARModel]
})
})

dataset
val userColType = dataset.schema(getUserCol).dataType
val itemColType = dataset.schema(getItemCol).dataType

val castedDataset = (userColType, itemColType) match {
case (StringType, StringType) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
.withColumn(getItemCol, col(getItemCol).cast("int"))
case (StringType, _) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
case (_, StringType) =>
dataset.withColumn(getItemCol, col(getItemCol).cast("int"))
case _ => dataset
}

castedDataset
.select(col(getItemCol), col(getUserCol))
.groupBy(getItemCol).agg(collect_list(getUserCol) as "collect_list")
.withColumn(C.FeaturesCol, createItemFeaturesVector(col("collect_list")))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.recommendation
dciborow marked this conversation as resolved.
Show resolved Hide resolved

import com.microsoft.azure.synapse.ml.codegen.Wrappable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,52 @@
.cache()
)

ratings_with_strings = (
spark.createDataFrame(
[
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3),
],
["originalCustomerID", "newCategoryID", "rating", "notTime"],
)
.coalesce(1)
.cache()
)


class RankingSpec(unittest.TestCase):
@staticmethod
def adapter_evaluator(algo):
def adapter_evaluator(algo, data):
recommendation_indexer = RecommendationIndexer(
userInputCol=USER_ID,
userOutputCol=USER_ID_INDEX,
Expand All @@ -80,7 +122,7 @@ def adapter_evaluator(algo):

adapter = RankingAdapter(mode="allUsers", k=5, recommender=algo)
pipeline = Pipeline(stages=[recommendation_indexer, adapter])
output = pipeline.fit(ratings).transform(ratings)
output = pipeline.fit(data).transform(data)
print(str(output.take(1)) + "\n")

metrics = ["ndcgAt", "fcp", "mrr"]
Expand All @@ -91,13 +133,17 @@ def adapter_evaluator(algo):
+ str(RankingEvaluator(k=3, metricName=metric).evaluate(output)),
)

# def test_adapter_evaluator_als(self):
# als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
# self.adapter_evaluator(als)
#
# def test_adapter_evaluator_sar(self):
# sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
# self.adapter_evaluator(sar)
def test_adapter_evaluator_als(self):
als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(als, ratings)

def test_adapter_evaluator_sar(self):
sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(sar, ratings)

def test_adapter_evaluator_sar_with_strings(self):
sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(sar, ratings_with_strings)

def test_all_tiny(self):
customer_index = StringIndexer(inputCol=USER_ID, outputCol=USER_ID_INDEX)
Expand Down
Loading