Skip to content

Commit

Permalink
updating with catalog and schema fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesrmccall committed Sep 16, 2024
1 parent d762bd6 commit 38bf7fc
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 145 deletions.
4 changes: 4 additions & 0 deletions 01_Introduction_And_Setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@
# COMMAND ----------

# MAGIC %run ./_resources/00-setup $reset_all_data=true

# COMMAND ----------


22 changes: 5 additions & 17 deletions 02_Fine_Grained_Demand_Forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

# COMMAND ----------

print(cloud_storage_path)
print(dbName)
spark.sql(f"""USE CATALOG {catalogName}""")
spark.sql(f"""USE {dbName}""")

# COMMAND ----------

Expand Down Expand Up @@ -156,24 +156,12 @@ def one_step_ahead_forecast(pdf: pd.DataFrame) -> pd.DataFrame:

# COMMAND ----------

distribution_center_demand_df_delta_path = os.path.join(cloud_storage_path, 'distribution_center_demand_df_delta')
distribution_center_demand.write.mode("overwrite").saveAsTable("distribution_center_demand")

# COMMAND ----------

# Write the data
distribution_center_demand.write \
.mode("overwrite") \
.format("delta") \
.save(distribution_center_demand_df_delta_path)

# COMMAND ----------

spark.sql(f"DROP TABLE IF EXISTS {dbName}.distribution_center_demand")
spark.sql(f"CREATE TABLE {dbName}.distribution_center_demand USING DELTA LOCATION '{distribution_center_demand_df_delta_path}'")

# COMMAND ----------

display(spark.sql(f"SELECT * FROM {dbName}.distribution_center_demand"))
# MAGIC %sql
# MAGIC select * from distribution_center_demand

# COMMAND ----------

Expand Down
34 changes: 8 additions & 26 deletions 03_Optimize_Transportation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@

# COMMAND ----------

print(cloud_storage_path)
print(dbName)

# COMMAND ----------

import os
import datetime as dt
import re
Expand All @@ -60,6 +55,11 @@

# COMMAND ----------

spark.sql(f"""USE CATALOG {catalogName}""")
spark.sql(f"""USE {dbName}""")

# COMMAND ----------

# MAGIC %md
# MAGIC ## Defining and solving the LP

Expand Down Expand Up @@ -213,13 +213,6 @@ def transport_optimization(pdf: pd.DataFrame) -> pd.DataFrame:

# COMMAND ----------

# Test the function
#product_selection = "nail_1"
# pdf = lp_table_all_info.filter(f.col("product")==product_selection).toPandas()
# transport_optimization(pdf)

# COMMAND ----------

spark.conf.set("spark.databricks.optimizer.adaptive.enabled", "false")
n_tasks = lp_table_all_info.select("product").distinct().count()

Expand All @@ -237,24 +230,13 @@ def transport_optimization(pdf: pd.DataFrame) -> pd.DataFrame:

# COMMAND ----------

shipment_recommendations_df_delta_path = os.path.join(cloud_storage_path, 'shipment_recommendations_df_delta')
optimal_transport_df.write.mode("overwrite").saveAsTable("shipment_recommendations")

# COMMAND ----------

# Write the data
optimal_transport_df.write \
.mode("overwrite") \
.format("delta") \
.save(shipment_recommendations_df_delta_path)

# COMMAND ----------

spark.sql(f"DROP TABLE IF EXISTS {dbName}.shipment_recommendations")
spark.sql(f"CREATE TABLE {dbName}.shipment_recommendations USING DELTA LOCATION '{shipment_recommendations_df_delta_path}'")

# COMMAND ----------
from pyspark.sql.functions import col

display(spark.sql(f"SELECT * FROM {dbName}.shipment_recommendations"))
display(spark.read.table('shipment_recommendations'))

# COMMAND ----------

Expand Down
33 changes: 18 additions & 15 deletions _resources/00-setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

# COMMAND ----------

dbName = "sco_data_james"
catalogPrefix = "supply_chain_optimization_catalog"

# COMMAND ----------

# Get dbName and cloud_storage_path, reset and create database
current_user = dbutils.notebook.entry_point.getDbutils().notebook().getContext().tags().apply('user')
if current_user.rfind('@') > 0:
Expand All @@ -18,21 +23,22 @@
current_user_no_at = current_user
current_user_no_at = re.sub(r'\W+', '_', current_user_no_at)

dbName = db_prefix+"_"+current_user_no_at
cloud_storage_path = f"/Users/{current_user}/field_demos/{db_prefix}"
catalogName = catalogPrefix+"_"+current_user_no_at

reset_all = dbutils.widgets.get("reset_all_data") == "true"

if reset_all:
spark.sql(f"DROP DATABASE IF EXISTS {dbName} CASCADE")
dbutils.fs.rm(cloud_storage_path, True)
spark.sql(f"DROP CATALOG IF EXISTS {catalogName} CASCADE")

