Skip to content

Commit

Permalink
Refactoring upsert_with_ctl_dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
galvana committed Feb 7, 2025
1 parent f413f0f commit 599aa3b
Showing 1 changed file with 43 additions and 63 deletions.
106 changes: 43 additions & 63 deletions src/fides/api/models/datasetconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,47 +60,62 @@ def upsert_with_ctl_dataset(
cls, db: Session, *, data: Dict[str, Any]
) -> "DatasetConfig":
"""
Create or update both the DatasetConfig and its associated CTL Dataset.
Create or update the DatasetConfig AND the corresponding CTL Dataset
This method handles:
1. Creating/updating the CTL Dataset
2. Creating/updating the DatasetConfig that references it
If the DatasetConfig exists with the supplied FidesKey, update the linked CtlDataset with the dataset contents.
If the DatasetConfig *does not exist*, upsert a CtlDataset on fides_key, and then link to the DatasetConfig on creation.
Args:
db: Database session
data: Dictionary containing:
- connection_config_id: ID of the connection config
- fides_key: Key for the dataset config
- dataset: Optional dataset contents for CTL dataset
Returns:
Updated or created DatasetConfig
"""
# Handle CTL dataset first
ctl_dataset = None
if "dataset" in data:
dataset_contents = data["dataset"]
ctl_dataset = cls._upsert_ctl_dataset(
db,
dataset_contents,
existing_ctl_dataset=cls._get_existing_ctl_dataset(db, data),

def upsert_ctl_dataset(dataset_contents: Dict[str, Any]) -> CtlDataset:
"""
If ctl_dataset_obj specified, update that resource directly, otherwise
create a new resource.
"""
validated_data = Dataset(**dataset_contents)

# Check for existing CTL dataset
ctl_dataset = (
db.query(CtlDataset)
.filter(CtlDataset.fides_key == dataset_contents.get("fides_key"))
.first()
)
data["ctl_dataset_id"] = ctl_dataset.id
data.pop("dataset")

# Then handle DatasetConfig
if ctl_dataset:
# Update existing CTL dataset
for key, val in validated_data.model_dump(mode="json").items():
setattr(ctl_dataset, key, val)
else:
# Create new CTL dataset
ctl_dataset = CtlDataset(**validated_data.model_dump(mode="json"))

db.add(ctl_dataset)
db.commit()
db.refresh(ctl_dataset)
return ctl_dataset

# Make a copy of data to avoid modifications
data_copy = data.copy()

# Handle CTL dataset if dataset data is provided
if "dataset" in data_copy:
ctl_dataset = upsert_ctl_dataset(data_copy["dataset"])
data_copy["ctl_dataset_id"] = ctl_dataset.id
data_copy.pop("dataset")

# Handle DatasetConfig
dataset_config = cls.filter(
db=db,
conditions=(
(cls.connection_config_id == data["connection_config_id"])
& (cls.fides_key == data["fides_key"])
(cls.connection_config_id == data_copy["connection_config_id"])
& (cls.fides_key == data_copy["fides_key"])
),
).first()

if dataset_config:
dataset_config.update(db=db, data=data)
dataset_config.update(db=db, data=data_copy)
else:
dataset_config = cls.create(db=db, data=data)
dataset_config = cls.create(db=db, data=data_copy)

return dataset_config

Expand Down Expand Up @@ -154,41 +169,6 @@ def get_dataset_with_stubbed_collection(self) -> GraphDataset:
dataset_graph.collections = [stubbed_collection]
return dataset_graph

@staticmethod
def _upsert_ctl_dataset(
db: Session,
dataset_data: Dict[str, Any],
existing_ctl_dataset: Optional[CtlDataset] = None,
) -> CtlDataset:
"""Helper method to handle CTL dataset creation/updates"""
validated_data = Dataset(**dataset_data)

if existing_ctl_dataset:
for key, val in validated_data.model_dump(mode="json").items():
setattr(existing_ctl_dataset, key, val)
ctl_dataset = existing_ctl_dataset
else:
ctl_dataset = CtlDataset(**validated_data.model_dump(mode="json"))

db.add(ctl_dataset)
db.commit()
db.refresh(ctl_dataset)
return ctl_dataset

@staticmethod
def _get_existing_ctl_dataset(
db: Session, data: Dict[str, Any]
) -> Optional[CtlDataset]:
"""Helper method to find existing CTL dataset"""
if "dataset" not in data:
return None

return (
db.query(CtlDataset)
.filter(CtlDataset.fides_key == data["dataset"].get("fides_key"))
.first()
)


def to_graph_field(
field: DatasetField, return_all_elements: Optional[bool] = None
Expand Down

0 comments on commit 599aa3b

Please sign in to comment.