Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config flag to turn off global tracing #327

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions src/nnsight/schema/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@
import yaml
from pydantic import BaseModel


class ApiConfigModel(BaseModel):
HOST: str = "ndif.dev"
SSL: bool = True
FORMAT:str ='json'
ZLIB:bool = True
FORMAT: str = "json"
ZLIB: bool = True
APIKEY: Optional[str] = None
JOB_ID:Optional[str] = None
JOB_ID: Optional[str] = None


class AppConfigModel(BaseModel):
LOGGING: bool = False
REMOTE_LOGGING: bool = True
DEBUG: bool = True
CONTROL_FLOW_HANDLING:bool = True
FRAME_INJECTION:bool = True
CONTROL_FLOW_HANDLING: bool = True
FRAME_INJECTION: bool = True
GLOBAL_TRACING: bool = True


class ConfigModel(BaseModel):
API: ApiConfigModel = ApiConfigModel()
Expand All @@ -32,11 +36,11 @@ def set_default_api_key(self, apikey: str):
def set_default_app_debug(self, debug: bool):

self.APP.DEBUG = debug

self.save()

def save(self):

from .. import PATH

with open(os.path.join(PATH, "config.yaml"), "w") as file:
Expand Down
54 changes: 33 additions & 21 deletions src/nnsight/tracing/contexts/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from types import FunctionType, MethodType
from typing import Any, Type, Union

from ... import util
from ... import util, CONFIG
from ..graph import Graph
from . import Tracer

Expand All @@ -25,7 +25,7 @@ def super_new(cls, *args, **kwargs):

@wraps(fn)
def inner(cls, *args, **kwargs):

if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return cls(*args, **kwargs)

Expand All @@ -38,28 +38,29 @@ def global_patch_fn(fn: FunctionType) -> util.Patch:

@wraps(fn)
def inner(*args, **kwargs):

if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return fn(*args, **kwargs)

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(fn, *args, **kwargs)

return util.Patch(inspect.getmodule(fn), inner, fn.__name__)


def global_patch_method(cls: type, fn: MethodType) -> None:

@wraps(fn)
def inner(*args, **kwargs):

if not GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
return fn(*args, **kwargs)

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply(fn, *args, **kwargs)

patch = util.Patch(cls, inner, fn.__name__)

GlobalTracingContext.PATCHER.add(patch)


def global_patch(obj: Union[FunctionType, Type]):

Expand All @@ -72,7 +73,8 @@ def global_patch(obj: Union[FunctionType, Type]):
patch = global_patch_fn(obj)

GlobalTracingContext.PATCHER.add(patch)



class GlobalTracingContext(Tracer):
"""The Global Tracing Context handles adding tracing operations globally without reference to a given `GraphBasedContext`.
There should only be one of these and that is `GlobalTracingContext.GLOBAL_TRACING_CONTEXT`.
Expand All @@ -87,13 +89,17 @@ class GlobalTracingExit(AbstractContextManager):

def __enter__(self) -> Any:

GlobalTracingContext.PATCHER.__exit__(None, None, None)
if CONFIG.APP.GLOBAL_TRACING:

GlobalTracingContext.PATCHER.__exit__(None, None, None)

return self

def __exit__(self, exc_type, exc_val, traceback):

GlobalTracingContext.PATCHER.__enter__()
if CONFIG.APP.GLOBAL_TRACING:

GlobalTracingContext.PATCHER.__enter__()

if isinstance(exc_val, BaseException):

Expand Down Expand Up @@ -121,11 +127,13 @@ def try_register(graph_based_context: Tracer) -> bool:
bool: True if registering ws successful, False otherwise.
"""

if GlobalTracingContext.GLOBAL_TRACING_CONTEXT:
if CONFIG.APP.GLOBAL_TRACING:

if GlobalTracingContext.GLOBAL_TRACING_CONTEXT:

return False
return False

GlobalTracingContext.register(graph_based_context)
GlobalTracingContext.register(graph_based_context)

return True

Expand All @@ -140,15 +148,18 @@ def try_deregister(graph_based_context: Tracer) -> bool:
Returns:
bool: True if deregistering ws successful, False otherwise.
"""
if (
not GlobalTracingContext.GLOBAL_TRACING_CONTEXT
or graph_based_context.graph
is not GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph
):

return False

GlobalTracingContext.deregister()
if CONFIG.APP.GLOBAL_TRACING:

if (
not GlobalTracingContext.GLOBAL_TRACING_CONTEXT
or graph_based_context.graph
is not GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph
):

return False

GlobalTracingContext.deregister()

return True

Expand Down Expand Up @@ -185,4 +196,5 @@ def __bool__(self) -> bool:

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.graph is not None


GlobalTracingContext.GLOBAL_TRACING_CONTEXT = GlobalTracingContext()