diff --git a/agentstack/generation/gen_utils.py b/agentstack/generation/gen_utils.py index 79f983f..3faf78c 100644 --- a/agentstack/generation/gen_utils.py +++ b/agentstack/generation/gen_utils.py @@ -1,3 +1,6 @@ +import ast + + def insert_code_after_tag(file_path, tag, code_to_insert, next_line=False): if next_line: code_to_insert = ['\n'] + code_to_insert @@ -16,3 +19,43 @@ def insert_code_after_tag(file_path, tag, code_to_insert, next_line=False): with open(file_path, 'w') as file: file.writelines(lines) + + +def insert_after_tasks(file_path, code_to_insert): + with open(file_path, 'r') as file: + content = file.read() + + module = ast.parse(content) + + # Track the last task function and its line number + last_task_end = None + last_task_start = None + for node in ast.walk(module): + if isinstance(node, ast.FunctionDef) and \ + any(isinstance(deco, ast.Name) and deco.id == 'task' for deco in node.decorator_list): + last_task_end = node.end_lineno + last_task_start = node.lineno + + if last_task_end is not None: + lines = content.split('\n') + + # Get the indentation of the task function + task_line = lines[last_task_start - 1] # -1 for 0-based indexing + indentation = '' + for char in task_line: + if char in [' ', '\t']: + indentation += char + else: + break + + # Add the same indentation to each line of the inserted code + indented_code = '\n' + '\n'.join(indentation + line for line in code_to_insert) + + lines.insert(last_task_end, indented_code) + content = '\n'.join(lines) + + with open(file_path, 'w') as file: + file.write(content) + return True + return False + diff --git a/agentstack/generation/task_generation.py b/agentstack/generation/task_generation.py index ad6c761..7976fe8 100644 --- a/agentstack/generation/task_generation.py +++ b/agentstack/generation/task_generation.py @@ -1,6 +1,6 @@ from typing import Optional -from .gen_utils import insert_code_after_tag +from .gen_utils import insert_code_after_tag, insert_after_tasks from ..utils import verify_agentstack_project, get_framework import os from ruamel.yaml import YAML @@ -77,7 +77,6 @@ def generate_crew_task( # Add task to crew.py file_path = 'src/crew.py' - tag = '# Task definitions' code_to_insert = [ "@task", f"def {name}(self) -> Task:", @@ -87,4 +86,4 @@ def generate_crew_task( "" ] - insert_code_after_tag(file_path, tag, code_to_insert) + insert_after_tasks(file_path, code_to_insert)