diff --git a/00_intro.ipynb b/00_intro.ipynb index dd32621..1ecaa31 100644 --- a/00_intro.ipynb +++ b/00_intro.ipynb @@ -4,7 +4,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "2bb220eb-b4b5-46d8-a9a0-6e6331a01d46", "showTitle": false, @@ -19,7 +22,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "92a25ef9-1b17-4f26-a615-ad66f051c1e1", "showTitle": false, @@ -60,7 +66,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "3d527009-ad2d-4239-a427-37eabe35a6d6", "showTitle": false, @@ -114,7 +123,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "500f91de-740e-4cc9-b63b-2f81a77a9899", "showTitle": false, @@ -134,7 +146,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "7074a071-e301-4043-9d1d-2a90ee4f95e5", "showTitle": true, @@ -151,7 +166,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "195ac730-282a-43a9-9e60-04f0ac322c19", "showTitle": true, @@ -171,7 +189,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "4f30d5af-d4a0-4191-8d73-c7b4f52f15f6", "showTitle": false, @@ -209,7 +230,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "3f16324b-38db-4c66-8eda-32bf209db06b", "showTitle": true, @@ -228,7 +252,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "ad77ecd0-6ee3-4d25-8940-872490e6726a", "showTitle": false, @@ -246,7 +273,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "eea443a9-4080-44cc-96cc-161d579b970e", "showTitle": true, @@ -262,7 +292,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "0ffc2283-43d0-4444-91a7-9108753aa370", "showTitle": false, @@ -280,7 +313,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "683c13aa-0615-4a7e-b131-aa6754fb48df", "showTitle": true, @@ -300,7 +336,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "8bef3897-24b1-48c2-b944-bf2054bcecf7", "showTitle": false, @@ -315,7 +354,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "40f7a244-6869-49ca-9d90-92c071fe2c1b", "showTitle": false, @@ -333,7 +375,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "68e88377-88ed-40a2-b435-e559d6d7af44", "showTitle": true, @@ -351,7 +396,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "15f56bdb-4beb-4fe6-9a61-a495da7d5dc3", "showTitle": false, @@ -368,7 +416,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "32ad0039-f35e-4686-ad50-87dfaf65db1c", "showTitle": false, @@ -395,7 +446,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "6541a350-62cf-4f12-ae86-2a0bf0789829", "showTitle": false, @@ -410,7 +464,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "401f0dcc-4f88-4a51-a339-979af7c8cf1c", "showTitle": false, @@ -431,6 +488,7 @@ "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [], + "environmentMetadata": null, "language": "python", "notebookMetadata": { "pythonIndentUnit": 2 diff --git a/01_causal_discovery.ipynb b/01_causal_discovery.ipynb index 674ca43..ec47945 100644 --- a/01_causal_discovery.ipynb +++ b/01_causal_discovery.ipynb @@ -4,7 +4,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "c0be0376-f008-4fb8-8ca5-cd67241d7ec5", "showTitle": false, @@ -20,7 +23,53 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "60f2f5e1-6ae9-4eea-b39d-8275fdc66a76", + "showTitle": false, + "title": "" + } + }, + "outputs": [], + "source": [ + "%sh \n", + "sudo apt-get -qq update\n", + "sudo apt-get -y -qq install graphviz libgraphviz-dev" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, + "inputWidgets": {}, + "nuid": "61a0f224-8195-4c8c-82c5-c273e0d7ee9b", + "showTitle": true, + "title": "Install a library for visualization" + } + }, + "outputs": [], + "source": [ + "%pip install pygraphviz==1.10 --quiet\n", + "dbutils.library.restartPython()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "d2f6fd8a-0cb6-44ef-9312-8d6d27403d5a", "showTitle": false, @@ -36,7 +85,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "8be398d8-971b-4f43-a35d-782ad8eb618e", "showTitle": false, @@ -54,7 +106,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "539c43b9-846f-498b-807e-dd0864091d6b", "showTitle": false, @@ -93,7 +148,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "5fb82a8b-4352-43fe-966a-60d7a7160458", "showTitle": false, @@ -113,7 +171,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "04e61f57-7fdf-4467-b6d8-fad10dcc3024", "showTitle": false, @@ -128,7 +189,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "3d04da71-222a-4869-91e6-13a713ac1d99", "showTitle": false, @@ -150,7 +214,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "bcfe02e4-3fae-426d-a738-73d5835920d0", "showTitle": false, @@ -178,7 +245,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "5991b566-77f8-440d-b5cd-23ec2fcde25c", "showTitle": false, @@ -197,7 +267,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "910811fe-3f44-4c5f-b779-d81e013a983a", "showTitle": false, @@ -223,7 +296,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "b05c4958-c7d6-45e8-bc6c-7868cea8fd3f", "showTitle": false, @@ -239,7 +315,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "662f9d42-276e-4c09-b27d-6504c0bb1ba3", "showTitle": false, @@ -266,7 +345,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "dbc21d42-1b38-446b-ac49-888c69a40829", "showTitle": false, @@ -282,7 +364,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "2795b8e8-75c0-41d8-b7b4-84a366cd6df7", "showTitle": false, @@ -314,7 +399,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "7fc7ea1a-21ff-483f-8b1c-05a312df0022", "showTitle": false, @@ -332,7 +420,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "7ef03eaa-5715-4bc7-967b-b9a52baae38d", "showTitle": false, @@ -372,7 +463,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "eb6f6fd0-d18c-4aa6-9e01-00e6760e74e6", "showTitle": false, @@ -393,8 +487,15 @@ "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [], + "environmentMetadata": null, "language": "python", "notebookMetadata": { + "mostRecentlyExecutedCommandWithImplicitDF": { + "commandId": 888983575928260, + "dataframes": [ + "_sqldf" + ] + }, "pythonIndentUnit": 2 }, "notebookName": "01_causal_discovery", diff --git a/02_identification_estimation.ipynb b/02_identification_estimation.ipynb index 91b5139..a4daeb0 100644 --- a/02_identification_estimation.ipynb +++ b/02_identification_estimation.ipynb @@ -4,7 +4,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "f51698b9-84bb-44bf-bfc5-b8d97b68080a", "showTitle": false, @@ -20,7 +23,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "19b7bbc9-c08a-4a2b-87ac-5069175bfcb6", "showTitle": false, @@ -36,7 +42,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "270d5fe1-fa4d-4774-b9d9-e003abae1ab1", "showTitle": false, @@ -51,7 +60,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "f0607752-b4d8-4ec8-93d8-ce0bd7a2ea59", "showTitle": false, @@ -69,7 +81,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "a5ba81e1-2a9d-4515-a41c-d0e147267a6d", "showTitle": false, @@ -85,7 +100,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "4f64fdf2-ea9e-4f3e-8310-7653dee1d10a", "showTitle": false, @@ -101,7 +119,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "06921b02-d3e2-455f-af01-a919dbaf18d6", "showTitle": false, @@ -131,7 +152,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "cdd8d10c-7102-42e0-9c7b-f6594eb52020", "showTitle": false, @@ -146,7 +170,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "2afe3068-7875-4b22-a724-0e0bfc5a4eda", "showTitle": false, @@ -164,7 +191,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "26e4e3f5-8c38-45b7-9a7b-cee6970447d9", "showTitle": false, @@ -210,7 +240,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "86a7b8e8-b634-44d2-b17f-a58973ae3595", "showTitle": false, @@ -227,7 +260,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "20ca3b6d-731d-4840-8c29-9e5126e62db3", "showTitle": false, @@ -237,7 +273,7 @@ "outputs": [], "source": [ "model_details = register_dowhy_model(\n", - " model_name=\"tech_support_total_effect_dowhy_model\",\n", + " model_name=f\"{catalog}.{db}.tech_support_total_effect_dowhy_model\",\n", " model=tech_support_effect_model,\n", " estimand=tech_support_total_effect_identified_estimand,\n", " estimate=tech_support_total_effect_estimate,\n", @@ -248,7 +284,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "9a9b6c67-7a78-45e3-999d-c56bf4de97b8", "showTitle": false, @@ -266,7 +305,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "440e4943-d277-420b-a0a7-93417879be0c", "showTitle": false, @@ -288,7 +330,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "b464cd92-a05f-4e6e-a6e4-84944eba7b03", "showTitle": false, @@ -304,7 +349,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "923607d7-b98c-469e-836b-505502108f74", "showTitle": false, @@ -342,7 +390,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "95152568-5e08-4a38-b878-0c29746fc06b", "showTitle": false, @@ -358,7 +409,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "e681358f-ce1b-4a57-a00d-51826f184195", "showTitle": false, @@ -368,7 +422,7 @@ "outputs": [], "source": [ "model_details = register_dowhy_model(\n", - " model_name=\"tech_support_direct_effect_dowhy_model\",\n", + " model_name=f\"{catalog}.{db}.tech_support_direct_effect_dowhy_model\",\n", " model=tech_support_effect_model,\n", " estimand=tech_support_direct_effect_identified_estimand,\n", " estimate=tech_support_direct_effect_estimate,\n", @@ -379,7 +433,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "edb0944a-5523-4308-803f-bcc8146c19ca", "showTitle": false, @@ -397,7 +454,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "fe3a2163-841b-4166-a119-2ce485d91538", "showTitle": false, @@ -423,7 +483,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "2c909cb2-2577-481e-badb-0683f987ca5d", "showTitle": false, @@ -463,7 +526,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "0f17cddd-3978-49fb-a13e-beaa397793d8", "showTitle": false, @@ -473,7 +539,7 @@ "outputs": [], "source": [ "model_details = register_dowhy_model(\n", - " model_name=\"discount_dowhy_model\",\n", + " model_name=f\"{catalog}.{db}.discount_dowhy_model\",\n", " model=discount_effect_model,\n", " estimand=discount_effect_identified_estimand,\n", " estimate=discount_effect_estimate,\n", @@ -484,7 +550,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "5b52b2d6-7047-4544-bec4-d12ddccaa5d8", "showTitle": false, @@ -502,7 +571,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "8daac7f0-1286-45ce-aef5-3ea506e9ac42", "showTitle": false, @@ -527,7 +599,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "8226a19b-8575-49cc-ae9a-bad7e92cfed4", "showTitle": false, @@ -551,7 +626,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "9373cd46-22ca-4b83-8a46-c3df70768340", "showTitle": false, @@ -568,7 +646,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "495cfae7-1e55-499a-8013-0aa2378f40b2", "showTitle": false, @@ -586,7 +667,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "dbe3b251-ae2a-4253-bca3-9d31e9870593", "showTitle": false, @@ -619,7 +703,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "bf5ea988-f494-4a20-8572-b0e31fc3443e", "showTitle": false, @@ -638,7 +725,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "a5a469e8-a40e-4428-8884-dfd9c637c677", "showTitle": false, @@ -659,8 +749,15 @@ "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [], + "environmentMetadata": null, "language": "python", "notebookMetadata": { + "mostRecentlyExecutedCommandWithImplicitDF": { + "commandId": 888983575931288, + "dataframes": [ + "_sqldf" + ] + }, "pythonIndentUnit": 2 }, "notebookName": "02_identification_estimation", diff --git a/03_promotional_offer_recommender.ipynb b/03_promotional_offer_recommender.ipynb index dc83ff0..13003df 100644 --- a/03_promotional_offer_recommender.ipynb +++ b/03_promotional_offer_recommender.ipynb @@ -4,7 +4,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "1963c485-808e-4e6b-96bd-c10c4033ae99", "showTitle": false, @@ -20,7 +23,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "eb599ae6-4b84-4808-922b-458856a6fa1d", "showTitle": false, @@ -36,7 +42,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "3f9a1d4f-6247-4358-bf1a-f79d7f2e3e94", "showTitle": false, @@ -54,7 +63,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "a03e256a-67a4-4e6a-9367-003f6624e916", "showTitle": false, @@ -125,7 +137,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "182d5be1-f0fb-417f-bd9c-e8cb8c75440c", "showTitle": false, @@ -145,7 +160,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "e6204a2c-bc30-46da-9d5e-71adfbff6837", "showTitle": false, @@ -156,45 +174,41 @@ "source": [ "from mlflow.models.signature import infer_signature\n", "\n", - "model_name = \"personalized_policy_recommender\"\n", + "model_name = f\"{catalog}.{db}.personalized_policy_recommender\"\n", "\n", "with mlflow.start_run(run_name=f\"{model_name}_run\") as experiment_run:\n", " #Instantiate a model \n", " personalizedIncentiveRecommender = PersonalizedIncentiveRecommender(\n", " models_dictionary={\n", " \"tech support\": get_registered_wrapped_model_estimator(\n", - " model_name=\"tech_support_total_effect_dowhy_model\"\n", + " model_name=f\"{catalog}.{db}.tech_support_total_effect_dowhy_model\"\n", " ),\n", " \"discount\": get_registered_wrapped_model_estimator(\n", - " model_name=\"discount_dowhy_model\"\n", + " model_name=f\"{catalog}.{db}.discount_dowhy_model\"\n", " ),\n", " },\n", " effect_modifiers=[\"Size\", \"Global Flag\"],\n", " )\n", " #Log the model in MLflow\n", - " mlflow.pyfunc.log_model(\n", + " model_details = mlflow.pyfunc.log_model(\n", " artifact_path=\"model\",\n", " python_model=personalizedIncentiveRecommender,\n", + " registered_model_name=model_name,\n", " signature=infer_signature(\n", - " input_df.drop([\"Tech Support\", \"Discount\", \"New Engagement Strategy\"], axis=1), personalizedIncentiveRecommender.predict({}, input_df)\n", + " input_df.drop([\"Tech Support\", \"Discount\", \"New Engagement Strategy\"], axis=1), \n", + " personalizedIncentiveRecommender.predict({}, input_df)\n", " ),\n", - " )\n", - "\n", - "#Register the model in MLflow\n", - "model_details = mlflow.register_model(\n", - " model_uri=f\"runs:/{experiment_run.info.run_id}/model\",\n", - " name=model_name,\n", - ")\n", - "\n", - "displayHTML(f\"

Model '{model_details.name}' registered

\")\n", - "displayHTML(f\"

-Version {model_details.version}

\")" + " )" ] }, { "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "47331bbd-1c8c-46fb-bbdf-2f3cef7adb21", "showTitle": false, @@ -219,7 +233,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "5d7d00c3-2d48-4e32-ac93-f0dae8ca3781", "showTitle": false, @@ -230,11 +247,9 @@ "source": [ "#Load the model from MLflow\n", "loaded_model = mlflow.pyfunc.load_model(\n", - " f\"models:/{model_details.name}/{model_details.version}\"\n", + " f\"{model_details.model_uri}\"\n", ")\n", - "\n", "final_df = loaded_model.predict(input_df)\n", - "\n", "display(final_df)" ] }, @@ -242,7 +257,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "3af4d7d2-07e7-4694-b507-0e6ed8d5910e", "showTitle": false, @@ -260,7 +278,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "566ab2e3-0d22-49ed-bec4-7bbf19cfa67e", "showTitle": false, @@ -277,7 +298,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "9bd15800-17bb-4805-8a3c-82e8da057d31", "showTitle": false, @@ -298,7 +322,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "430ed66f-df52-4b3c-97f2-b482236f2433", "showTitle": false, @@ -313,7 +340,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "6f5f2040-c2cf-4129-9b40-c26a30648ef9", "showTitle": false, @@ -329,7 +359,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "186e105c-e7e2-42b0-b528-9a624d7967c7", "showTitle": false, @@ -348,7 +381,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "a855e7b1-6cd4-4195-9be7-1effd6c7197a", "showTitle": false, @@ -366,7 +402,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "bdc04774-3b77-467b-99ec-0584e33bff8d", "showTitle": false, @@ -383,7 +422,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "b6d8268c-0a51-4e53-b0b1-45e5dcce9724", "showTitle": false, @@ -399,7 +441,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "42ae946e-b195-4b3b-8f4f-d8c284afb6f6", "showTitle": false, @@ -419,7 +464,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "e84f7dcb-5129-4df8-a62b-b8c2f7291678", "showTitle": false, @@ -434,7 +482,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "d83c00ca-2833-4b9e-bd3e-106c2813c533", "showTitle": false, @@ -455,6 +506,7 @@ "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [], + "environmentMetadata": null, "language": "python", "notebookMetadata": { "pythonIndentUnit": 2 diff --git a/04_refutation.ipynb b/04_refutation.ipynb index ccb7605..ce15c0a 100644 --- a/04_refutation.ipynb +++ b/04_refutation.ipynb @@ -4,7 +4,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "08ea43f7-a118-44bf-9e8f-0a64ca254932", "showTitle": false, @@ -20,7 +23,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "5252e041-9be1-4fae-b0d2-d1ea362b1e74", "showTitle": false, @@ -36,7 +42,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "715d7031-369f-4b0c-a7d2-5a9d6db62340", "showTitle": false, @@ -51,7 +60,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "44a73459-20fa-468d-b0c6-8491949a976c", "showTitle": false, @@ -76,7 +88,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "4c90b52c-6de2-4625-92c4-b9742865b275", "showTitle": false, @@ -92,7 +107,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "f86ad090-bbe1-4d34-92bc-7487b746135d", "showTitle": false, @@ -101,7 +119,7 @@ }, "outputs": [], "source": [ - "wrapped_model = get_registered_wrapped_model(model_name=\"discount_dowhy_model\")\n", + "wrapped_model = get_registered_wrapped_model(model_name=f\"{catalog}.{db}.discount_dowhy_model\")\n", "\n", "model = wrapped_model.get_model()\n", "estimand = wrapped_model.get_estimand()\n", @@ -112,7 +130,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "bd614eae-c2f1-4e20-aea3-d089e985fb0f", "showTitle": false, @@ -128,7 +149,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "ec92f977-642e-4f60-8ee1-4ed65ca61916", "showTitle": false, @@ -167,7 +191,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "fd8c2bb9-4725-46f4-95a6-944618c51585", "showTitle": false, @@ -185,7 +212,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "703c0184-37c6-4c1e-9321-9c2e6c9a200a", "showTitle": false, @@ -225,7 +255,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "e3a0d131-1e6e-486b-bc61-909fa138573f", "showTitle": false, @@ -241,7 +274,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "bf97f2c7-7a2d-4d9c-8655-e07b589411bb", "showTitle": false, @@ -278,7 +314,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "f786d7ca-774f-4d3f-84bd-ff6ad2b377c8", "showTitle": false, @@ -294,7 +333,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "59810965-5ed9-47f9-ad0a-d396984b6dea", "showTitle": false, @@ -331,7 +373,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "fd7bc428-3ce6-45b6-8342-05ee3cbfba28", "showTitle": false, @@ -347,7 +392,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "44ab4124-32ed-4938-876e-584c33358c9a", "showTitle": false, @@ -395,7 +443,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "eac0f339-0fbb-49db-b7f0-3d981d499596", "showTitle": false, @@ -413,7 +464,10 @@ "execution_count": 0, "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "065c5f30-7bfb-4b30-b0b7-6e45ed7fe2df", "showTitle": false, @@ -438,7 +492,10 @@ "cell_type": "markdown", "metadata": { "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, + "cellMetadata": { + "byteLimit": 2048000, + "rowLimit": 10000 + }, "inputWidgets": {}, "nuid": "bea61eee-89be-465c-a2e2-2350902cbdde", "showTitle": false, @@ -459,6 +516,7 @@ "metadata": { "application/vnd.databricks.v1+notebook": { "dashboards": [], + "environmentMetadata": null, "language": "python", "notebookMetadata": { "pythonIndentUnit": 2 diff --git a/RUNME.py b/RUNME.py index d085aef..c43dcbb 100644 --- a/RUNME.py +++ b/RUNME.py @@ -49,11 +49,6 @@ }, "task_key": "00_intro", "libraries": [ - { - "pypi": { - "package": "pygraphviz==1.10" - } - }, { "pypi": { "package": "networkx==2.8.8" @@ -88,11 +83,6 @@ } ], "libraries": [ - { - "pypi": { - "package": "pygraphviz==1.10" - } - }, { "pypi": { "package": "networkx==2.8.8" @@ -127,11 +117,6 @@ } ], "libraries": [ - { - "pypi": { - "package": "pygraphviz==1.10" - } - }, { "pypi": { "package": "networkx==2.8.8" @@ -166,11 +151,6 @@ } ], "libraries": [ - { - "pypi": { - "package": "pygraphviz==1.10" - } - }, { "pypi": { "package": "networkx==2.8.8" @@ -205,11 +185,6 @@ } ], "libraries": [ - { - "pypi": { - "package": "pygraphviz==1.10" - } - }, { "pypi": { "package": "networkx==2.8.8" @@ -237,20 +212,17 @@ { "job_cluster_key": "causal_cluster", "new_cluster": { - "spark_version": "13.3.x-cpu-ml-scala2.12", + "spark_version": "14.3.x-cpu-ml-scala2.12", "num_workers": 0, "spark_conf": { "spark.master": "local[*, 4]", "spark.databricks.cluster.profile": "singleNode" }, + "custom_tags": { + "ResourceClass": "SingleNode" + }, "node_type_id": {"AWS": "i3.8xlarge", "MSA": "Standard_E32_v3", "GCP": "n1-highmem-32"}, - "init_scripts": [ - { - "workspace": { - "destination": f"{nsc.solacc_path}/causal_init.sh" - } - } - ] + "data_security_mode": "SINGLE_USER", } } ] @@ -260,5 +232,9 @@ # DBTITLE 1,Deploy job and cluster dbutils.widgets.dropdown("run_job", "False", ["True", "False"]) -run_job = dbutils.widgets.get("run_job") == "True" +run_job = dbutils.widgets.get("run_job") == "False" nsc.deploy_compute(job_json, run_job=run_job) + +# COMMAND ---------- + + diff --git a/causal_init.sh b/causal_init.sh deleted file mode 100644 index ec8501b..0000000 --- a/causal_init.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -sudo apt-get -qq update -sudo apt-get -y -qq install graphviz libgraphviz-dev \ No newline at end of file diff --git a/util/notebook-config.py b/util/notebook-config.py index f99f343..cb8f2a3 100644 --- a/util/notebook-config.py +++ b/util/notebook-config.py @@ -19,6 +19,22 @@ # COMMAND ---------- +# Create catalog if it doesn't exist +catalog = "causal_solacc" +create_catalog_query = f"CREATE CATALOG IF NOT EXISTS {catalog}" +use_catalog_query = f"USE CATALOG {catalog}" + +# Create database with the user's name if it doesn't exist +email = spark.sql('select current_user() as user').collect()[0]['user'] +db = email.split('@')[0].replace('.', '_') +create_db_query = f"CREATE SCHEMA IF NOT EXISTS {catalog}.{db}" + +_ = spark.sql(create_catalog_query) +_ = spark.sql(use_catalog_query) +_ = spark.sql(create_db_query) + +# COMMAND ---------- + # Utility methods for manipulating and serializing the causal graph. @@ -176,6 +192,8 @@ def setup_treatment_and_out_models(): from pip._vendor import pkg_resources +from mlflow.models.signature import ModelSignature +from mlflow.types import DataType, Schema, TensorSpec, ColSpec def get_version(package): @@ -225,7 +243,10 @@ def get_model_env(): def register_dowhy_model(model_name, model, estimand, estimate): """Register a DoWhy model in MLflow.""" - + # Define a dummy input and output schema for the model signature + input_schema = Schema([ColSpec("double")]) + output_schema = Schema([ColSpec("double")]) + signature = ModelSignature(inputs=input_schema, outputs=output_schema) with mlflow.start_run(run_name=f"{model_name} run") as run: model_info = mlflow.pyfunc.log_model( artifact_path="model", @@ -234,27 +255,31 @@ def register_dowhy_model(model_name, model, estimand, estimate): estimand=estimand, estimate=estimate, ), + registered_model_name=model_name, + signature=signature, conda_env=get_model_env(), ) + return model_info - return mlflow.register_model( - model_uri=f"runs:/{run.info.run_id}/model", name=model_name - ) + +# Function to get the latest version of a registered model +def get_latest_model_version(client, model_name): + latest_version = 1 # Initialize the latest version to 1 + # Iterate through all model versions for the given registered model name + for mv in client.search_model_versions(f"name='{model_name}'"): + version_int = int(mv.version) # Convert version string to integer + # Update the latest version if a higher version is found + if version_int > latest_version: + latest_version = version_int + return latest_version # Return the latest version number def get_registered_wrapped_model(model_name): client = mlflow.MlflowClient() - latest_model_versions = client.get_latest_versions(name=model_name) - - if len(latest_model_versions) > 0: - latest_model_version = latest_model_versions[0].version - else: - raise Exception(f"There are no registered versions for model {model_name}") - + latest_model_version = get_latest_model_version(client, model_name) wrapped_model = mlflow.pyfunc.load_model( f"models:/{model_name}/{latest_model_version}" ) - return wrapped_model.unwrap_python_model()