Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for userCol and itemCol as string types in SAR model #2283

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
5 changes: 5 additions & 0 deletions core/src/main/python/synapse/ml/recommendation/SARModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@
class SARModel(_SARModel):
def recommendForAllUsers(self, numItems):
return self._call_java("recommendForAllUsers", numItems)

def recommendForUserSubset(self, dataset, numItems):
if dataset.schema[self.getUserCol()].dataType == StringType():
dataset = dataset.withColumn(self.getUserCol(), dataset[self.getUserCol()].cast("int"))
return self._call_java("recommendForUserSubset", dataset, numItems)
dciborow marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, I
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseMatrix}
import org.apache.spark.sql.functions.{col, collect_list, sum, udf, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

import java.text.SimpleDateFormat
Expand Down Expand Up @@ -106,8 +106,22 @@ class SAR(override val uid: String) extends Estimator[SARModel]
(0 to numItems.value).map(i => map.getOrElse(i, 0.0).toFloat).toArray
})

dataset
.withColumn(C.AffinityCol, (dataset.columns.contains(getTimeCol), dataset.columns.contains(getRatingCol)) match {
val userColType = dataset.schema(getUserCol).dataType
val itemColType = dataset.schema(getItemCol).dataType

val castedDataset = (userColType, itemColType) match {
case (StringType, StringType) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
.withColumn(getItemCol, col(getItemCol).cast("int"))
case (StringType, _) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
case (_, StringType) =>
dataset.withColumn(getItemCol, col(getItemCol).cast("int"))
case _ => dataset
}

castedDataset
.withColumn(C.AffinityCol, (castedDataset.columns.contains(getTimeCol), castedDataset.columns.contains(getRatingCol)) match {
case (true, true) => blendWeights(timeDecay(col(getTimeCol)), col(getRatingCol))
case (true, false) => timeDecay(col(getTimeCol))
case (false, true) => col(getRatingCol)
Expand Down Expand Up @@ -197,7 +211,21 @@ class SAR(override val uid: String) extends Estimator[SARModel]
})
})

dataset
val userColType = dataset.schema(getUserCol).dataType
val itemColType = dataset.schema(getItemCol).dataType

val castedDataset = (userColType, itemColType) match {
case (StringType, StringType) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
.withColumn(getItemCol, col(getItemCol).cast("int"))
case (StringType, _) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
case (_, StringType) =>
dataset.withColumn(getItemCol, col(getItemCol).cast("int"))
case _ => dataset
}

castedDataset
.select(col(getItemCol), col(getUserCol))
.groupBy(getItemCol).agg(collect_list(getUserCol) as "collect_list")
.withColumn(C.FeaturesCol, createItemFeaturesVector(col("collect_list")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,52 @@
.cache()
)

ratings_with_strings = (
spark.createDataFrame(
[
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3),
],
["originalCustomerID", "newCategoryID", "rating", "notTime"],
)
.coalesce(1)
.cache()
)


class RankingSpec(unittest.TestCase):
@staticmethod
def adapter_evaluator(algo):
def adapter_evaluator(algo, data):
recommendation_indexer = RecommendationIndexer(
userInputCol=USER_ID,
userOutputCol=USER_ID_INDEX,
Expand All @@ -80,7 +122,7 @@ def adapter_evaluator(algo):

adapter = RankingAdapter(mode="allUsers", k=5, recommender=algo)
pipeline = Pipeline(stages=[recommendation_indexer, adapter])
output = pipeline.fit(ratings).transform(ratings)
output = pipeline.fit(data).transform(data)
print(str(output.take(1)) + "\n")

metrics = ["ndcgAt", "fcp", "mrr"]
Expand All @@ -91,13 +133,17 @@ def adapter_evaluator(algo):
+ str(RankingEvaluator(k=3, metricName=metric).evaluate(output)),
)

# def test_adapter_evaluator_als(self):
# als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
# self.adapter_evaluator(als)
#
# def test_adapter_evaluator_sar(self):
# sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
# self.adapter_evaluator(sar)
def test_adapter_evaluator_als(self):
als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(als, ratings)

def test_adapter_evaluator_sar(self):
sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(sar, ratings)

def test_adapter_evaluator_sar_with_strings(self):
sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(sar, ratings_with_strings)

def test_all_tiny(self):
customer_index = StringIndexer(inputCol=USER_ID, outputCol=USER_ID_INDEX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ class SARSpec extends RankingTestBase with EstimatorFuzzing[SAR] {
new TestObject(new SAR()
.setUserCol(recommendationIndexer.getUserOutputCol)
.setItemCol(recommendationIndexer.getItemOutputCol)
.setRatingCol(ratingCol), transformedDf)
.setRatingCol(ratingCol), transformedDf),
new TestObject(new SAR()
.setUserCol(recommendationIndexer.getUserOutputCol)
.setItemCol(recommendationIndexer.getItemOutputCol)
.setRatingCol(ratingCol), transformedDfWithStrings)
)
}

