Skip to content

Commit

Permalink
Refactoring file and directory handling (#5)
Browse files Browse the repository at this point in the history
1. **Pathlib Integration**: The use of `os` module is replaced with Python's more modern `pathlib` module across various tools. This shift enhances readability and consistency in handling file paths.
2. **Field Validation Enhancements**:
   - A new field validator, `check_directory_traversal`, is implemented to prevent directory traversal attacks. It ensures the provided paths are within allowed directories.
   - The `start_directory` fields in both `BuildDirectoryTree` and `PrintAllFilesInDirectory` classes now use `Path` instead of `str`, ensuring better path handling.
3. **Refactoring File Extensions Handling**:
   - File extensions are now handled using a set (`set[str]`) instead of a list (`list[str] | None`). This change simplifies the logic for including file types in the output.
4. **Improved Directory Traversal**:
   - `BuildDirectoryTree` and `PrintAllFilesInDirectory` classes have been refactored to use recursive functions for directory traversal, enhancing their efficiency and readability.
5. **Testing Adjustments**:
   - Tests have been updated to reflect the changes in the implementation, particularly the use of `pathlib` and the new field validations.
   - A new test file, `test_utils.py`, is added to specifically test the `check_directory_traversal` function.
  • Loading branch information
guiparpinelli authored Dec 13, 2023
1 parent 4a6f52f commit 6129357
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 71 deletions.
49 changes: 22 additions & 27 deletions src/nalgonda/custom_tools/build_directory_tree.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,41 @@
import os
from pathlib import Path

from agency_swarm import BaseTool
from pydantic import Field
from pydantic import Field, field_validator

from nalgonda.custom_tools.utils import check_directory_traversal


class BuildDirectoryTree(BaseTool):
"""Print the structure of directories and files."""

start_directory: str = Field(
default_factory=lambda: os.getcwd(),
start_directory: Path = Field(
default_factory=Path.cwd,
description="The starting directory for the tree, defaults to the current working directory.",
)
file_extensions: list[str] | None = Field(
default_factory=lambda: None,
description="List of file extensions to include in the tree. If None, all files will be included.",
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the tree. If empty, all files will be included.",
)

def run(self) -> str:
"""Run the tool."""
self._validate_start_directory()
tree_str = self.print_tree()
return tree_str
_validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal)

def print_tree(self):
"""Recursively print the tree of directories and files using os.walk."""
def run(self) -> str:
"""Recursively print the tree of directories and files using pathlib."""
tree_str = ""
start_path = self.start_directory.resolve()

for root, _, files in os.walk(self.start_directory, topdown=True):
level = root.replace(self.start_directory, "").count(os.sep)
def recurse(directory: Path, level: int = 0) -> None:
nonlocal tree_str
indent = " " * 4 * level
tree_str += f"{indent}{os.path.basename(root)}\n"
tree_str += f"{indent}{directory.name}\n"
sub_indent = " " * 4 * (level + 1)

for f in files:
if not self.file_extensions or f.endswith(tuple(self.file_extensions)):
tree_str += f"{sub_indent}{f}\n"
for path in sorted(directory.iterdir()):
if path.is_dir():
recurse(path, level + 1)
elif path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions):
tree_str += f"{sub_indent}{path.name}\n"

recurse(start_path)
return tree_str

def _validate_start_directory(self):
"""Do not allow directory traversal."""
if ".." in self.start_directory or (
self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp")
):
raise ValueError("Directory traversal is not allowed.")
48 changes: 23 additions & 25 deletions src/nalgonda/custom_tools/print_all_files_in_directory.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,42 @@
import os
from pathlib import Path

from agency_swarm import BaseTool
from pydantic import Field
from pydantic import Field, field_validator

from nalgonda.custom_tools.utils import check_directory_traversal


class PrintAllFilesInDirectory(BaseTool):
"""Print the contents of all files in a start_directory recursively."""

