Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: Add goodfire chat model #29427

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions docs/docs/integrations/chat/goodfire.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
{
"cells": [
{
"cell_type": "raw",
"id": "afaf8039",
"metadata": {
"vscode": {
"languageId": "raw"
}

Check failure on line 9 in docs/docs/integrations/chat/goodfire.ipynb

View workflow job for this annotation

GitHub Actions / cd . / make lint #3.9

Ruff (I001)

docs/docs/integrations/chat/goodfire.ipynb:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 9 in docs/docs/integrations/chat/goodfire.ipynb

View workflow job for this annotation

GitHub Actions / cd . / make lint #3.12

Ruff (I001)

docs/docs/integrations/chat/goodfire.ipynb:1:1: I001 Import block is un-sorted or un-formatted
},
"source": [
"---\n",
"sidebar_label: Goodfire\n",
"---"
]
},
{
"cell_type": "markdown",
"id": "e49f1e0d",
"metadata": {},
"source": [
"# Goodfire\n",
"\n",
"Goodfire is an AI inference platform to run certain Llama models with SAE feature steering. See the [Goodfire docs](https://docs.goodfire.ai/) for more information."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "433e8d2b-9519-4b49-b2c4-7ab65b046c94",
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if \"GOODFIRE_API_KEY\" not in os.environ:\n",
" os.environ[\"GOODFIRE_API_KEY\"] = getpass.getpass(\"Enter your Goodfire API key: \")"
]
},
{
"cell_type": "markdown",
"id": "a38cde65-254d-4219-a441-068766c0d4b5",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cb09c344-1836-4e0c-acf8-11d13ac1dbae",
"metadata": {},
"outputs": [
{
"ename": "ValueError",
"evalue": "model must be a Goodfire variant, got <class 'str'>",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[2], line 15\u001b[0m\n\u001b[1;32m 12\u001b[0m enthusiasm_variant \u001b[38;5;241m=\u001b[39m goodfire\u001b[38;5;241m.\u001b[39mVariant(MODEL_NAME)\n\u001b[1;32m 13\u001b[0m enthusiasm_variant\u001b[38;5;241m.\u001b[39mset(enthusiasm_feature, \u001b[38;5;241m0.3\u001b[39m)\n\u001b[0;32m---> 15\u001b[0m llm \u001b[38;5;241m=\u001b[39m \u001b[43mChatGoodfire\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 16\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mMODEL_NAME\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43mvariant\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43menthusiasm_variant\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mtemperature\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.6\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mseed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m42\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;66;43;03m# other params...\u001b[39;49;00m\n\u001b[1;32m 21\u001b[0m \u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/langchain/libs/community/langchain_community/chat_models/goodfire.py:80\u001b[0m, in \u001b[0;36mChatGoodfire.__init__\u001b[0;34m(self, model, goodfire_api_key, **kwargs)\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m 75\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not import goodfire python package. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 76\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease install it with `pip install goodfire`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 77\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21;01me\u001b[39;00m\n\u001b[1;32m 79\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, goodfire\u001b[38;5;241m.\u001b[39mVariant):\n\u001b[0;32m---> 80\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel must be a Goodfire variant, got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(model)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 82\u001b[0m \u001b[38;5;66;03m# Include model in kwargs for parent initialization\u001b[39;00m\n\u001b[1;32m 83\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m model\n",
"\u001b[0;31mValueError\u001b[0m: model must be a Goodfire variant, got <class 'str'>"
]
}
],
"source": [
"from langchain_community.chat_models import ChatGoodfire\n",
"import goodfire\n",
"\n",
"MODEL_NAME = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
"\n",
"goodfire_client = goodfire.Client(api_key=os.environ[\"GOODFIRE_API_KEY\"])\n",
"\n",
"base_variant = goodfire.Variant(MODEL_NAME)\n",
"\n",
"enthusiasm_feature = goodfire_client.features.lookup([55543], base_variant)[55543]\n",
"\n",
"enthusiasm_variant = goodfire.Variant(MODEL_NAME)\n",
"enthusiasm_variant.set(enthusiasm_feature, 0.3)\n",
"\n",
"llm = ChatGoodfire(\n",
" model=enthusiasm_variant,\n",
" temperature=0.6,\n",
" seed=42,\n",
" # other params...\n",
")"
]
},
{
"cell_type": "markdown",
"id": "2b4f3e15",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "62e0dbc3",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='J\\'ADORE LA PROGRAMMATION! \\n\\n(or in a more casual tone: J\\'ADORE LE CODAGE!)\\n\\nNote: \"J\\'adore\" is a stronger way to say \"I love\" in French, it\\'s more like \"I\\'m crazy about\" or \"I\\'m absolutely passionate about\". If you want to use a more literal translation, you can say: \"J\\'aime la programmation\" which means \"I like programming\".', additional_kwargs={}, response_metadata={}, id='run-d91dd50b-1d6a-4c04-a78c-b1b922c1fc92-0')"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "markdown",
"id": "39f7d928",
"metadata": {},
"source": [
"Note: The variant can be overridden after instantiation by providing a new variant to the `model` parameter."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ceac2cb6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'adore la programmation.\", additional_kwargs={}, response_metadata={}, id='run-b646d8ed-74c3-40a2-8530-7f094060bf23-0')"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ai_msg = llm.invoke(messages, model=base_variant)\n",
"ai_msg"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d86145b3-bfef-46e8-b227-4dda5c9c2705",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"J'ADORE LA PROGRAMMATION! \n",
"\n",
"(or in a more casual tone: J'ADORE LE CODAGE!)\n",
"\n",
"Note: \"J'adore\" is a stronger way to say \"I love\" in French, it's more like \"I'm crazy about\" or \"I'm absolutely passionate about\". If you want to use a more literal translation, you can say: \"J'aime la programmation\" which means \"I like programming\".\n"
]
}
],
"source": [
"print(ai_msg.content)"
]
},
{
"cell_type": "markdown",
"id": "18e2bfc0-7e78-4528-a73f-499ac150dca8",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e197d1d7-a070-4c96-9f8a-a0e86d046e0b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe das Programmieren.', additional_kwargs={}, response_metadata={}, id='run-f77167ac-e9a8-4fc0-9e43-5a4800290324-0')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3a5bb5ca-c3ae-4a58-be67-2cd18574b9a3",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatGoodfire features and configurations head to the API reference: https://python.langchain.com/api_reference/goodfire/chat_models/langchain_goodfire.chat_models.ChatGoodfire.html\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
1 change: 1 addition & 0 deletions libs/community/extended_testing_deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ friendli-client>=1.2.4,<2
geopandas>=0.13.1
gitpython>=3.1.32,<4
gliner>=0.2.7
goodfire>=0.3.4
google-cloud-documentai>=2.20.1,<3
gql>=3.4.1,<4
gradientai>=1.4.0,<2
Expand Down
5 changes: 5 additions & 0 deletions libs/community/langchain_community/chat_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
from langchain_community.chat_models.gigachat import (
GigaChat,
)
from langchain_community.chat_models.goodfire import (
ChatGoodfire,
)
from langchain_community.chat_models.google_palm import (
ChatGooglePalm,
)
Expand Down Expand Up @@ -210,6 +213,7 @@
"ChatEverlyAI",
"ChatFireworks",
"ChatFriendli",
"ChatGoodfire",
"ChatGooglePalm",
"ChatHuggingFace",
"ChatHunyuan",
Expand Down Expand Up @@ -276,6 +280,7 @@
"ChatEdenAI": "langchain_community.chat_models.edenai",
"ChatFireworks": "langchain_community.chat_models.fireworks",
"ChatFriendli": "langchain_community.chat_models.friendli",
"ChatGoodfire": "langchain_community.chat_models.goodfire",
"ChatGooglePalm": "langchain_community.chat_models.google_palm",
"ChatHuggingFace": "langchain_community.chat_models.huggingface",
"ChatHunyuan": "langchain_community.chat_models.hunyuan",
Expand Down
Loading
Loading