spark.sql(f"""create database if not exists {dbName} LOCATION '{cloud_storage_path}/tables' """)
spark.sql(f"""create catalog if not exists {catalogName}""")
spark.sql(f"""USE CATALOG {catalogName}""")
spark.sql(f"""create database if not exists {dbName}""")
spark.sql(f"""USE {dbName}""")

# COMMAND ----------

print(cloud_storage_path)
print(dbName)
print(f"The catalog {catalogName} will be used")
print(f"The database {dbName} will be used")

# COMMAND ----------

Expand All @@ -41,25 +47,22 @@

# COMMAND ----------

path = cloud_storage_path

dirname = os.path.dirname(dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get())
filename = "01-data-generator"
if (os.path.basename(dirname) != '_resources'):
dirname = os.path.join(dirname,'_resources')

generate_data_notebook_path = os.path.join(dirname,filename)

# print(generate_data_notebook_path)

def generate_data():
dbutils.notebook.run(generate_data_notebook_path, 600, {"reset_all_data": reset_all, "dbName": dbName, "cloud_storage_path": cloud_storage_path})
dbutils.notebook.run(generate_data_notebook_path, 3000, {"reset_all_data": reset_all, "catalogName": catalogName, "dbName": dbName})

# COMMAND ----------

if reset_all_bool:
generate_data()
else:
try:
dbutils.fs.ls(path)
except:
generate_data()

# COMMAND ----------

Expand Down
116 changes: 29 additions & 87 deletions _resources/01-data-generator.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
# Databricks notebook source
dbutils.widgets.dropdown('reset_all_data', 'false', ['true', 'false'], 'Reset all data')
dbutils.widgets.text('dbName', 'supply_chain_optimization_max_kohler' , 'Database Name')
dbutils.widgets.text('cloud_storage_path', '/Users/[email protected]/field_demos/supply_chain_optimization', 'Storage Path')
dbutils.widgets.text('catalogName', 'supply_chain_optimization_catalog_max_kohler' , 'Catalog Name')
dbutils.widgets.text('dbName', 'sco_data' , 'Database Name')

# COMMAND ----------

print("Starting ./_resources/01-data-generator")

# COMMAND ----------

cloud_storage_path = dbutils.widgets.get('cloud_storage_path')
catalogName = dbutils.widgets.get('catalogName')
dbName = dbutils.widgets.get('dbName')
reset_all_data = dbutils.widgets.get('reset_all_data') == 'true'

# COMMAND ----------

print(cloud_storage_path)
print(dbName)
print(reset_all_data)
print(f"The catalog {catalogName} will be used")
print(f"The database {dbName} will be used")
print(f"Running withreset_all_data = {reset_all_data}")

# COMMAND ----------

spark.sql(f"""USE CATALOG {catalogName}""")
spark.sql(f"""USE {dbName}""")

# COMMAND ----------

Expand Down Expand Up @@ -269,12 +274,12 @@ def time_series_generator_pandas_udf(pdf):
# COMMAND ----------

# Test if demand is in a realistic range
#display(demand_df.groupBy("product", "store").mean("demand"))
# display(demand_df.groupBy("product", "store").mean("demand"))

# COMMAND ----------

# Select a sepecific time series
# display(demand_df.join(demand_df.sample(False, 1 / demand_df.count(), seed=0).limit(1).select("product", "store"), on=["product", "store"], how="inner"))
#display(demand_df.join(demand_df.sample(False, 1 / demand_df.count(), seed=0).limit(1).select("product", "store"), on=["product", "store"], how="inner"))

# COMMAND ----------

Expand All @@ -283,28 +288,19 @@ def time_series_generator_pandas_udf(pdf):

# COMMAND ----------

demand_df_delta_path = os.path.join(cloud_storage_path, 'demand_df_delta')

# COMMAND ----------

# Write the data
demand_df.write \
.mode("overwrite") \
.format("delta") \
.save(demand_df_delta_path)

# COMMAND ----------
demand_df.write.mode("overwrite").saveAsTable("part_level_demand")

spark.sql(f"DROP TABLE IF EXISTS {dbName}.part_level_demand")
spark.sql(f"CREATE TABLE {dbName}.part_level_demand USING DELTA LOCATION '{demand_df_delta_path}'")
#### table not yet stored

# COMMAND ----------

display(spark.sql(f"SELECT * FROM {dbName}.part_level_demand"))
# MAGIC %sql
# MAGIC SELECT * FROM part_level_demand

# COMMAND ----------

display(spark.sql(f"SELECT COUNT(*) as row_count FROM {dbName}.part_level_demand"))
# MAGIC %sql
# MAGIC SELECT COUNT(*) as row_count FROM part_level_demand

# COMMAND ----------

Expand Down Expand Up @@ -370,24 +366,7 @@ def time_series_generator_pandas_udf(pdf):

# COMMAND ----------

distribution_center_to_store_mapping_delta_path = os.path.join(cloud_storage_path, 'distribution_center_to_store_mapping')

