Skip to content

Commit

Permalink
feat: add string to sar
Browse files Browse the repository at this point in the history
  • Loading branch information
dciborow authored Sep 7, 2024
1 parent f3953bc commit d580344
Show file tree
Hide file tree
Showing 7 changed files with 531 additions and 266 deletions.
29 changes: 29 additions & 0 deletions core/src/main/python/synapse/ml/recommendation/SAR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in the project root for information.

import sys

if sys.version >= "3":
basestring = str

from synapse.ml.core.schema.Utils import *
from synapse.ml.recommendation._SAR import _SAR

@inherit_doc
class SAR(_SAR):
def __init__(self, **kwargs):
_SAR.__init__(self, **kwargs)

def calculateUserItemAffinities(self, dataset):
if dataset.schema[self.getUserCol()].dataType == StringType():
dataset = dataset.withColumn(self.getUserCol(), dataset[self.getUserCol()].cast("int"))
if dataset.schema[self.getItemCol()].dataType == StringType():
dataset = dataset.withColumn(self.getItemCol(), dataset[self.getItemCol()].cast("int"))
return self._call_java("calculateUserItemAffinities", dataset)

def calculateItemItemSimilarity(self, dataset):
if dataset.schema[self.getUserCol()].dataType == StringType():
dataset = dataset.withColumn(self.getUserCol(), dataset[self.getUserCol()].cast("int"))
if dataset.schema[self.getItemCol()].dataType == StringType():
dataset = dataset.withColumn(self.getItemCol(), dataset[self.getItemCol()].cast("int"))
return self._call_java("calculateItemItemSimilarity", dataset)
5 changes: 5 additions & 0 deletions core/src/main/python/synapse/ml/recommendation/SARModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@
class SARModel(_SARModel):
def recommendForAllUsers(self, numItems):
return self._call_java("recommendForAllUsers", numItems)

def recommendForUserSubset(self, dataset, numItems):
if dataset.schema[self.getUserCol()].dataType == StringType():
dataset = dataset.withColumn(self.getUserCol(), dataset[self.getUserCol()].cast("int"))
return self._call_java("recommendForUserSubset", dataset, numItems)
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.recommendation

import breeze.linalg.{CSCMatrix => BSM}
Expand All @@ -13,7 +10,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, I
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseMatrix}
import org.apache.spark.sql.functions.{col, collect_list, sum, udf, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

import java.text.SimpleDateFormat
Expand Down Expand Up @@ -106,8 +103,22 @@ class SAR(override val uid: String) extends Estimator[SARModel]
(0 to numItems.value).map(i => map.getOrElse(i, 0.0).toFloat).toArray
})

dataset
.withColumn(C.AffinityCol, (dataset.columns.contains(getTimeCol), dataset.columns.contains(getRatingCol)) match {
val userColType = dataset.schema(getUserCol).dataType
val itemColType = dataset.schema(getItemCol).dataType

val castedDataset = (userColType, itemColType) match {
case (StringType, StringType) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
.withColumn(getItemCol, col(getItemCol).cast("int"))
case (StringType, _) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
case (_, StringType) =>
dataset.withColumn(getItemCol, col(getItemCol).cast("int"))
case _ => dataset
}

castedDataset
.withColumn(C.AffinityCol, (castedDataset.columns.contains(getTimeCol), castedDataset.columns.contains(getRatingCol)) match {
case (true, true) => blendWeights(timeDecay(col(getTimeCol)), col(getRatingCol))
case (true, false) => timeDecay(col(getTimeCol))
case (false, true) => col(getRatingCol)
Expand Down Expand Up @@ -197,7 +208,21 @@ class SAR(override val uid: String) extends Estimator[SARModel]
})
})

dataset
val userColType = dataset.schema(getUserCol).dataType
val itemColType = dataset.schema(getItemCol).dataType

val castedDataset = (userColType, itemColType) match {
case (StringType, StringType) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
.withColumn(getItemCol, col(getItemCol).cast("int"))
case (StringType, _) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
case (_, StringType) =>
dataset.withColumn(getItemCol, col(getItemCol).cast("int"))
case _ => dataset
}

castedDataset
.select(col(getItemCol), col(getUserCol))
.groupBy(getItemCol).agg(collect_list(getUserCol) as "collect_list")
.withColumn(C.FeaturesCol, createItemFeaturesVector(col("collect_list")))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.recommendation

import com.microsoft.azure.synapse.ml.codegen.Wrappable
Expand Down
64 changes: 55 additions & 9 deletions core/src/test/python/synapsemltest/recommendation/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,52 @@
.cache()
)

ratings_with_strings = (
spark.createDataFrame(
[
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3),
],
["originalCustomerID", "newCategoryID", "rating", "notTime"],
)
.coalesce(1)
.cache()
)


class RankingSpec(unittest.TestCase):
@staticmethod
def adapter_evaluator(algo):
def adapter_evaluator(algo, data):
recommendation_indexer = RecommendationIndexer(
userInputCol=USER_ID,
userOutputCol=USER_ID_INDEX,
Expand All @@ -80,7 +122,7 @@ def adapter_evaluator(algo):

adapter = RankingAdapter(mode="allUsers", k=5, recommender=algo)
pipeline = Pipeline(stages=[recommendation_indexer, adapter])
output = pipeline.fit(ratings).transform(ratings)
output = pipeline.fit(data).transform(data)
print(str(output.take(1)) + "\n")

metrics = ["ndcgAt", "fcp", "mrr"]
Expand All @@ -91,13 +133,17 @@ def adapter_evaluator(algo):
+ str(RankingEvaluator(k=3, metricName=metric).evaluate(output)),
)

# def test_adapter_evaluator_als(self):
# als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
# self.adapter_evaluator(als)
#
# def test_adapter_evaluator_sar(self):
# sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
# self.adapter_evaluator(sar)
def test_adapter_evaluator_als(self):
als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(als, ratings)

def test_adapter_evaluator_sar(self):
sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(sar, ratings)

def test_adapter_evaluator_sar_with_strings(self):
sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(sar, ratings_with_strings)

def test_all_tiny(self):
customer_index = StringIndexer(inputCol=USER_ID, outputCol=USER_ID_INDEX)
Expand Down
Loading

0 comments on commit d580344

Please sign in to comment.