forked from AgentTorch/AgentTorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from AgentTorch/master
merge from master
- Loading branch information
Showing
206 changed files
with
42,492 additions
and
2,339 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
name: release | ||
|
||
on: push | ||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: 3.9 | ||
- run: pip install build | ||
- run: python -m build | ||
- uses: actions/upload-artifact@v3 | ||
with: | ||
name: package-distributions | ||
path: dist/ | ||
|
||
publish: | ||
runs-on: ubuntu-latest | ||
if: startsWith(github.ref, 'refs/tags/v') | ||
needs: | ||
- build | ||
permissions: | ||
id-token: write | ||
steps: | ||
- uses: actions/download-artifact@v3 | ||
with: | ||
name: package-distributions | ||
path: dist/ | ||
- uses: pypa/gh-action-pypi-publish@release/v1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,3 +27,4 @@ venv.bak/ | |
|
||
simulation_memory_output/ | ||
cache/ | ||
dist/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +0,0 @@ | ||
from agent_torch.config import Configurator | ||
from agent_torch.registry import Registry | ||
from agent_torch.runner import Runner | ||
from agent_torch.controller import Controller | ||
from agent_torch.initializer import Initializer | ||
|
||
from .version import __version__ | ||
|
||
from .distributions import distributions | ||
from agent_torch.helpers.soft import * | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from .config import Configurator | ||
from .registry import Registry | ||
from .runner import Runner | ||
from .controller import Controller | ||
from .initializer import Initializer | ||
|
||
from .version import __version__ | ||
|
||
from .distributions import distributions | ||
from .helpers.soft import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import importlib | ||
import sys | ||
from tqdm import trange | ||
|
||
from agent_torch.core.dataloader import DataLoader | ||
from agent_torch.core.runner import Runner | ||
|
||
|
||
class BaseExecutor: | ||
def __init__(self, model): | ||
self.model = model | ||
|
||
def _get_runner(self, config): | ||
module_name = f"{self.model.__name__}.simulator" | ||
module = importlib.import_module(module_name) | ||
registry = module.get_registry() | ||
runner = Runner(config, registry) | ||
return runner | ||
|
||
|
||
class Executor(BaseExecutor): | ||
def __init__(self, model, data_loader=None, pop_loader=None) -> None: | ||
super().__init__(model) | ||
if pop_loader: | ||
self.pop_loader = pop_loader | ||
self.data_loader = DataLoader(model, self.pop_loader) | ||
else: | ||
self.data_loader = data_loader | ||
|
||
self.config = self.data_loader.get_config() | ||
self.runner = self._get_runner(self.config) | ||
|
||
def init(self): | ||
self.runner.init() | ||
# self.learnable_params = [ | ||
# param for param in self.runner.parameters() if param.requires_grad | ||
# ] | ||
# self.opt = opt(self.learnable_params) | ||
|
||
def execute(self, key=None): | ||
num_episodes = self.config["simulation_metadata"]["num_episodes"] | ||
num_steps_per_episode = self.config["simulation_metadata"][ | ||
"num_steps_per_episode" | ||
] | ||
|
||
for episode in trange(num_episodes): | ||
# self.opt.zero_grad() | ||
self.runner.reset() | ||
self.runner.step(num_steps_per_episode) | ||
|
||
if key is not None: | ||
self.simulation_values = self.runner.get_simulation_values(key) | ||
|
||
def get_simulation_values(self, key, key_type="environment"): | ||
self.simulation_values = self.runner.state_trajectory[-1][-1][key_type][ | ||
key | ||
] # List containing values for each step | ||
return self.simulation_values |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from agent_torch.core.helpers.general import * | ||
from agent_torch.core.helpers.environment import * | ||
from agent_torch.core.helpers.initializer import * | ||
from agent_torch.core.helpers.soft import * |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import abc | ||
import os | ||
|
||
|
||
class MemoryHandler(abc.ABC): | ||
"""Abstract base class for handling memory operations.""" | ||
|
||
@abc.abstractmethod | ||
def save_memory(self, context_in, context_out, agent_id): | ||
"""Save conversation context to memory.""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def get_memory(self, last_k, agent_id): | ||
"""Retrieve memory for the specified agent and last_k messages.""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def clear_memory(self, agent_id): | ||
"""Clear memory for the specified agent.""" | ||
pass | ||
|
||
@abc.abstractmethod | ||
def export_memory_to_file(self, file_dir, last_k): | ||
"""Export memory to a file.""" | ||
pass | ||
|
||
|
||
class DSPYMemoryHandler(MemoryHandler): | ||
"""Concrete implementation of MemoryHandler for dspy backend.""" | ||
|
||
def __init__(self, agent_memory, llm): | ||
self.agent_memory = agent_memory | ||
self.llm = llm | ||
|
||
def save_memory(self, query, output, agent_id): | ||
self.agent_memory[agent_id].save_context( | ||
{"input": query["agent_query"]}, {"output": output} | ||
) | ||
|
||
def get_memory(self, last_k, agent_id): | ||
last_k_memory = { | ||
"chat_history": self.agent_memory[agent_id].load_memory_variables({})[ | ||
"chat_history" | ||
][-last_k:] | ||
} | ||
return last_k_memory | ||
|
||
def clear_memory(self, agent_id): | ||
self.agent_memory[agent_id].clear() | ||
|
||
def export_memory_to_file(self, file_dir, last_k): | ||
if not os.path.exists(file_dir): | ||
os.makedirs(file_dir) | ||
for id in range(len(self.agent_memory)): | ||
file_name = f"output_mem_{id}.md" | ||
file_path = os.path.join(file_dir, file_name) | ||
memory = self.get_memory(agent_id=id, last_k=last_k) | ||
with open(file_path, "w") as f: | ||
f.write(str(memory)) | ||
self.llm.inspect_history(file_dir=file_dir, last_k=last_k) | ||
|
||
|
||
class LangchainMemoryHandler(MemoryHandler): | ||
"""Concrete implementation of MemoryHandler for langchain backend.""" | ||
|
||
def __init__(self, agent_memory): | ||
self.agent_memory = agent_memory | ||
|
||
def save_memory(self, query, output, agent_id): | ||
self.agent_memory[agent_id].save_context( | ||
{"input": query["agent_query"]}, {"output": output["text"]} | ||
) | ||
|
||
def get_memory(self, last_k, agent_id): | ||
last_k_memory = { | ||
"chat_history": self.agent_memory[agent_id].load_memory_variables({})[ | ||
"chat_history" | ||
][-last_k:] | ||
} | ||
return last_k_memory | ||
|
||
def clear_memory(self, agent_id): | ||
self.agent_memory[agent_id].clear() | ||
|
||
def export_memory_to_file(self, file_dir, last_k): | ||
if not os.path.exists(file_dir): | ||
os.makedirs(file_dir) | ||
for id in range(len(self.agent_memory)): | ||
file_name = f"output_mem_{id}.md" | ||
file_path = os.path.join(file_dir, file_name) | ||
memory = self.get_memory(agent_id=id, last_k=last_k) | ||
with open(file_path, "w") as f: | ||
f.write(str(memory)) |
Oops, something went wrong.