Skip to content

Commit

Permalink
Merge branch 'next' into add_notebook_support
Browse files Browse the repository at this point in the history
  • Loading branch information
Fernando-crz authored Oct 23, 2023
2 parents 190ce1c + 368e9ae commit 1c9a03e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
39 changes: 29 additions & 10 deletions pipreqs/pipreqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
re.compile(r"^from ((?!\.+).*?) import (?:.*)$"),
]


@contextmanager
def _open(filename=None, mode="r"):
"""Open a file or ``sys.stdout`` depending on the provided filename.
Expand Down Expand Up @@ -199,11 +198,16 @@ def generate_requirements_file(path, imports, symbol):
num=len(imports),
file=path,
imports=", ".join([x["name"] for x in imports]),

)
)
fmt = "{name}" + symbol + "{version}"
out_file.write(
"\n".join(fmt.format(**item) if item["version"] else "{name}".format(**item) for item in imports) + "\n"
"\n".join(
fmt.format(**item) if item["version"] else "{name}".format(**item)
for item in imports
)
+ "\n"
)


Expand Down Expand Up @@ -301,7 +305,7 @@ def get_import_local(imports, encoding=None):
# had to use second method instead of the previous one,
# because we have a list in the 'exports' field
# https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python
result_unique = [i for n, i in enumerate(result) if i not in result[n + 1 :]]
result_unique = [i for n, i in enumerate(result) if i not in result[n + 1:]]

return result_unique

Expand Down Expand Up @@ -346,6 +350,9 @@ def parse_requirements(file_):
delimiter, get module name by element index, create a dict consisting of
module:version, and add dict to list of parsed modules.
If file ´file_´ is not found in the system, the program will print a
helpful message and end its execution immediately.
Args:
file_: File to parse.
Expand All @@ -362,9 +369,12 @@ def parse_requirements(file_):

try:
f = open(file_, "r")
except OSError:
logging.error("Failed on file: {}".format(file_))
raise
except FileNotFoundError:
print(f"File {file_} was not found. Please, fix it and run again.")
sys.exit(1)
except OSError as error:
logging.error(f"There was an error opening the file {file_}: {str(error)}")
raise error
else:
try:
data = [x.strip() for x in f.readlines() if x != "\n"]
Expand Down Expand Up @@ -476,9 +486,16 @@ def init(args):

if extra_ignore_dirs:
extra_ignore_dirs = extra_ignore_dirs.split(",")

path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path):

Check warning on line 489 in pipreqs/pipreqs.py

View workflow job for this annotation

GitHub Actions / Lint

[flake8] reported by reviewdog 🐶 blank line contains whitespace Raw Output: ./pipreqs/pipreqs.py:489:1: W293 blank line contains whitespace
path = (
args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
)
if (
not args["--print"]
and not args["--savepath"]
and not args["--force"]
and os.path.exists(path)
):
logging.warning("requirements.txt already exists, " "use --force to overwrite it")
return

Expand Down Expand Up @@ -538,7 +555,9 @@ def init(args):
if scheme in ["compat", "gt", "no-pin"]:
imports, symbol = dynamic_versioning(scheme, imports)
else:
raise ValueError("Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead")
raise ValueError(
"Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead"
)
else:
symbol = "=="

Expand Down
26 changes: 23 additions & 3 deletions tests/test_pipreqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
Tests for `pipreqs` module.
"""

import io
import sys
from io import StringIO
from unittest.mock import patch
import unittest
import os
import requests
import sys

from pipreqs import pipreqs

Expand Down Expand Up @@ -80,6 +81,8 @@ def setUp(self):
"original": os.path.join(os.path.dirname(__file__), "_data/test.py"),
"notebook": os.path.join(os.path.dirname(__file__), "_data_notebook/test.ipynb"),
}
self.non_existing_filepath = "xpto"


def test_get_all_imports(self):
imports = pipreqs.get_all_imports(self.project)
Expand Down Expand Up @@ -479,7 +482,7 @@ def test_output_requirements(self):
It should print to stdout the same content as requeriments.txt
"""

capturedOutput = io.StringIO()
capturedOutput = StringIO()
sys.stdout = capturedOutput

pipreqs.init(
Expand Down Expand Up @@ -583,6 +586,23 @@ def test_parse_requirements(self):

self.assertListEqual(parsed_requirements, expected_parsed_requirements)

@patch("sys.exit")
def test_parse_requirements_handles_file_not_found(self, exit_mock):
captured_output = StringIO()
sys.stdout = captured_output

# This assertion is needed, because since "sys.exit" is mocked, the program won't end,
# and the code that is after the except block will be run
with self.assertRaises(UnboundLocalError):
pipreqs.parse_requirements(self.non_existing_filepath)

exit_mock.assert_called_once_with(1)

printed_text = captured_output.getvalue().strip()
sys.stdout = sys.__stdout__

self.assertEqual(printed_text, "File xpto was not found. Please, fix it and run again.")

def tearDown(self):
"""
Remove requiremnts.txt files that were written
Expand Down

0 comments on commit 1c9a03e

Please sign in to comment.