diff --git a/docs/advanced_usage.rst b/docs/advanced_usage.rst index 5e73fba..98214e5 100644 --- a/docs/advanced_usage.rst +++ b/docs/advanced_usage.rst @@ -137,4 +137,25 @@ passing ``object_aggregation="first_token"`` when training a LRE. For instance, object_aggregation="first_token", ) -If the answer is a single token, "mean" and "first_token" are equivalent. \ No newline at end of file +If the answer is a single token, "mean" and "first_token" are equivalent. + + +Custom layer selection +'''''''''''''''''''''' + +By default, the library will try to guess which layers corresponding to hidden activations in the model, +and will use these layers for reading activations and training LREs. If the layers the library guesses are not +correct, or if you want to use different layers to extract activations and train LREs, you can pass in a +custom ``layer_matcher`` to the ``Trainer``, ``CausalEditor``, and ``ConceptMatcher`` when creating these +objects. + +A ``layer_matcher`` is typically A string, and must include the substring ``"{num}"`` which will be replaced +with the layer number to select a layer in the model. For instance, for GPT models, the matcher for +hidden layers is ``"transformer.h.{num}"``. You can find a list of all layers in a model by calling +``model.named_modules()``. + +For most cases, using a string is sufficient, but if you want to customize the layer matcher further +you can pass in a function to ``layer_matcher`` which takes in the layer number as an int and +returns the layer in the model as a string. For instance, for GPT models, this could be provided as +``lambda num: f"transformer.h.{num}"``. + diff --git a/docs/api/causal_editor.rst b/docs/api/causal_editor.rst new file mode 100644 index 0000000..29b099d --- /dev/null +++ b/docs/api/causal_editor.rst @@ -0,0 +1,14 @@ +CausalEditor +============ + +.. autoclass:: linear_relational.CausalEditor + :members: + :undoc-members: + +.. autoclass:: linear_relational.ConceptSwapRequest + :members: + :undoc-members: + +.. autoclass:: linear_relational.ConceptSwapAndPredictGreedyRequest + :members: + :undoc-members: \ No newline at end of file diff --git a/docs/api/concept_matcher.rst b/docs/api/concept_matcher.rst new file mode 100644 index 0000000..086f74e --- /dev/null +++ b/docs/api/concept_matcher.rst @@ -0,0 +1,19 @@ +ConceptMatcher +============== + +.. autoclass:: linear_relational.ConceptMatcher + :members: + :undoc-members: + + +.. autoclass:: linear_relational.ConceptMatchQuery + :members: + :undoc-members: + +.. autoclass:: linear_relational.QueryResult + :members: + :undoc-members: + +.. autoclass:: linear_relational.ConceptMatchResult + :members: + :undoc-members: diff --git a/docs/api/lre.rst b/docs/api/lre.rst new file mode 100644 index 0000000..68681b4 --- /dev/null +++ b/docs/api/lre.rst @@ -0,0 +1,14 @@ +Lre +=== + +.. autoclass:: linear_relational.Lre + :members: + :undoc-members: + +.. autoclass:: linear_relational.LowRankLre + :members: + :undoc-members: + +.. autoclass:: linear_relational.InvertedLre + :members: + :undoc-members: \ No newline at end of file diff --git a/docs/api/trainer.rst b/docs/api/trainer.rst new file mode 100644 index 0000000..02eb071 --- /dev/null +++ b/docs/api/trainer.rst @@ -0,0 +1,7 @@ +Trainer +======= + + +.. autoclass:: linear_relational.Trainer + :members: + :undoc-members: diff --git a/docs/conf.py b/docs/conf.py index acde6ed..7111f19 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,7 @@ sys.path.insert(0, os.path.abspath("..")) -from linear_relational import __version__ +from linear_relational import __version__ # noqa: E402 # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information diff --git a/docs/index.rst b/docs/index.rst index cf19881..77e0b60 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -44,6 +44,7 @@ Linear Relational Concepts (LRCs) represent a concept :math:`(r, o)` as a direct For more information on LREs and LRCs, check out `these `_ `papers `_. +------------ .. toctree:: :maxdepth: 2 @@ -52,6 +53,15 @@ For more information on LREs and LRCs, check out `these ConceptMatchResult: class ConceptMatcher: + """Match concepts against subject activations in a model""" + concepts: list[Concept] model: nn.Module tokenizer: Tokenizer diff --git a/linear_relational/__init__.py b/linear_relational/__init__.py index 4282c09..6d58803 100644 --- a/linear_relational/__init__.py +++ b/linear_relational/__init__.py @@ -12,6 +12,7 @@ ConceptMatchResult, QueryResult, ) +from .lib.layer_matching import LayerMatcher from .Lre import InvertedLre, LowRankLre, Lre from .Prompt import Prompt from .PromptValidator import PromptValidator @@ -32,4 +33,5 @@ "ConceptMatchQuery", "QueryResult", "Trainer", + "LayerMatcher", ] diff --git a/linear_relational/training/Trainer.py b/linear_relational/training/Trainer.py index 4aa526b..4df053d 100644 --- a/linear_relational/training/Trainer.py +++ b/linear_relational/training/Trainer.py @@ -27,6 +27,8 @@ class Trainer: + """Train LREs and concepts from prompts""" + model: nn.Module tokenizer: Tokenizer layer_matcher: LayerMatcher