Expand Down Expand Up @@ -62,6 +66,41 @@ class SARSpec extends RankingTestBase with EstimatorFuzzing[SAR] {
assert(recs.count == 2)
}

test("SAR with string userCol and itemCol") {

val algo = sar
.setSupportThreshold(1)
.setSimilarityFunction("jacccard")
.setActivityTimeFormat("EEE MMM dd HH:mm:ss Z yyyy")

val adapter: RankingAdapter = new RankingAdapter()
.setK(5)
.setRecommender(algo)

val recopipeline = new Pipeline()
.setStages(Array(recommendationIndexer, adapter))
.fit(ratingsWithStrings)

val output = recopipeline.transform(ratingsWithStrings)

val evaluator: RankingEvaluator = new RankingEvaluator()
.setK(5)
.setNItems(10)

assert(evaluator.setMetricName("ndcgAt").evaluate(output) === 0.602819875812812)
assert(evaluator.setMetricName("fcp").evaluate(output) === 0.05 ||
evaluator.setMetricName("fcp").evaluate(output) === 0.1)
assert(evaluator.setMetricName("mrr").evaluate(output) === 1.0)

val users: DataFrame = spark
.createDataFrame(Seq(("user0","item0"),("user1","item1")))
.toDF(userColIndex, itemColIndex)

val recs = recopipeline.stages(1).asInstanceOf[RankingAdapterModel].getRecommenderModel
.asInstanceOf[SARModel].recommendForUserSubset(users, 10)
assert(recs.count == 2)
}

lazy val testFile: String = getClass.getResource("/demoUsage.csv.gz").getPath
lazy val simCount1: String = getClass.getResource("/sim_count1.csv.gz").getPath
lazy val simLift1: String = getClass.getResource("/sim_lift1.csv.gz").getPath
Expand Down Expand Up @@ -115,7 +154,12 @@ class SARModelSpec extends RankingTestBase with TransformerFuzzing[SARModel] {
.setUserCol(recommendationIndexer.getUserOutputCol)
.setItemCol(recommendationIndexer.getItemOutputCol)
.setRatingCol(ratingCol)
.fit(transformedDf), transformedDf)
.fit(transformedDf), transformedDf),
new TestObject(new SAR()
.setUserCol(recommendationIndexer.getUserOutputCol)
.setItemCol(recommendationIndexer.getItemOutputCol)
.setRatingCol(ratingCol)
.fit(transformedDfWithStrings), transformedDfWithStrings)
)
}

Expand Down
119 changes: 119 additions & 0 deletions docs/Quick Examples/estimators/core/_Recommendation.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,43 @@ ratings = (spark.createDataFrame([
.dropDuplicates()
.cache())

ratings_with_strings = (spark.createDataFrame([
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3)
], ["originalCustomerID", "newCategoryID", "rating", "notTime"])
.coalesce(1)
.cache())

dciborow marked this conversation as resolved.
Show resolved Hide resolved
recommendationIndexer = (RecommendationIndexer()
.setUserInputCol("customerIDOrg")
.setUserOutputCol("customerID")
Expand Down Expand Up @@ -275,6 +312,43 @@ ratings = (spark.createDataFrame([
.dropDuplicates()
.cache())

ratings_with_strings = (spark.createDataFrame([
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3)
], ["originalCustomerID", "newCategoryID", "rating", "notTime"])
.coalesce(1)
.cache())

dciborow marked this conversation as resolved.
Show resolved Hide resolved
recommendationIndexer = (RecommendationIndexer()
.setUserInputCol("customerIDOrg")
.setUserOutputCol("customerID")
Expand All @@ -298,6 +372,10 @@ adapter = (RankingAdapter()
res1 = recommendationIndexer.fit(ratings).transform(ratings).cache()

adapter.fit(res1).transform(res1).show()

res2 = recommendationIndexer.fit(ratings_with_strings).transform(ratings_with_strings).cache()

adapter.fit(res2).transform(res2).show()
dciborow marked this conversation as resolved.
Show resolved Hide resolved
```

</TabItem>
Expand Down Expand Up @@ -344,6 +422,43 @@ val ratings = (Seq(
.dropDuplicates()
.cache())

val ratings_with_strings = (Seq(
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3))
.toDF("originalCustomerID", "newCategoryID", "rating", "notTime")
.coalesce(1)
.cache())

dciborow marked this conversation as resolved.
Show resolved Hide resolved
val recommendationIndexer = (new RecommendationIndexer()
.setUserInputCol("customerIDOrg")
.setUserOutputCol("customerID")
Expand All @@ -367,6 +482,10 @@ val adapter = (new RankingAdapter()
val res1 = recommendationIndexer.fit(ratings).transform(ratings).cache()

adapter.fit(res1).transform(res1).show()

val res2 = recommendationIndexer.fit(ratings_with_strings).transform(ratings_with_strings).cache()

adapter.fit(res2).transform(res2).show()
dciborow marked this conversation as resolved.
Show resolved Hide resolved
```

</TabItem>
Expand Down