diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index 002e9e2..3363dae 100644 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -56,7 +56,6 @@ re.compile(r"^from ((?!\.+).*?) import (?:.*)$"), ] - @contextmanager def _open(filename=None, mode="r"): """Open a file or ``sys.stdout`` depending on the provided filename. @@ -199,11 +198,16 @@ def generate_requirements_file(path, imports, symbol): num=len(imports), file=path, imports=", ".join([x["name"] for x in imports]), + ) ) fmt = "{name}" + symbol + "{version}" out_file.write( - "\n".join(fmt.format(**item) if item["version"] else "{name}".format(**item) for item in imports) + "\n" + "\n".join( + fmt.format(**item) if item["version"] else "{name}".format(**item) + for item in imports + ) + + "\n" ) @@ -301,7 +305,7 @@ def get_import_local(imports, encoding=None): # had to use second method instead of the previous one, # because we have a list in the 'exports' field # https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python - result_unique = [i for n, i in enumerate(result) if i not in result[n + 1 :]] + result_unique = [i for n, i in enumerate(result) if i not in result[n + 1:]] return result_unique @@ -346,6 +350,9 @@ def parse_requirements(file_): delimiter, get module name by element index, create a dict consisting of module:version, and add dict to list of parsed modules. + If file ´file_´ is not found in the system, the program will print a + helpful message and end its execution immediately. + Args: file_: File to parse. @@ -362,9 +369,12 @@ def parse_requirements(file_): try: f = open(file_, "r") - except OSError: - logging.error("Failed on file: {}".format(file_)) - raise + except FileNotFoundError: + print(f"File {file_} was not found. Please, fix it and run again.") + sys.exit(1) + except OSError as error: + logging.error(f"There was an error opening the file {file_}: {str(error)}") + raise error else: try: data = [x.strip() for x in f.readlines() if x != "\n"] @@ -476,9 +486,16 @@ def init(args): if extra_ignore_dirs: extra_ignore_dirs = extra_ignore_dirs.split(",") - - path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt") - if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path): + + path = ( + args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt") + ) + if ( + not args["--print"] + and not args["--savepath"] + and not args["--force"] + and os.path.exists(path) + ): logging.warning("requirements.txt already exists, " "use --force to overwrite it") return @@ -538,7 +555,9 @@ def init(args): if scheme in ["compat", "gt", "no-pin"]: imports, symbol = dynamic_versioning(scheme, imports) else: - raise ValueError("Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead") + raise ValueError( + "Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead" + ) else: symbol = "==" diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index a2cbbbe..988a72f 100644 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -8,11 +8,12 @@ Tests for `pipreqs` module. """ -import io -import sys +from io import StringIO +from unittest.mock import patch import unittest import os import requests +import sys from pipreqs import pipreqs @@ -80,6 +81,8 @@ def setUp(self): "original": os.path.join(os.path.dirname(__file__), "_data/test.py"), "notebook": os.path.join(os.path.dirname(__file__), "_data_notebook/test.ipynb"), } + self.non_existing_filepath = "xpto" + def test_get_all_imports(self): imports = pipreqs.get_all_imports(self.project) @@ -479,7 +482,7 @@ def test_output_requirements(self): It should print to stdout the same content as requeriments.txt """ - capturedOutput = io.StringIO() + capturedOutput = StringIO() sys.stdout = capturedOutput pipreqs.init( @@ -583,6 +586,23 @@ def test_parse_requirements(self): self.assertListEqual(parsed_requirements, expected_parsed_requirements) + @patch("sys.exit") + def test_parse_requirements_handles_file_not_found(self, exit_mock): + captured_output = StringIO() + sys.stdout = captured_output + + # This assertion is needed, because since "sys.exit" is mocked, the program won't end, + # and the code that is after the except block will be run + with self.assertRaises(UnboundLocalError): + pipreqs.parse_requirements(self.non_existing_filepath) + + exit_mock.assert_called_once_with(1) + + printed_text = captured_output.getvalue().strip() + sys.stdout = sys.__stdout__ + + self.assertEqual(printed_text, "File xpto was not found. Please, fix it and run again.") + def tearDown(self): """ Remove requiremnts.txt files that were written