From 13721d24fd3491f1fa82cd2ee2086a2d6a98d078 Mon Sep 17 00:00:00 2001 From: Adam Belfki Date: Mon, 2 Dec 2024 11:48:40 -0500 Subject: [PATCH 1/2] bug: fixing visualization + reducing memory usage --- .../tutorials/activation_patching.ipynb | 247 ++++++++++-------- 1 file changed, 132 insertions(+), 115 deletions(-) diff --git a/docs/source/notebooks/tutorials/activation_patching.ipynb b/docs/source/notebooks/tutorials/activation_patching.ipynb index 7b12ad0c..cb22820f 100644 --- a/docs/source/notebooks/tutorials/activation_patching.ipynb +++ b/docs/source/notebooks/tutorials/activation_patching.ipynb @@ -104,19 +104,35 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import clear_output" + "try:\n", + " import google.colab\n", + " is_colab = True\n", + "except ImportError:\n", + " is_colab = False\n", + "\n", + "if is_colab:\n", + " !pip install -U nnsight" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import clear_output" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/opt/anaconda3/envs/nnsight/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "/Users/adam/NDIF/mib-env/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } @@ -137,7 +153,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "id": "l8WfCzQhwFM0" }, @@ -145,14 +161,14 @@ "source": [ "import plotly.express as px\n", "import plotly.io as pio\n", - "pio.renderers.default = \"plotly_mimetype+notebook_connected+colab+notebook\"\n", + "pio.renderers.default = \"colab\" if is_colab else \"plotly_mimetype+notebook_connected+colab+notebook\"\n", "from nnsight import LanguageModel, util\n", "from nnsight.tracing.Proxy import Proxy" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": { "id": "qLmX2tdZmgiz" }, @@ -165,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -226,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "id": "7FrduLBCmhOp" }, @@ -249,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -314,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -381,7 +397,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": { "id": "JodVtd5VAuo2" }, @@ -418,7 +434,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": { "id": "1vuK3rEMAug_" }, @@ -477,7 +493,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": { "id": "G60gi4ZUmxlG" }, @@ -541,7 +557,7 @@ " clean_logit_diff - corrupted_logit_diff\n", " )\n", "\n", - " _ioi_patching_results.append(patched_result.save())\n", + " _ioi_patching_results.append(patched_result.item().save())\n", "\n", " ioi_patching_results.append(_ioi_patching_results)" ] @@ -566,7 +582,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "metadata": { "id": "EaHSh6ctnJFl" }, @@ -577,7 +593,7 @@ " x_labels,\n", " plot_title=\"Normalized Logit Difference After Patching Residual Stream on the IOI Task\"):\n", "\n", - " ioi_patching_results = util.apply(ioi_patching_results, lambda x: x.value.item(), Proxy)\n", + " ioi_patching_results = util.apply(ioi_patching_results, lambda x: x.value, Proxy)\n", " fig = px.imshow(\n", " ioi_patching_results,\n", " color_continuous_midpoint=0.0,\n", @@ -601,7 +617,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -715,13 +731,13 @@ 0, 0, 0, - 0.997462809085846, - 0.0017713640118017793, - -0.00017296147416345775, - 0.0008970139897428453, - 0.000027435267838882282, - -0.001084289513528347, - -0.00046759238466620445 + 0.9974616169929504, + 0.0017701691249385476, + -0.0001777326106093824, + 0.0008922415436245501, + 0.000023856726329540834, + -0.0010866738157346845, + -0.0004663989820983261 ], [ 0, @@ -733,13 +749,13 @@ 0, 0, 0, - 0.9971288442611694, - 0.002613507444038987, - -0.0004866778035648167, - 0.0006274326588027179, - -0.00006202756048878655, - -0.0004341929452493787, - -0.0009220635402016342 + 0.9971300363540649, + 0.002608733018860221, + -0.00048309870180673897, + 0.0006298175430856645, + -0.00006441315781557932, + -0.00043419242138043046, + -0.0009184839436784387 ], [ 0, @@ -752,12 +768,12 @@ 0, 0, 0.9950998425483704, - 0.0045661828480660915, - -0.00023498902737628669, - 0.0005928403697907925, - 0.00006202756048878655, - -0.0003220661892555654, - -0.002197207184508443 + 0.004563791677355766, + -0.0002361815859330818, + 0.0005916468217037618, + 0.00006441315781557932, + -0.0003244514809921384, + -0.0021983974147588015 ], [ 0, @@ -769,13 +785,13 @@ 0, 0, 0, - 0.9868704080581665, - 0.02010408788919449, - 0.0006465180194936693, - 0.0011272316332906485, - 0.00015387606981676072, - 0.0006274326588027179, - -0.0019156973576173186 + 0.9868668913841248, + 0.020101677626371384, + 0.0006417459226213396, + 0.0011212661629542708, + 0.0001526830455986783, + 0.000627431902103126, + -0.001915695145726204 ], [ 0, @@ -787,13 +803,13 @@ 0, 0, 0, - 0.9520777463912964, - 0.07812251895666122, - 0.0025824937038123608, - 0.00213160109706223, - 0.000459242524811998, - 0.0003208733396604657, - -0.0014481049729511142 + 0.9520730376243591, + 0.07812362164258957, + 0.0025824906770139933, + 0.0021339841187000275, + 0.00046401331201195717, + 0.00031968014081940055, + -0.0014445247361436486 ], [ 0, @@ -805,13 +821,13 @@ 0, 0, 0, - 0.952825665473938, - 0.07888712733983994, - 0.0019884605426341295, - 0.0002314105222467333, - 0.00021351795294322073, - 0.00024333890178240836, - -0.004971747752279043 + 0.9528281092643738, + 0.07888465374708176, + 0.001984879607334733, + 0.0002278317406307906, + 0.0002147105406038463, + 0.0002433386107441038, + -0.004974127281457186 ], [ 0, @@ -823,13 +839,13 @@ 0, 0, 0, - 0.9087968468666077, - 0.10826075822114944, - 0.003999584820121527, - -0.0003757438971661031, - 0.0002147107879864052, - 0.0005642122705467045, - 0.009820632636547089 + 0.9087921380996704, + 0.10825704783201218, + 0.003999580163508654, + -0.0003733577614184469, + 0.0002206747158197686, + 0.0005677900626324117, + 0.00981227122247219 ], [ 0, @@ -841,13 +857,13 @@ 0, 0, 0, - 0.6858399510383606, - 0.0358411967754364, - 0.0005200772429816425, - -0.0008099368424154818, - 0.000058449048083275557, - 0.0013109286082908511, - 0.41536638140678406 + 0.6858367323875427, + 0.03584115207195282, + 0.0005200766026973724, + -0.0008170928922481835, + 0.00005487047019414604, + 0.0013109270948916674, + 0.41536349058151245 ], [ 0, @@ -859,13 +875,13 @@ 0, 0, 0, - 0.10977685451507568, - 0.024682197719812393, - 0.0003005951002705842, - -0.0003554656286723912, - 0.0002815097104758024, - 0.001400391454808414, - 0.8544618487358093 + 0.10977672785520554, + 0.02468455396592617, + 0.00030059475102461874, + -0.000356658041710034, + 0.000279123691143468, + 0.0014027755241841078, + 0.8544548749923706 ], [ 0, @@ -877,13 +893,13 @@ 0, 0, 0, - 0.019870290532708168, - 0.01540430635213852, - -0.0001848898536991328, - -0.00009423417941434309, - 0.000027435267838882282, - 0.0006560607580468059, - 0.907935619354248 + 0.019870266318321228, + 0.01540309563279152, + -0.00018250395078212023, + -0.00009781257540453225, + 0.00002982090700243134, + 0.0006572527927346528, + 0.9079344868659973 ], [ 0, @@ -895,13 +911,13 @@ 0, 0, 0, - 0.01998003199696541, - 0.006296990439295769, - 0.00020993943326175213, - -0.0007359808660112321, - -0.00004174932109890506, - 0.0005713692517019808, - 0.9069038033485413 + 0.01998358592391014, + 0.006295789964497089, + 0.0002087463508360088, + -0.0007407513330690563, + -0.00004294210521038622, + 0.0005737542524002492, + 0.9069039225578308 ], [ 0, @@ -1825,9 +1841,9 @@ } }, "text/html": [ - "