Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Synthetic difference in differences #2095

Merged
merged 51 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
80762ec
Estimators for diff-in-diff, synthetic control and synthetic diff-in-…
memoryz Oct 4, 2023
2dbd448
add more params
memoryz Oct 4, 2023
1de7aa1
refactor
memoryz Oct 4, 2023
98a6659
adding unit tests for linalg
memoryz Oct 4, 2023
29deccc
more unit tests
memoryz Oct 5, 2023
b7293c2
Unit test for DiffInDiffEstimator
memoryz Oct 6, 2023
0eb6b62
more unit tests
memoryz Oct 6, 2023
5e93636
unit test for SyntheticControlEstimator
memoryz Oct 7, 2023
2c7a6a0
unit test for SyntheticDiffInDiffEstimator
memoryz Oct 7, 2023
697b578
logClass
memoryz Oct 7, 2023
0aa9fc2
Python code gen
memoryz Oct 8, 2023
5660ace
pyspark wrapper
memoryz Oct 9, 2023
b2d9477
expose loss history
memoryz Oct 9, 2023
ca354cc
fix bugs for synthetic control
memoryz Oct 10, 2023
37b138d
Merge branch 'microsoft:master' into master
memoryz Oct 11, 2023
b02aa6f
fix time effects for synthetic control estimator
memoryz Oct 12, 2023
e55a8ac
fix unit test
memoryz Oct 12, 2023
b76641b
add notebook
memoryz Oct 12, 2023
1a2e320
fixing indexing logic
memoryz Oct 26, 2023
03b8bbc
add file headers
memoryz Oct 26, 2023
3bee3d7
Merge branch 'master' into master
memoryz Oct 26, 2023
5787e9e
Add feature name to logClass call
memoryz Oct 26, 2023
faa1900
more scalastyle fixes
memoryz Oct 26, 2023
a7b6f20
More scalastyle and unit test fixes
memoryz Oct 27, 2023
0f37f60
Python style fix
memoryz Oct 27, 2023
aa6728a
fix unit test
memoryz Oct 27, 2023
c1eeaff
fix more python style issue
memoryz Oct 27, 2023
f6aaed6
Merge branch 'master' into master
memoryz Oct 27, 2023
b0326a7
python style fix
memoryz Oct 27, 2023
813a68d
Merge branch 'microsoft:master' into master
memoryz Oct 27, 2023
df296a6
fix unit test
memoryz Oct 27, 2023
280ab2d
Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/Diff…
memoryz Oct 30, 2023
638d9b4
Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/Synt…
memoryz Oct 30, 2023
6553653
Update core/src/main/scala/com/microsoft/azure/synapse/ml/causal/Synt…
memoryz Oct 30, 2023
864a04f
addressing comments
memoryz Oct 30, 2023
7c47d42
extract some constants to findUnusedColumn
memoryz Oct 30, 2023
e60ee35
Merge branch 'master' into master
memoryz Oct 30, 2023
43f767f
Expose zeta as an optional parameter, also return the RMSE for unit w…
memoryz Oct 31, 2023
076576e
Merge branch 'master' into master
memoryz Oct 31, 2023
394936e
Replace constant TimeIdxCol and UnitIdxCol with findUnusedColumn
memoryz Oct 31, 2023
be581ad
typo
memoryz Oct 31, 2023
a3af20e
Adding notebook to sidebar
memoryz Oct 31, 2023
2e636e0
Merge branch 'master' into master
memoryz Nov 3, 2023
5b9ab17
fix bad merge
memoryz Nov 3, 2023
c2d413f
Merge branch 'master' into master
memoryz Nov 11, 2023
af8c88e
Merge branch 'master' into master
memoryz Dec 18, 2023
666d1ed
address code review comments
memoryz Dec 18, 2023
5cc9027
Update docs/Explore Algorithms/Causal Inference/Quickstart - Syntheti…
memoryz Dec 18, 2023
f9bb83c
clean synapse widget output state
memoryz Dec 18, 2023
f1bf701
remove invalid image links
memoryz Jan 4, 2024
cb8fd47
Merge branch 'master' into master
mhamilton723 Jan 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions core/src/main/python/synapse/ml/causal/DiffInDiffModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
memoryz marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the MIT License. See LICENSE in project root for information.

