Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #295

Merged
merged 10 commits into from
Nov 20, 2024
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"protobuf",
"python-socketio[client]",
"tokenizers>=0.13.0",
"pydantic>=2.4.0",
"pydantic>=2.9.0",
"torch>=2.4.0",
"sentencepiece",
"torchvision",
Expand Down
7 changes: 7 additions & 0 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
from functools import wraps
from typing import Dict, Union

from importlib.metadata import version, PackageNotFoundError

try:
__version__ = version("nnsight")
except PackageNotFoundError:
__version__ = "unknown version"

import torch
import yaml

Expand Down
1 change: 1 addition & 0 deletions src/nnsight/contexts/GraphBasedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class GlobalTracingContext(GraphBasedContext):
global_patch(torch, "randn"),
global_patch(torch, "randperm"),
global_patch(torch, "zeros"),
global_patch(torch, "cat")
]
+ [
global_patch_class(value)
Expand Down
3 changes: 2 additions & 1 deletion src/nnsight/contexts/backends/RemoteBackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def submit_request(self, request: "RequestModel") -> "ResponseModel":

else:

raise Exception(response.reason)
msg = response.json()['detail']
raise ConnectionError(msg)

def get_response(self) -> "ResponseModel":
"""Retrieves and handles the response object from the remote endpoint.
Expand Down
21 changes: 20 additions & 1 deletion src/nnsight/envoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,28 @@ def _set_tracer(self, tracer: Tracer, propagate=True):
if propagate:
for envoy in self._sub_envoys:
envoy._set_tracer(tracer, propagate=True)


def _tracing(self) -> bool:
"""Whether or not tracing.

Returns:
bool: Is tracing.
"""

try:

return self._tracer.graph.alive

except:

return False

def _scanning(self) -> bool:
"""Whether or not in scanning mode. Checks the current Tracer's Invoker.

Returns:
bool: _description_
bool: Is scanning.
"""

try:
Expand Down Expand Up @@ -420,6 +436,9 @@ def __call__(
Returns:
InterventionProxy: Module call proxy.
"""

if not self._tracing():
return self._module(*args, **kwargs)

if isinstance(self._tracer.backend, EditBackend):
hook = True
Expand Down
4 changes: 4 additions & 0 deletions src/nnsight/models/NNsightModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def __init__(
self.dispatch_model()

logger.info(f"Initialized `{self._model_key}`")

def __call__(self, *args, **kwargs):

return self._envoy(*args, **kwargs)

def trace(
self,
Expand Down
10 changes: 10 additions & 0 deletions src/nnsight/schema/format/functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import operator
from inspect import getmembers, isbuiltin, isfunction, ismethod, ismethoddescriptor, isclass
from typing import Callable

import einops
import torch
Expand All @@ -24,6 +25,15 @@ def get_function_name(fn, module_name=None):

return f"{module_name}.{fn.__qualname__}"

def update_function(function: str | Callable, new_function: Callable):

if not isinstance(function, str):

function = get_function_name(function)

new_function.__name__ = function

FUNCTIONS_WHITELIST[function] = new_function

FUNCTIONS_WHITELIST = {}

Expand Down
Loading