start_directory: str = Field(
default_factory=lambda: os.getcwd(),
start_directory: Path = Field(
default_factory=Path.cwd,
description="Directory to search for Python files, by default the current working directory.",
)
file_extensions: list[str] | None = Field(
default_factory=lambda: None,
description="List of file extensions to include in the output. If None, all files will be included.",
file_extensions: set[str] = Field(
default_factory=set,
description="Set of file extensions to include in the output. If empty, all files will be included.",
)

def run(self) -> str:
"""Run the tool."""
self._validate_start_directory()
_validate_start_directory = field_validator("start_directory", mode="before")(check_directory_traversal)

def run(self) -> str:
"""
Recursively searches for files within `start_directory` and compiles their contents into a single string.
"""
output = []
for root, _, files in os.walk(self.start_directory, topdown=True):
for file in files:
if not self.file_extensions or file.endswith(tuple(self.file_extensions)):
file_path = os.path.join(root, file)
output.append(f"{file_path}:\n```\n{self.read_file(file_path)}\n```\n")
start_path = self.start_directory.resolve()

for path in start_path.rglob("*"):
if path.is_file() and (not self.file_extensions or path.suffix in self.file_extensions):
output.append(f"{str(path)}:\n```\n{self.read_file(path)}\n```\n")

return "\n".join(output)

@staticmethod
def read_file(file_path):
def read_file(file_path: Path):
"""Read and return the contents of a file."""
try:
with open(file_path, "r") as file:
return file.read()
return file_path.read_text()
except IOError as e:
return f"Error reading file {file_path}: {e}"

def _validate_start_directory(self):
"""Do not allow directory traversal."""
if ".." in self.start_directory or (
self.start_directory.startswith("/") and not self.start_directory.startswith("/tmp")
):
raise ValueError("Directory traversal is not allowed.")
20 changes: 19 additions & 1 deletion src/nalgonda/custom_tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import tempfile
from pathlib import Path

from agency_swarm.util import get_openai_client


def get_chat_completion(user_prompt, system_message, **kwargs) -> str:
"""Generate a chat completion based on a prompt and a system message.
"""
Generate a chat completion based on a prompt and a system message.
This function is a wrapper around the OpenAI API.
"""
from config import settings
Expand All @@ -24,3 +28,17 @@ def get_chat_completion(user_prompt, system_message, **kwargs) -> str:
)

return str(completion.choices[0].message.content)


def check_directory_traversal(dir_path: str) -> Path:
"""
Ensures that the given directory path is within allowed paths.
"""
path = Path(dir_path)
if ".." in path.parts:
raise ValueError("Directory traversal is not allowed.")

allowed_bases = [Path(tempfile.gettempdir()).resolve(), Path.home().resolve()]
if not any(str(path.resolve()).startswith(str(base)) for base in allowed_bases):
raise ValueError("Directory traversal is not allowed.")
return path
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
import sys
from pathlib import Path

import pytest


# each test runs on cwd to its temp dir
@pytest.fixture(autouse=True)
def go_to_tmpdir(request):
# Get the fixture dynamically by its name.
tmpdir = request.getfixturevalue("tmpdir")
# ensure local test created packages can be imported
sys.path.insert(0, str(tmpdir))
# Chdir only for the duration of the test.
with tmpdir.as_cwd():
yield


@pytest.fixture
def temp_dir(tmp_path: Path):
"""Create a temporary directory with some files inside it.
Expand Down
12 changes: 7 additions & 5 deletions tests/custom_tools/test_build_directory_tree.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path

from nalgonda.custom_tools import BuildDirectoryTree

Expand All @@ -7,7 +7,7 @@ def test_build_directory_tree_with_py_extension(temp_dir):
"""
Test if BuildDirectoryTree correctly lists only .py files in the directory tree.
"""
bdt = BuildDirectoryTree(start_directory=str(temp_dir), file_extensions=[".py"])
bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions={".py"})
expected_output = f"{temp_dir.name}\n sub\n test.py\n"
assert bdt.run() == expected_output

Expand All @@ -16,7 +16,7 @@ def test_build_directory_tree_with_multiple_extensions(temp_dir):
"""
Test if BuildDirectoryTree lists files with multiple specified extensions.
"""
bdt = BuildDirectoryTree(start_directory=str(temp_dir), file_extensions=[".py", ".txt"])
bdt = BuildDirectoryTree(start_directory=temp_dir, file_extensions={".py", ".txt"})
expected_output = {
f"{temp_dir.name}",
" sub",
Expand All @@ -32,5 +32,7 @@ def test_build_directory_tree_default_settings():
Test if BuildDirectoryTree uses the correct default settings.
"""
bdt = BuildDirectoryTree()
assert bdt.start_directory == os.getcwd()
assert bdt.file_extensions is None
assert bdt.start_directory == Path.cwd()
assert bdt.file_extensions == set()


23 changes: 10 additions & 13 deletions tests/custom_tools/test_print_all_files_in_directory.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os

from nalgonda.custom_tools import PrintAllFilesInDirectory


def test_print_all_files_no_extension_filter(temp_dir):
"""
Test if PrintAllFilesInDirectory correctly prints contents of all files when no file extension filter is applied.
"""
pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir))
pafid = PrintAllFilesInDirectory(start_directory=temp_dir)
expected_output = {
f"{temp_dir}/sub/test.py:\n```\nprint('hello')\n```",
f"{temp_dir}/sub/test.txt:\n```\nhello world\n```",
Expand All @@ -20,17 +18,17 @@ def test_print_all_files_with_py_extension(temp_dir):
"""
Test if PrintAllFilesInDirectory correctly prints contents of .py files only.
"""
pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".py"])
expected_output = f"{os.path.join(temp_dir, 'sub', 'test.py')}:\n```\nprint('hello')\n```\n"
pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".py"})
expected_output = f"{temp_dir.joinpath('sub', 'test.py')}:\n```\nprint('hello')\n```\n"
assert pafid.run() == expected_output


def test_print_all_files_with_txt_extension(temp_dir):
"""
Test if PrintAllFilesInDirectory correctly prints contents of .txt files only.
"""
pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".txt"])
expected_output = f"{os.path.join(temp_dir, 'sub', 'test.txt')}:\n```\nhello world\n```\n"
pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".txt"})
expected_output = f"{temp_dir.joinpath('sub', 'test.txt')}:\n```\nhello world\n```\n"
assert pafid.run() == expected_output


Expand All @@ -39,12 +37,11 @@ def test_print_all_files_error_reading_file(temp_dir):
Test if PrintAllFilesInDirectory handles errors while reading a file.
"""
# Create an unreadable file
unreadable_file = os.path.join(temp_dir, "unreadable_file.txt")
with open(unreadable_file, "w") as f:
f.write("content")
os.chmod(unreadable_file, 0o000) # make the file unreadable
unreadable_file = temp_dir.joinpath("unreadable_file.txt")
unreadable_file.write_text("content")
unreadable_file.chmod(0o000) # make the file unreadable

pafid = PrintAllFilesInDirectory(start_directory=str(temp_dir), file_extensions=[".txt"])
pafid = PrintAllFilesInDirectory(start_directory=temp_dir, file_extensions={".txt"})
assert "Error reading file" in pafid.run()

os.chmod(unreadable_file, 0o644) # reset file permissions for cleanup
unreadable_file.chmod(0o644) # reset file permissions for cleanup
10 changes: 10 additions & 0 deletions tests/custom_tools/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from nalgonda.custom_tools.utils import check_directory_traversal

@pytest.mark.parametrize("path", ["..", "/", "/sbin"])
def test_check_directory_traversal_raises_for_attempts(path):
with pytest.raises(ValueError) as e:
check_directory_traversal(path)
assert e.errisinstance(ValueError)
assert "Directory traversal is not allowed." in str(e.value)

0 comments on commit 6129357

Please sign in to comment.