Skip to content

Commit

Permalink
Config flag to turn off global tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Feb 12, 2025
1 parent e40f17f commit fb337da
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
18 changes: 11 additions & 7 deletions src/nnsight/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@
import yaml
from pydantic import BaseModel


class ApiConfigModel(BaseModel):
HOST: str = "ndif.dev"
SSL: bool = True
FORMAT:str ='json'
ZLIB:bool = True
FORMAT: str = "json"
ZLIB: bool = True
APIKEY: Optional[str] = None
JOB_ID:Optional[str] = None
JOB_ID: Optional[str] = None


class AppConfigModel(BaseModel):
LOGGING: bool = False
REMOTE_LOGGING: bool = True
DEBUG: bool = True
CONTROL_FLOW_HANDLING:bool = True
FRAME_INJECTION:bool = True
CONTROL_FLOW_HANDLING: bool = True
FRAME_INJECTION: bool = True
GLOBAL_TRACING: bool = True


class ConfigModel(BaseModel):
API: ApiConfigModel = ApiConfigModel()
Expand All @@ -32,11 +36,11 @@ def set_default_api_key(self, apikey: str):
def set_default_app_debug(self, debug: bool):

self.APP.DEBUG = debug

self.save()

def save(self):

from .. import PATH

with open(os.path.join(PATH, "config.yaml"), "w") as file:
Expand Down
54 changes: 33 additions & 21 deletions src/nnsight/tracing/contexts/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import FunctionType, MethodType
from typing import Any, Type, Union

from ... import util
from ... import util, CONFIG
from ..graph import Graph
from . import Tracer

Expand All @@ -25,7 +25,7 @@ def super_new(cls, *args, **kwargs):

@wraps(fn)
def inner(cls, *args, **kwargs):

if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return cls(*args, **kwargs)

Expand All @@ -38,28 +38,29 @@ def global_patch_fn(fn: FunctionType) -> util.Patch:

@wraps(fn)
def inner(*args, **kwargs):

if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return fn(*args, **kwargs)

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(fn, *args, **kwargs)

return util.Patch(inspect.getmodule(fn), inner, fn.__name__)


def global_patch_method(cls: type, fn: MethodType) -> None:

@wraps(fn)
def inner(*args, **kwargs):

if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return fn(*args, **kwargs)

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(fn, *args, **kwargs)

patch = util.Patch(cls, inner, fn.__name__)

GlobalTracingContext.PATCHER.add(patch)


def global_patch(obj: Union[FunctionType, Type]):

Expand All @@ -72,7 +73,8 @@ def global_patch(obj: Union[FunctionType, Type]):
patch = global_patch_fn(obj)

GlobalTracingContext.PATCHER.add(patch)



class GlobalTracingContext(Tracer):
"""The Global Tracing Context handles adding tracing operations globally without reference to a given `GraphBasedContext`.
There should only be one of these and that is `GlobalTracingContext.GLOBAL_TRACING_CONTEXT`.
Expand All @@ -87,13 +89,17 @@ class GlobalTracingExit(AbstractContextManager):

def __enter__(self) -> Any:

GlobalTracingContext.PATCHER.__exit__(None, None, None)
if CONFIG.APP.GLOBAL_TRACING:

GlobalTracingContext.PATCHER.__exit__(None, None, None)

return self

def __exit__(self, exc_type, exc_val, traceback):

GlobalTracingContext.PATCHER.__enter__()
if CONFIG.APP.GLOBAL_TRACING:

GlobalTracingContext.PATCHER.__enter__()

if isinstance(exc_val, BaseException):

Expand Down Expand Up @@ -121,11 +127,13 @@ def try_register(graph_based_context: Tracer) -> bool:
bool: True if registering ws successful, False otherwise.
"""

if GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
if CONFIG.APP.GLOBAL_TRACING:

if GlobalTracingContext.GLOBAL_TRACING_CONTEXT:

return False
return False

GlobalTracingContext.register(graph_based_context)
GlobalTracingContext.register(graph_based_context)

return True

Expand All @@ -140,15 +148,18 @@ def try_deregister(graph_based_context: Tracer) -> bool:
Returns:
bool: True if deregistering ws successful, False otherwise.
"""
if (
not GlobalTracingContext.GLOBAL_TRACING_CONTEXT
or graph_based_context.graph
is not GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph
):

return False

GlobalTracingContext.deregister()
if CONFIG.APP.GLOBAL_TRACING:

if (
not GlobalTracingContext.GLOBAL_TRACING_CONTEXT
or graph_based_context.graph
is not GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph
):

return False

GlobalTracingContext.deregister()

return True

Expand Down Expand Up @@ -185,4 +196,5 @@ def __bool__(self) -> bool:

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is not None


GlobalTracingContext.GLOBAL_TRACING_CONTEXT = GlobalTracingContext()

0 comments on commit fb337da

Please sign in to comment.