-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactoring file and directory handling (#5)
1. **Pathlib Integration**: The use of `os` module is replaced with Python's more modern `pathlib` module across various tools. This shift enhances readability and consistency in handling file paths. 2. **Field Validation Enhancements**: - A new field validator, `check_directory_traversal`, is implemented to prevent directory traversal attacks. It ensures the provided paths are within allowed directories. - The `start_directory` fields in both `BuildDirectoryTree` and `PrintAllFilesInDirectory` classes now use `Path` instead of `str`, ensuring better path handling. 3. **Refactoring File Extensions Handling**: - File extensions are now handled using a set (`set[str]`) instead of a list (`list[str] | None`). This change simplifies the logic for including file types in the output. 4. **Improved Directory Traversal**: - `BuildDirectoryTree` and `PrintAllFilesInDirectory` classes have been refactored to use recursive functions for directory traversal, enhancing their efficiency and readability. 5. **Testing Adjustments**: - Tests have been updated to reflect the changes in the implementation, particularly the use of `pathlib` and the new field validations. - A new test file, `test_utils.py`, is added to specifically test the `check_directory_traversal` function.
- Loading branch information
1 parent
4a6f52f
commit 6129357
Showing
7 changed files
with
104 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,41 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
from agency_swarm import BaseTool | ||
from pydantic import Field | ||
from pydantic import Field, field_validator | ||
|
||
from nalgonda.custom_tools.utils import check_directory_traversal | ||
|
||
|
||
class BuildDirectoryTree(BaseTool): | ||
"""Print the structure of directories and files.""" | ||
|
||
start_directory: str = Field( | ||
default_factory=lambda: os.getcwd(), | ||
start_directory: Path = Field( | ||
default_factory=Path.cwd, | ||
description="The starting directory for the tree, defaults to the current working directory.", | ||
) | ||
file_extensions: list[str] | None = Field( | ||
default_factory=lambda: None, | ||
description="List of file extensions to include in the tree. If None, all files will be included.", | ||
file_extensions: set[str] = Field( | ||
default_factory=set, | ||
description="Set of file extensions to include in the tree. If empty, all files will be included.", | ||
) | ||
|
||
def run(self) -> str: | ||
"""Run the tool.""" | ||
self._validate_start_directory() | ||
tree_str = self.print_tree() | ||
return tree_str | ||
_validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal) | ||
|
||
def print_tree(self): | ||
"""Recursively print the tree of directories and files using os.walk.""" | ||
def run(self) -> str: | ||
"""Recursively print the tree of directories and files using pathlib.""" | ||
tree_str = "" | ||
start_path = self.start_directory.resolve() | ||
|
||
for root, _, files in os.walk(self.start_directory, topdown=True): | ||
level = root.replace(self.start_directory, "").count(os.sep) | ||
def recurse(directory: Path, level: int = 0) -> None: | ||
nonlocal tree_str | ||
indent = " " * 4 * level | ||
tree_str += f"{indent}{os.path.basename(root)}\n" | ||
tree_str += f"{indent}{directory.name}\n" | ||
sub_indent = " " * 4 * (level + 1) | ||
|
||
for f in files: | ||
if not self.file_extensions or f.endswith(tuple(self.file_extensions)): | ||
tree_str += f"{sub_indent}{f}\n" | ||
for path in sorted(directory.iterdir()): | ||
if path.is_dir(): | ||
recurse(path, level + 1) | ||
elif path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions): | ||
tree_str += f"{sub_indent}{path.name}\n" | ||
|
||
recurse(start_path) | ||
return tree_str | ||
|
||
def _validate_start_directory(self): | ||
"""Do not allow directory traversal.""" | ||
if ".." in self.start_directory or ( | ||
self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp") | ||
): | ||
raise ValueError("Directory traversal is not allowed.") |
48 changes: 23 additions & 25 deletions
48
src/nalgonda/custom_tools/print_all_files_in_directory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,42 @@ | ||
import os | ||
from pathlib import Path | ||
|
||
from agency_swarm import BaseTool | ||
from pydantic import Field | ||
from pydantic import Field, field_validator | ||
|
||
from nalgonda.custom_tools.utils import check_directory_traversal | ||
|
||
|
||
class PrintAllFilesInDirectory(BaseTool): | ||
"""Print the contents of all files in a start_directory recursively.""" | ||
|
||
start_directory: str = Field( | ||
default_factory=lambda: os.getcwd(), | ||
start_directory: Path = Field( | ||
default_factory=Path.cwd, | ||
description="Directory to search for Python files, by default the current working directory.", | ||
) | ||
file_extensions: list[str] | None = Field( | ||
default_factory=lambda: None, | ||
description="List of file extensions to include in the output. If None, all files will be included.", | ||
file_extensions: set[str] = Field( | ||
default_factory=set, | ||
description="Set of file extensions to include in the output. If empty, all files will be included.", | ||
) | ||
|
||
def run(self) -> str: | ||
"""Run the tool.""" | ||
self._validate_start_directory() | ||
_validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal) | ||
|
||
def run(self) -> str: | ||
""" | ||
Recursively searches for files within `start_directory` and compiles their contents into a single string. | ||
""" | ||
output = [] | ||
for root, _, files in os.walk(self.start_directory, topdown=True): | ||
for file in files: | ||
if not self.file_extensions or file.endswith(tuple(self.file_extensions)): | ||
file_path = os.path.join(root, file) | ||
output.append(f"{file_path}:\n```\n{self.read_file(file_path)}\n```\n") | ||
start_path = self.start_directory.resolve() | ||
|
||
for path in start_path.rglob("*"): | ||
if path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions): | ||
output.append(f"{str(path)}:\n```\n{self.read_file(path)}\n```\n") | ||
|
||
return "\n".join(output) | ||
|
||
@staticmethod | ||
def read_file(file_path): | ||
def read_file(file_path: Path): | ||
"""Read and return the contents of a file.""" | ||
try: | ||
with open(file_path, "r") as file: | ||
return file.read() | ||
return file_path.read_text() | ||
except IOError as e: | ||
return f"Error reading file {file_path}: {e}" | ||
|
||
def _validate_start_directory(self): | ||
"""Do not allow directory traversal.""" | ||
if ".." in self.start_directory or ( | ||
self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp") | ||
): | ||
raise ValueError("Directory traversal is not allowed.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import pytest | ||
|
||
from nalgonda.custom_tools.utils import check_directory_traversal | ||
|
||
@pytest.mark.parametrize("path", ["..", "/", "/sbin"]) | ||
def test_check_directory_traversal_raises_for_attempts(path): | ||
with pytest.raises(ValueError) as e: | ||
check_directory_traversal(path) | ||
assert e.errisinstance(ValueError) | ||
assert "Directory traversal is not allowed." in str(e.value) |