Skip to content

Commit

Permalink
added logger object to models - maybe debugging will actually log?
Browse files Browse the repository at this point in the history
  • Loading branch information
sstill88 committed Dec 24, 2024
1 parent e5b3132 commit da28300
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 45 deletions.
13 changes: 11 additions & 2 deletions cerulean_cloud/cloud_run_orchestrator/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ async def _orchestrate(
"Removed tiles over land",
severity="INFO",
scene_id=payload.sceneid,
n_tiles=len(tileset_list),
n_tilesets=len(tileset_list),
)
)
except ValueError as e:
Expand Down Expand Up @@ -514,13 +514,22 @@ async def _orchestrate(
for tileset in tileset_list
]

logger.info(
structured_log(
"Initializing model",
severity="INFO",
scene_id=payload.sceneid,
model_type=model_dict["type"],
)
)
model = get_model(model_dict, scene_id=payload.sceneid)

# Stitch inferences
logger.info(
structured_log(
"Stitching result", severity="INFO", scene_id=payload.sceneid
)
)
model = get_model(model_dict, scene_id=payload.sceneid)
tileset_fc_list = [
model.postprocess_tileset(
tileset_results, [[b] for b in tileset_bounds]
Expand Down
69 changes: 26 additions & 43 deletions cerulean_cloud/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,6 @@
)
from cerulean_cloud.cloud_run_orchestrator.utils import structured_log

logger = logging.getLogger("model")
handler = logging.StreamHandler(sys.stdout)
logger.addHandler(handler)
logger.setLevel(logging.INFO)


class BaseModel:
"""
Expand Down Expand Up @@ -69,6 +64,11 @@ def __init__(self, model_dict=None, model_path_local=None, scene_id=None):
)
self.scene_id = scene_id

self.logger = logging.getLogger("model")
handler = logging.StreamHandler(sys.stdout)
self.logger.addHandler(handler)
self.logger.setLevel(logging.INFO)

def load(self):
"""
Loads the model from the given path.
Expand All @@ -78,7 +78,7 @@ def load(self):
self.model = torch.jit.load(self.model_path_local, map_location="cpu")
self.model.eval()
except Exception as e:
logger.error(
self.logger.error(
structured_log(
"Error loading model",
severity="ERROR",
Expand All @@ -96,7 +96,7 @@ def predict(self, inf_stack: List[InferenceInput]) -> InferenceResultStack:
Args:
inf_stack: The input data stack for inference.
"""
logger.info(
self.logger.info(
structured_log(
"Predicting images",
severity="INFO",
Expand Down Expand Up @@ -172,7 +172,7 @@ def nms_feature_reduction(
- A geojson FeatureCollection containing the retained features.
"""

logger.debug(
self.logger.debug(
structured_log(
"DEBUG NMS: initializing feature list",
severity="DEBUG",
Expand All @@ -191,7 +191,7 @@ def nms_feature_reduction(
if not feature_list:
return geojson.FeatureCollection([])

logger.debug(
self.logger.debug(
structured_log(
"DEBUG NMS: filtering None type geometry",
severity="DEBUG",
Expand All @@ -206,7 +206,7 @@ def nms_feature_reduction(
]

# Precompute the areas of all features to optimize geometry operations
logger.debug(
self.logger.debug(
structured_log(
"DEBUG NMS: precomputing areas",
severity="DEBUG",
Expand All @@ -224,7 +224,7 @@ def nms_feature_reduction(
],
crs="EPSG:4326",
)
logger.debug(
self.logger.debug(
structured_log(
"DEBUG NMS: reprojecting features",
severity="DEBUG",
Expand All @@ -240,7 +240,7 @@ def nms_feature_reduction(
feats_to_remove = []

# If the feature has fewer overlaps than required, mark it for removal
logger.debug(
self.logger.debug(
structured_log(
"DEBUG NMS: finding overlaps",
severity="DEBUG",
Expand All @@ -257,7 +257,7 @@ def nms_feature_reduction(
total_features = len(gdf)
log_interval = max(total_features // 10, 1) # Ensure at least one log every 10%

logger.debug(
self.logger.debug(
structured_log(
"DEBUG NMS: removing intersecting features",
severity="DEBUG",
Expand All @@ -272,7 +272,7 @@ def nms_feature_reduction(
# TODO: Remove this log after identifying if this is the step that is failing
if (i + 1) % log_interval == 0 or (i + 1) == total_features:
percentage_complete = int((i + 1) / total_features * 100)
logger.debug(
self.logger.debug(
structured_log(
"Progress update",
severity="DEBUG",
Expand Down Expand Up @@ -338,7 +338,7 @@ def preprocess_tiles(self, inf_stack: List[InferenceInput]):
)
for record in inf_stack
]
logger.info(
self.logger.info(
structured_log(
"Stacked tensors",
severity="INFO",
Expand Down Expand Up @@ -429,15 +429,15 @@ def postprocess_tileset(
geojson.FeatureCollection: A geojson feature collection representing the processed and combined geographical data.
"""

logger.info(
self.logger.info(
structured_log(
"Reducing feature count on tiles",
severity="INFO",
scene_id=self.scene_id,
)
)
scene_polys = self.reduce_tile_features(tileset_results, tileset_bounds)
logger.info(
self.logger.info(
structured_log(
"Stitching tiles into scene",
severity="INFO",
Expand All @@ -446,7 +446,7 @@ def postprocess_tileset(
)
)
feature_collection = self.stitch(scene_polys)
logger.info(
self.logger.info(
structured_log(
"Reducing feature count on scene",
severity="INFO",
Expand All @@ -456,7 +456,7 @@ def postprocess_tileset(
reduced_feature_collection = self.nms_feature_reduction(feature_collection)
n_feats = len(reduced_feature_collection.get("features", []))

logger.info(
self.logger.info(
structured_log(
"Generated features",
severity="INFO",
Expand Down Expand Up @@ -859,7 +859,7 @@ def preprocess_tiles(self, inf_stack: List[InferenceInput]):
for record in inf_stack
]
batch_tensor = torch.cat(stack_tensors, dim=0).to(self.device)
logger.info(
self.logger.info(
structured_log(
"Generated batch tensor",
severity="INFO",
Expand All @@ -869,7 +869,7 @@ def preprocess_tiles(self, inf_stack: List[InferenceInput]):
)
return batch_tensor # Only the tensor batch is needed for the model
except Exception as e:
logger.error(
self.logger.error(
structured_log(
"Failure in preprocessing",
severity="ERROR",
Expand Down Expand Up @@ -951,7 +951,7 @@ def postprocess_tileset(
Args:
tileset_results: The list of InferenceResultStacks to stitch together.
"""
logger.info(
self.logger.info(
structured_log(
"Stitching tiles into scene",
severity="INFO",
Expand All @@ -961,7 +961,7 @@ def postprocess_tileset(
)
scene_array_probs, transform = self.stitch(tileset_results, tileset_bounds)

logger.info(
self.logger.info(
structured_log(
"Finding instances in scene", severity="INFO", scene_id=self.scene_id
)
Expand All @@ -970,7 +970,7 @@ def postprocess_tileset(
feature_collection = self.instantiate(scene_array_probs, transform)
n_feats = len(feature_collection.get("features", []))

logger.info(
self.logger.info(
structured_log(
"Generated features. Reducing feature count on scene",
severity="INFO",
Expand All @@ -982,7 +982,7 @@ def postprocess_tileset(
reduced_feature_collection = self.nms_feature_reduction(feature_collection)
n_feats = len(reduced_feature_collection.get("features", []))

logger.info(
self.logger.info(
structured_log(
"Reduced features",
severity="INFO",
Expand Down Expand Up @@ -1238,14 +1238,6 @@ def get_model(
An instance of the appropriate model class.
"""
model_type = model_dict["type"]
logger.info(
structured_log(
"Initializing model",
severity="INFO",
scene_id=scene_id,
model_type=model_type,
)
)

if model_type == "MASKRCNN":
return MASKRCNNModel(model_dict, model_path_local, scene_id=scene_id)
Expand Down Expand Up @@ -1309,16 +1301,7 @@ def b64_image_to_array(image: str, tensor: bool = False, to_float=False, scene_i
if to_float:
np_img = dtype_to_float(np_img)
return torch.tensor(np_img) if tensor else np_img
except Exception as e:
logger.error(
structured_log(
"Failed to convert base64 image to array",
severity="ERROR",
scene_id=scene_id,
exception=str(e),
traceback=traceback.format_exc(),
)
)
except Exception:
raise


Expand Down

0 comments on commit da28300

Please sign in to comment.