# COMMAND ----------

# Write the data
distribution_center_to_store_mapping_table.write \
.mode("overwrite") \
.format("delta") \
.save(distribution_center_to_store_mapping_delta_path)

# COMMAND ----------

spark.sql(f"DROP TABLE IF EXISTS {dbName}.distribution_center_to_store_mapping_table")
spark.sql(f"CREATE TABLE {dbName}.distribution_center_to_store_mapping_table USING DELTA LOCATION '{distribution_center_to_store_mapping_delta_path}'")

# COMMAND ----------

display(spark.sql(f"SELECT * FROM {dbName}.distribution_center_to_store_mapping_table"))
distribution_center_to_store_mapping_table.write.mode("overwrite").saveAsTable("distribution_center_to_store_mapping_table")

# COMMAND ----------

Expand All @@ -407,8 +386,8 @@ def time_series_generator_pandas_udf(pdf):

# COMMAND ----------

tmp_map_distribution_center_to_store = spark.read.table(f"{dbName}.distribution_center_to_store_mapping_table")
distribution_center_df = (spark.read.table(f"{dbName}.part_level_demand").
tmp_map_distribution_center_to_store = spark.read.table("distribution_center_to_store_mapping_table")
distribution_center_df = (spark.read.table("part_level_demand").
select("product","store").
join(tmp_map_distribution_center_to_store, ["store"], how="inner").
select("product","distribution_center").
Expand Down Expand Up @@ -476,24 +455,7 @@ def cost_generator(pdf: pd.DataFrame) -> pd.DataFrame:

# COMMAND ----------

cost_table_delta_path = os.path.join(cloud_storage_path, 'cost_table')

# COMMAND ----------

# Write the data
transport_cost_table.write \
.mode("overwrite") \
.format("delta") \
.save(cost_table_delta_path)

# COMMAND ----------

spark.sql(f"DROP TABLE IF EXISTS {dbName}.transport_cost_table")
spark.sql(f"CREATE TABLE {dbName}.transport_cost_table USING DELTA LOCATION '{cost_table_delta_path}'")

# COMMAND ----------

display(spark.sql(f"SELECT * FROM {dbName}.transport_cost_table"))
transport_cost_table.write.mode("overwrite").saveAsTable("transport_cost_table")

# COMMAND ----------

Expand All @@ -503,7 +465,7 @@ def cost_generator(pdf: pd.DataFrame) -> pd.DataFrame:
# COMMAND ----------

# Create a list with all plants
all_plants = spark.read.table(f"{dbName}.transport_cost_table").select("plant").distinct().collect()
all_plants = spark.read.table(f"transport_cost_table").select("plant").distinct().collect()
all_plants = [row[0] for row in all_plants]

# Create a list with fractions: Sum must be larger than one to fullfill the demands
Expand All @@ -514,8 +476,8 @@ def cost_generator(pdf: pd.DataFrame) -> pd.DataFrame:
plant_supply_in_percentage_of_demand = {all_plants[i]: fractions_lst[i] for i in range(len(all_plants))}

#Get maximum demand in history and sum up the demand of all distribution centers
map_store_to_dc_tmp = spark.read.table(f"{dbName}.distribution_center_to_store_mapping_table")
max_demands_per_dc = (spark.read.table(f"{dbName}.part_level_demand").
map_store_to_dc_tmp = spark.read.table(f"distribution_center_to_store_mapping_table")
max_demands_per_dc = (spark.read.table(f"part_level_demand").
groupBy("product", "store").
agg(f.max("demand").alias("demand")).
join(map_store_to_dc_tmp, ["store"], how = "inner"). # This join will not produce duplicates, as one store is assigned to exactly one distribution center
Expand All @@ -532,37 +494,17 @@ def cost_generator(pdf: pd.DataFrame) -> pd.DataFrame:

# COMMAND ----------

display(spark.read.table(f"{dbName}.distribution_center_to_store_mapping_table"))

# COMMAND ----------

display(spark.read.table(f"{dbName}.part_level_demand"))

# COMMAND ----------

# MAGIC %md
# MAGIC Save as a Delta table

# COMMAND ----------

supply_table_delta_path = os.path.join(cloud_storage_path, 'supply_table')

# COMMAND ----------

# Write the data
plant_supply.write \
.mode("overwrite") \
.format("delta") \
.save(supply_table_delta_path)

# COMMAND ----------

spark.sql(f"DROP TABLE IF EXISTS {dbName}.supply_table")
spark.sql(f"CREATE TABLE {dbName}.supply_table USING DELTA LOCATION '{supply_table_delta_path}'")
plant_supply.write.mode("overwrite").saveAsTable("supply_table")

# COMMAND ----------

display(spark.sql(f"SELECT * FROM {dbName}.supply_table"))
# MAGIC %sql
# MAGIC SELECT * FROM supply_table

# COMMAND ----------

Expand Down

0 comments on commit 38bf7fc

Please sign in to comment.