import sys

if sys.version >= "3":
basestring = str

from synapse.ml.causal._DiffInDiffModel import _DiffInDiffModel
from pyspark.ml.common import inherit_doc
from pyspark.sql import SparkSession, DataFrame
from pyspark import SparkContext, SQLContext


@inherit_doc
class DiffInDiffModel(_DiffInDiffModel):
@staticmethod
def _mapOption(option, func):
return func(option.get()) if option.isDefined() else None

@staticmethod
def _unwrapOption(option):
return DiffInDiffModel._mapOption(option, lambda x: x)

def __init__(self, java_obj=None) -> None:
super(DiffInDiffModel, self).__init__(java_obj=java_obj)

ctx = SparkContext._active_spark_context
sql_ctx = SQLContext.getOrCreate(ctx)

self.summary = java_obj.getSummary()
self.treatmentEffect = self.summary.treatmentEffect()
self.standardError = self.summary.standardError()
self.timeIntercept = DiffInDiffModel._unwrapOption(self.summary.timeIntercept())
self.unitIntercept = DiffInDiffModel._unwrapOption(self.summary.unitIntercept())
self.timeWeights = DiffInDiffModel._mapOption(
java_obj.getTimeWeights(), lambda x: DataFrame(x, sql_ctx)
)
self.unitWeights = DiffInDiffModel._mapOption(
java_obj.getUnitWeights(), lambda x: DataFrame(x, sql_ctx)
)
self.timeRMSE = DiffInDiffModel._unwrapOption(self.summary.timeRMSE())
self.unitRMSE = DiffInDiffModel._unwrapOption(self.summary.unitRMSE())
self.zeta = DiffInDiffModel._unwrapOption(self.summary.zeta())
self.lossHistoryTimeWeights = DiffInDiffModel._unwrapOption(
self.summary.getLossHistoryTimeWeightsJava()
)
self.lossHistoryUnitWeights = DiffInDiffModel._unwrapOption(
self.summary.getLossHistoryUnitWeightsJava()
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.causal.linalg.DVector
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import com.microsoft.azure.synapse.ml.param.DataFrameParam
import org.apache.spark.SparkException
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model}
import org.apache.spark.sql.types.{BooleanType, NumericType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

import java.util

abstract class BaseDiffInDiffEstimator(override val uid: String)
extends Estimator[DiffInDiffModel]
with DiffInDiffEstimatorParams {

private def validateFieldNumericOrBooleanType(field: StructField): Unit = {
val dataType = field.dataType
require(dataType.isInstanceOf[NumericType] || dataType == BooleanType,
s"Column ${field.name} must be numeric type or boolean type, but got $dataType instead.")
}

protected def validateFieldNumericType(field: StructField): Unit = {
val dataType = field.dataType
require(dataType.isInstanceOf[NumericType],
s"Column ${field.name} must be numeric type, but got $dataType instead.")
}

override def transformSchema(schema: StructType): StructType = {
validateFieldNumericOrBooleanType(schema(getPostTreatmentCol))
validateFieldNumericOrBooleanType(schema(getTreatmentCol))
validateFieldNumericType(schema(getOutcomeCol))
schema
}

override def copy(extra: ParamMap): Estimator[DiffInDiffModel] = defaultCopy(extra)

private[causal] val findInteractionCol = DatasetExtensions.findUnusedColumnName("interaction") _

private[causal] def fitLinearModel(df: DataFrame,
featureCols: Array[String],
fitIntercept: Boolean,
weightCol: Option[String] = None) = {

val featuresCol = DatasetExtensions.findUnusedColumnName("features", df)
val assembler = new VectorAssembler()
.setInputCols(featureCols)
.setOutputCol(featuresCol)

val regression = weightCol
.map(new LinearRegression().setWeightCol)
.getOrElse(new LinearRegression())

regression
.setFeaturesCol(featuresCol)
.setLabelCol(getOutcomeCol)
.setFitIntercept(fitIntercept)
.setLoss("squaredError")
.setRegParam(1E-10)

assembler.transform _ andThen regression.fit apply df
}
}

case class DiffInDiffSummary(treatmentEffect: Double, standardError: Double,
timeWeights: Option[DVector] = None,
timeIntercept: Option[Double] = None,
timeRMSE: Option[Double] = None,
unitWeights: Option[DVector] = None,
unitIntercept: Option[Double] = None,
unitRMSE: Option[Double] = None,
zeta: Option[Double] = None,
lossHistoryTimeWeights: Option[List[Double]] = None,
lossHistoryUnitWeights: Option[List[Double]] = None) {
import scala.collection.JavaConverters._

def getLossHistoryTimeWeightsJava: Option[util.List[Double]] = {
lossHistoryTimeWeights.map(_.asJava)
}

def getLossHistoryUnitWeightsJava: Option[util.List[Double]] = {
lossHistoryUnitWeights.map(_.asJava)
}
}

class DiffInDiffModel(override val uid: String)
extends Model[DiffInDiffModel]
with HasUnitCol
with HasTimeCol
with Wrappable
with ComplexParamsWritable
with SynapseMLLogging {

logClass(FeatureNames.Causal)

final val timeIndex = new DataFrameParam(this, "timeIndex", "time index")
def getTimeIndex: DataFrame = $(timeIndex)
def setTimeIndex(value: DataFrame): this.type = set(timeIndex, value)

final val timeIndexCol = new Param[String](this, "timeIndexCol", "time index column")
def getTimeIndexCol: String = $(timeIndexCol)
def setTimeIndexCol(value: String): this.type = set(timeIndexCol, value)

final val unitIndex = new DataFrameParam(this, "unitIndex", "unit index")
def getUnitIndex: DataFrame = $(unitIndex)
def setUnitIndex(value: DataFrame): this.type = set(unitIndex, value)

final val unitIndexCol = new Param[String](this, "unitIndexCol", "unit index column")
def getUnitIndexCol: String = $(unitIndexCol)
def setUnitIndexCol(value: String): this.type = set(unitIndexCol, value)

override protected lazy val pyInternalWrapper = true

def this() = this(Identifiable.randomUID("DiffInDiffModel"))

private final var summary: Option[DiffInDiffSummary] = None

def getSummary: DiffInDiffSummary = summary.getOrElse {
throw new SparkException(
s"No summary available for this ${this.getClass.getSimpleName}")
}

private[causal] def setSummary(summary: Option[DiffInDiffSummary]): this.type = {
this.summary = summary
this
}

override def copy(extra: ParamMap): DiffInDiffModel = {
copyValues(new DiffInDiffModel(uid), extra)
.setSummary(this.summary)
.setParent(parent)
}

override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF

override def transformSchema(schema: StructType): StructType = schema

def getTimeWeights: Option[DataFrame] = {
(get(timeIndex), getSummary.timeWeights) match {
case (Some(idxDf), Some(timeWeights)) =>
Some(
idxDf.join(timeWeights, idxDf(getTimeIndexCol) === timeWeights("i"), "left_outer")
.select(
idxDf(getTimeCol),
timeWeights("value")
)
)
case _ =>
None
}
}

def getUnitWeights: Option[DataFrame] = {
(get(unitIndex), getSummary.unitWeights) match {
case (Some(idxDf), Some(unitWeights)) =>
Some(
idxDf.join(unitWeights, idxDf(getUnitIndexCol) === unitWeights("i"), "left_outer")
.select(
idxDf(getUnitCol),
unitWeights("value")
)
)
case _ =>
None
}
}
}

object DiffInDiffModel extends ComplexParamsReadable[DiffInDiffModel]

trait DiffInDiffEstimatorParams extends Params
with HasTreatmentCol
with HasOutcomeCol
with HasPostTreatmentCol
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.causal

import breeze.linalg.{DenseVector => BDV}
import com.microsoft.azure.synapse.ml.causal.linalg.DVector
trait CacheOps[T] {
def checkpoint(data: T): T = data
def cache(data: T): T = data
}

object BDVCacheOps extends CacheOps[BDV[Double]] {
override def checkpoint(data: BDV[Double]): BDV[Double] = data
override def cache(data: BDV[Double]): BDV[Double] = data
}

object DVectorCacheOps extends CacheOps[DVector] {
override def checkpoint(data: DVector): DVector = data.localCheckpoint(true)
override def cache(data: DVector): DVector = data.cache
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable}
import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._

class DiffInDiffEstimator(override val uid: String)
extends BaseDiffInDiffEstimator(uid)
with ComplexParamsWritable
with Wrappable
with SynapseMLLogging {

logClass(FeatureNames.Causal)

def this() = this(Identifiable.randomUID("DiffInDiffEstimator"))

override def fit(dataset: Dataset[_]): DiffInDiffModel = logFit({
val interactionCol = findInteractionCol(dataset.columns.toSet)
val postTreatment = col(getPostTreatmentCol)
val treatment = col(getTreatmentCol)
val outcome = col(getOutcomeCol)

val didData = dataset.select(
postTreatment.cast(IntegerType).as(getPostTreatmentCol),
treatment.cast(IntegerType).as(getTreatmentCol),
outcome.cast(DoubleType).as(getOutcomeCol)
)
.withColumn(interactionCol, treatment * postTreatment)

val linearModel = fitLinearModel(
didData,
Array(getPostTreatmentCol, getTreatmentCol, interactionCol),
fitIntercept = true
)

val treatmentEffect = linearModel.coefficients(2)
val standardError = linearModel.summary.coefficientStandardErrors(2)
val summary = DiffInDiffSummary(treatmentEffect, standardError)

copyValues(new DiffInDiffModel(this.uid))
.setSummary(Some(summary))
.setParent(this)
}, dataset.columns.length)
}

object DiffInDiffEstimator extends ComplexParamsReadable[DiffInDiffEstimator]
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,14 @@

package com.microsoft.azure.synapse.ml.causal

import com.microsoft.azure.synapse.ml.core.contracts.{HasFeaturesCol, HasLabelCol, HasWeightCol}
import com.microsoft.azure.synapse.ml.core.contracts.{HasFeaturesCol, HasWeightCol}
import com.microsoft.azure.synapse.ml.param.EstimatorParam
import org.apache.spark.ml.classification.{LogisticRegression, ProbabilisticClassifier}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.ParamInjections.HasParallelismInjected
import org.apache.spark.ml.param.shared.{HasMaxIter, HasPredictionCol}
import org.apache.spark.ml.param.{DoubleArrayParam, DoubleParam, Param, Params}
import org.apache.spark.ml.classification.{LogisticRegression, ProbabilisticClassifier}
import org.apache.spark.ml.param.shared.HasMaxIter
import org.apache.spark.ml.param.{DoubleArrayParam, DoubleParam, Params}
import org.apache.spark.ml.regression.Regressor

trait HasTreatmentCol extends Params {
Copy link
Collaborator

@mhamilton723 mhamilton723 Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically now that we are 1.0 this is a breaking change, if you could deprecate and create new ones would appreciate it. To defend yourself against this, please consider makign any APIs you dont want to have to babysit private, private[ml], or protected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trait is shared between multiple estimators (DoubleML, SyntheticDiffInDiff, etc.), so I moved it to SharedParams.scala. The trait is still defined in the same package, so it shouldn't be breaking.

val treatmentCol = new Param[String](this, "treatmentCol", "treatment column")
def getTreatmentCol: String = $(treatmentCol)

/**
* Set name of the column which will be used as treatment
*
* @group setParam
*/
def setTreatmentCol(value: String): this.type = set(treatmentCol, value)
}

trait HasOutcomeCol extends Params {
val outcomeCol: Param[String] = new Param[String](this, "outcomeCol", "outcome column")
def getOutcomeCol: String = $(outcomeCol)

/**
* Set name of the column which will be used as outcome
*
* @group setParam
*/
def setOutcomeCol(value: String): this.type = set(outcomeCol, value)
}
import org.apache.spark.ml.{Estimator, Model}

trait DoubleMLParams extends Params
with HasTreatmentCol with HasOutcomeCol with HasFeaturesCol
Expand Down Expand Up @@ -86,7 +62,7 @@ trait DoubleMLParams extends Params
def setSampleSplitRatio(value: Array[Double]): this.type = set(sampleSplitRatio, value)

private[causal] object DoubleMLModelTypes extends Enumeration {
type TreatmentType = Value
type DoubleMLModelTypes = Value
val Binary, Continuous = Value
}

Expand Down
Loading
Loading