Skip to content

Commit

Permalink
chore: improving docs
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 23, 2023
1 parent 33e7ebd commit ab4748e
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 2 deletions.
23 changes: 22 additions & 1 deletion docs/advanced_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,25 @@ passing ``object_aggregation="first_token"`` when training a LRE. For instance,
object_aggregation="first_token",
)
If the answer is a single token, "mean" and "first_token" are equivalent.
If the answer is a single token, "mean" and "first_token" are equivalent.


Custom layer selection
''''''''''''''''''''''

By default, the library will try to guess which layers corresponding to hidden activations in the model,
and will use these layers for reading activations and training LREs. If the layers the library guesses are not
correct, or if you want to use different layers to extract activations and train LREs, you can pass in a
custom ``layer_matcher`` to the ``Trainer``, ``CausalEditor``, and ``ConceptMatcher`` when creating these
objects.

A ``layer_matcher`` is typically A string, and must include the substring ``"{num}"`` which will be replaced
with the layer number to select a layer in the model. For instance, for GPT models, the matcher for
hidden layers is ``"transformer.h.{num}"``. You can find a list of all layers in a model by calling
``model.named_modules()``.

For most cases, using a string is sufficient, but if you want to customize the layer matcher further
you can pass in a function to ``layer_matcher`` which takes in the layer number as an int and
returns the layer in the model as a string. For instance, for GPT models, this could be provided as
``lambda num: f"transformer.h.{num}"``.

14 changes: 14 additions & 0 deletions docs/api/causal_editor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
CausalEditor
============

.. autoclass:: linear_relational.CausalEditor
:members:
:undoc-members:

.. autoclass:: linear_relational.ConceptSwapRequest
:members:
:undoc-members:

.. autoclass:: linear_relational.ConceptSwapAndPredictGreedyRequest
:members:
:undoc-members:
19 changes: 19 additions & 0 deletions docs/api/concept_matcher.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
ConceptMatcher
==============

.. autoclass:: linear_relational.ConceptMatcher
:members:
:undoc-members:


.. autoclass:: linear_relational.ConceptMatchQuery
:members:
:undoc-members:

.. autoclass:: linear_relational.QueryResult
:members:
:undoc-members:

.. autoclass:: linear_relational.ConceptMatchResult
:members:
:undoc-members:
14 changes: 14 additions & 0 deletions docs/api/lre.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Lre
===

.. autoclass:: linear_relational.Lre
:members:
:undoc-members:

.. autoclass:: linear_relational.LowRankLre
:members:
:undoc-members:

.. autoclass:: linear_relational.InvertedLre
:members:
:undoc-members:
7 changes: 7 additions & 0 deletions docs/api/trainer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Trainer
=======


.. autoclass:: linear_relational.Trainer
:members:
:undoc-members:
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

sys.path.insert(0, os.path.abspath(".."))

from linear_relational import __version__
from linear_relational import __version__ # noqa: E402

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
Expand Down
10 changes: 10 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Linear Relational Concepts (LRCs) represent a concept :math:`(r, o)` as a direct
For more information on LREs and LRCs, check out `these <https://arxiv.org/abs/2308.09124>`_ `papers <https://arxiv.org/abs/2311.08968>`_.

------------

.. toctree::
:maxdepth: 2
Expand All @@ -52,6 +53,15 @@ For more information on LREs and LRCs, check out `these <https://arxiv.org/abs/2
advanced_usage
about

.. toctree::
:maxdepth: 2
:caption: API Reference:

api/trainer
api/lre
api/causal_editor
api/concept_matcher

.. toctree::
:caption: Project Links

Expand Down
2 changes: 2 additions & 0 deletions linear_relational/CausalEditor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class ConceptSwapAndPredictGreedyRequest(ConceptSwapRequest):


class CausalEditor:
"""Modify model activations during inference to swap concepts"""

concepts: list[Concept]
model: nn.Module
tokenizer: Tokenizer
Expand Down
2 changes: 2 additions & 0 deletions linear_relational/ConceptMatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def best_match(self) -> ConceptMatchResult:


class ConceptMatcher:
"""Match concepts against subject activations in a model"""

concepts: list[Concept]
model: nn.Module
tokenizer: Tokenizer
Expand Down
2 changes: 2 additions & 0 deletions linear_relational/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ConceptMatchResult,
QueryResult,
)
from .lib.layer_matching import LayerMatcher
from .Lre import InvertedLre, LowRankLre, Lre
from .Prompt import Prompt
from .PromptValidator import PromptValidator
Expand All @@ -32,4 +33,5 @@
"ConceptMatchQuery",
"QueryResult",
"Trainer",
"LayerMatcher",
]
2 changes: 2 additions & 0 deletions linear_relational/training/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@


class Trainer:
"""Train LREs and concepts from prompts"""

model: nn.Module
tokenizer: Tokenizer
layer_matcher: LayerMatcher
Expand Down

0 comments on commit ab4748e

Please sign in to comment.