Skip to content

Commit

Permalink
chore: updating docs with advanced usage
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 23, 2023
1 parent b915827 commit 7439e71
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 120 deletions.
58 changes: 0 additions & 58 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Linear Relational


[![ci](https://img.shields.io/github/actions/workflow/status/chanind/linear-relational/ci.yaml?branch=main)](https://github.com/chanind/linear-relational)
[![PyPI](https://img.shields.io/pypi/v/linear-relational?color=blue)](https://pypi.org/project/linear-relational/)

Expand Down Expand Up @@ -176,42 +175,6 @@ edited_answer = editor.swap_subject_concepts_and_predict_greedy(
print(edited_answer) # " France"
```

#### Bulk editing

Edits can be performed in batches to make better use of GPU resources using `editor.swap_subject_concepts_and_predict_greedy_bulk()` as below:

```python
from linear_relational import CausalEditor, ConceptSwapAndPredictGreedyRequest

concepts = trainer.train_relation_concepts(...)

editor = CausalEditor(model, tokenizer, concepts=concepts)

swap_requests = [
ConceptSwapAndPredictGreedyRequest(
text="Shanghai is located in the country of",
subject="Shanghai",
remove_concept="located in country: China",
add_concept="located in country: France",
predict_num_tokens=1,
),
ConceptSwapAndPredictGreedyRequest(
text="Berlin is located in the country of",
subject="Berlin",
remove_concept="located in country: Germany",
add_concept="located in country: Japan",
predict_num_tokens=1,
),
]
edited_answers = editor.swap_subject_concepts_and_predict_greedy_bulk(
requests=swap_requests,
edit_single_layer=False,
magnitude_multiplier=0.1,
batch_size=4,
)
print(edited_answers) # [" France", " Japan"]
```

### Concept matching

We can use learned concepts (LRCs) to act like classifiers and match them against subject activations in sentences. We can use the `ConceptMatcher` class to do this matching.
Expand All @@ -229,27 +192,6 @@ print(match_info.best_match.name) # located in country: China
print(match_info.betch_match.score) # 0.832
```

#### Bulk concept matching

We can perform concept matches in batches to better utilize GPU resources using `matcher.query_bulk()` as below:

```python
from linear_relational import ConceptMatcher, ConceptMatchQuery

concepts = trainer.train_relation_concepts(...)

matcher = ConceptMatcher(model, tokenizer, concepts=concepts)

match_queries = [
ConceptMatchQuery("Beijing is a northern city", subject="Beijing"),
ConceptMatchQuery("I saw him in Marseille", subject="Marseille"),
]
matches = matcher.query_bulk(match_queries, batch_size=4)

print(matches[0].best_match.name) # located in country: China
print(matches[1].best_match.name) # located in country: France
```

## Acknowledgements

This library is inspired by and uses modified code from the following excellent projects:
Expand Down
140 changes: 140 additions & 0 deletions docs/advanced_usage.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
Advanced usage
==============

Bulk editing
''''''''''''

Edits can be performed in batches to make better use of GPU resources using ``editor.swap_subject_concepts_and_predict_greedy_bulk()`` as below:

.. code:: python
from linear_relational import CausalEditor, ConceptSwapAndPredictGreedyRequest
concepts = trainer.train_relation_concepts(...)
editor = CausalEditor(model, tokenizer, concepts=concepts)
swap_requests = [
ConceptSwapAndPredictGreedyRequest(
text="Shanghai is located in the country of",
subject="Shanghai",
remove_concept="located in country: China",
add_concept="located in country: France",
predict_num_tokens=1,
),
ConceptSwapAndPredictGreedyRequest(
text="Berlin is located in the country of",
subject="Berlin",
remove_concept="located in country: Germany",
add_concept="located in country: Japan",
predict_num_tokens=1,
),
]
edited_answers = editor.swap_subject_concepts_and_predict_greedy_bulk(
requests=swap_requests,
edit_single_layer=False,
magnitude_multiplier=0.1,
batch_size=4,
)
print(edited_answers) # [" France", " Japan"]
Bulk concept matching
'''''''''''''''''''''

We can perform concept matches in batches to better utilize GPU resources using ``matcher.query_bulk()`` as below:

.. code:: python
from linear_relational import ConceptMatcher, ConceptMatchQuery
concepts = trainer.train_relation_concepts(...)
matcher = ConceptMatcher(model, tokenizer, concepts=concepts)
match_queries = [
ConceptMatchQuery("Beijng is a northern city", subject="Beijing"),
ConceptMatchQuery("I sawi him in Marseille", subject="Marseille"),
]
matches = matcher.query_bulk(match_queries, batch_size=4)
print(matches[0].best_match.name) # located in country: China
print(matches[1].best_match.name) # located in country: France
Customizing LRC training
''''''''''''''''''''''''

The base ``trainer.train_relation_concepts()`` function is a convenience wrapper which trains a LRE,
performs a low-rank inverse of the LRE, and uses the inverted LRE to generate concepts. If you want to customize
this process, you can generate a LRE using ``trainer.train_lre()``, followed by inverting the LRE with ``lre.invert()``,
and finally training concepts from the inverted LRE with ``trainer.train_relation_concepts_from_inv_lre()``. This process
is shown below:

.. code:: python
from linear_relational import Trainer
trainer = Trainer(model, tokenizer)
prompts = [...]
lre = trainer.train_lre(...)
inv_lre = lre.invert(rank=200)
concepts = trainer.train_relation_concepts_from_inv_lre(
inv_lre=inv_lre,
prompts=prompts,
)
Custom objects in prompts
'''''''''''''''''''''''''

By default, when you create a ``Prompt``, the answer to the prompt is assumed to be the object
corresponding to a LRC. For instance, in the prompt ``Prompt("Paris is located in", "France", subject="Paris")``,
the answer, "France", is assumed to be the object. However, if this is not the case, you can specify the object
explicitly using the ``object_name`` parameter as below:

.. code:: python
from linear_relational import Prompt
prompt1 = Prompt(
text="PARIS IS LOCATED IN",
answer="FRANCE",
subject="PARIS",
object_name="france",
)
prompt2 = Prompt(
text="Paris is located in",
answer="France",
subject="Paris",
object_name="france",
)
Skipping prompt validation
''''''''''''''''''''''''''

By default, the ``Trainer`` will validate that for every prompt passed in, that the model answers the prompt correctly,
and will filter out any prompts where this is not the case.
If you want to skip this validation, you can pass ``validate_prompts=False`` to all methods on the trainer
like ``Trainer.train_relation_concepts(prompts, validate_prompts=False)``.


Multi-token object aggregation
''''''''''''''''''''''''''''''

If a prompt has an answer which is multiple tokens, by default the ``Trainer`` will use the mean activation of
the tokens in the answer when training a LRE. An example of a prompt with a multi-token answer is "The CEO of Microsoft is Bill Gates",
where the object, "Bill Gates", has two tokens. Alternatively, you can use just the first token of the object by
passing ``object_aggregation="first_token"`` when training a LRE. For instance, you can run the following:

.. code:: python
lre = trainer.train_lre(
prompts=prompts,
object_aggregation="first_token",
)
If the answer is a single token, "mean" and "first_token" are equivalent.
63 changes: 2 additions & 61 deletions docs/usage.rst → docs/basic_usage.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Usage
=====
Basic usage
===========
This library assumes you're using PyTorch with a decoder-only generative language
model (e.g., GPT, LLaMa, etc...), and a tokenizer from Huggingface.

Expand Down Expand Up @@ -156,43 +156,6 @@ hyperparameter that requires tuning depending on the model being edited.
)
print(edited_answer) # " France"
Bulk editing
''''''''''''

Edits can be performed in batches to make better use of GPU resources using `editor.swap_subject_concepts_and_predict_greedy_bulk()` as below:

.. code:: python
from linear_relational import CausalEditor, ConceptSwapAndPredictGreedyRequest
concepts = trainer.train_relation_concepts(...)
editor = CausalEditor(model, tokenizer, concepts=concepts)
swap_requests = [
ConceptSwapAndPredictGreedyRequest(
text="Shanghai is located in the country of",
subject="Shanghai",
remove_concept="located in country: China",
add_concept="located in country: France",
predict_num_tokens=1,
),
ConceptSwapAndPredictGreedyRequest(
text="Berlin is located in the country of",
subject="Berlin",
remove_concept="located in country: Germany",
add_concept="located in country: Japan",
predict_num_tokens=1,
),
]
edited_answers = editor.swap_subject_concepts_and_predict_greedy_bulk(
requests=swap_requests,
edit_single_layer=False,
magnitude_multiplier=0.1,
batch_size=4,
)
print(edited_answers) # [" France", " Japan"]
Concept matching
''''''''''''''''

Expand All @@ -211,25 +174,3 @@ We can use the ``ConceptMatcher`` class to do this matching.
print(match_info.best_match.name) # located in country: China
print(match_info.betch_match.score) # 0.832
Bulk concept matching
'''''''''''''''''''''

We can perform concept matches in batches to better utilize GPU resources using ``matcher.query_bulk()`` as below:

.. code:: python
from linear_relational import ConceptMatcher, ConceptMatchQuery
concepts = trainer.train_relation_concepts(...)
matcher = ConceptMatcher(model, tokenizer, concepts=concepts)
match_queries = [
ConceptMatchQuery("Beijng is a northern city", subject="Beijing"),
ConceptMatchQuery("I sawi him in Marseille", subject="Marseille"),
]
matches = matcher.query_bulk(match_queries, batch_size=4)
print(matches[0].best_match.name) # located in country: China
print(matches[1].best_match.name) # located in country: France
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ For more information on LREs and LRCs, check out `these <https://arxiv.org/abs/2
.. toctree::
:maxdepth: 2

usage
basic_usage
advanced_usage
about

.. toctree::
Expand Down

0 comments on commit 7439e71

Please sign in to comment.