Skip to content

Commit

Permalink
single shot solution
Browse files Browse the repository at this point in the history
  • Loading branch information
hammoudhasan committed Jun 26, 2023
1 parent 2150be5 commit aa2b501
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
9 changes: 5 additions & 4 deletions examples/ai_society/role_playing_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@

from camel.configs import ChatGPTConfig
from camel.societies import RolePlaying
from camel.typing import TaskType
from camel.typing import TaskType, ModelType
from camel.utils import download_tasks


def generate_data(assistant_idx: int, assistant_role_name: str, user_idx: int,
user_role_name: str, task_idx: int, task_prompt: str,
verbose: bool = False) -> None:

max_num_messages = 40
max_num_messages = 100

original_task_prompt = task_prompt.replace(f"{task_idx+1}. ", "")

Expand All @@ -38,6 +38,7 @@ def generate_data(assistant_idx: int, assistant_role_name: str, user_idx: int,
task_prompt=original_task_prompt,
with_task_specify=True,
with_task_planner=False,
model_type=ModelType.GPT_3_5_TURBO_16K,
task_specify_agent_kwargs=dict(model_config=ChatGPTConfig(
temperature=1.4)),
)
Expand Down Expand Up @@ -204,7 +205,7 @@ def main() -> None:
try:
slurm_array_task_id = os.environ.get('SLURM_ARRAY_TASK_ID')
if slurm_array_task_id is None:
raise
raise ValueError("SLURM_ARRAY_TASK_ID is not set")
array_idx = int(slurm_array_task_id)
except (TypeError, ValueError) as e:
print(f"Error: {e}")
Expand All @@ -227,7 +228,7 @@ def main() -> None:
roles_per_chunk:(array_idx + 1) *
roles_per_chunk]

pool = multiprocessing.Pool()
pool = multiprocessing.Pool(processes=10)

for assistant_idx, assistant_role_name in enumerate(assistant_roles):
assistant_idx += array_idx * roles_per_chunk
Expand Down
5 changes: 2 additions & 3 deletions examples/single_shot/pair_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from camel.typing import TaskType, RoleType


def main(key: str = "generate_users", num_roles: int = 50):
def main():

single_shot_template = SingleShotPromptTemplateDict()
assistant_sys_msg_prompt = single_shot_template[RoleType.ASSISTANT]
Expand All @@ -39,5 +39,4 @@ def main(key: str = "generate_users", num_roles: int = 50):


if __name__ == "__main__":
main("generate_users", 50)
main("generate_assistants", 50)
main()

0 comments on commit aa2b501

Please sign in to comment.