From 368e9ae7e7b53e414c793f1e0ad3890b8e228f12 Mon Sep 17 00:00:00 2001 From: mateuslatrova Date: Thu, 19 Oct 2023 19:52:08 -0300 Subject: [PATCH] handle FileNotFoundError in parse_requirements function --- pipreqs/pipreqs.py | 198 ++++++++++++++++++++++-------------------- tests/test_pipreqs.py | 29 +++++-- 2 files changed, 126 insertions(+), 101 deletions(-) diff --git a/pipreqs/pipreqs.py b/pipreqs/pipreqs.py index a84f39b..dbc804c 100644 --- a/pipreqs/pipreqs.py +++ b/pipreqs/pipreqs.py @@ -50,14 +50,11 @@ from pipreqs import __version__ -REGEXP = [ - re.compile(r'^import (.+)$'), - re.compile(r'^from ((?!\.+).*?) import (?:.*)$') -] +REGEXP = [re.compile(r"^import (.+)$"), re.compile(r"^from ((?!\.+).*?) import (?:.*)$")] @contextmanager -def _open(filename=None, mode='r'): +def _open(filename=None, mode="r"): """Open a file or ``sys.stdout`` depending on the provided filename. Args: @@ -70,13 +67,13 @@ def _open(filename=None, mode='r'): A file handle. """ - if not filename or filename == '-': - if not mode or 'r' in mode: + if not filename or filename == "-": + if not mode or "r" in mode: file = sys.stdin - elif 'w' in mode: + elif "w" in mode: file = sys.stdout else: - raise ValueError('Invalid mode for file: {}'.format(mode)) + raise ValueError("Invalid mode for file: {}".format(mode)) else: file = open(filename, mode) @@ -87,8 +84,7 @@ def _open(filename=None, mode='r'): file.close() -def get_all_imports( - path, encoding=None, extra_ignore_dirs=None, follow_links=True): +def get_all_imports(path, encoding=None, extra_ignore_dirs=None, follow_links=True): imports = set() raw_imports = set() candidates = [] @@ -137,11 +133,11 @@ def get_all_imports( # Cleanup: We only want to first part of the import. # Ex: from django.conf --> django.conf. But we only want django # as an import. - cleaned_name, _, _ = name.partition('.') + cleaned_name, _, _ = name.partition(".") imports.add(cleaned_name) packages = imports - (set(candidates) & imports) - logging.debug('Found packages: {0}'.format(packages)) + logging.debug("Found packages: {0}".format(packages)) with open(join("stdlib"), "r") as f: data = {x.strip() for x in f} @@ -151,56 +147,55 @@ def get_all_imports( def generate_requirements_file(path, imports, symbol): with _open(path, "w") as out_file: - logging.debug('Writing {num} requirements: {imports} to {file}'.format( - num=len(imports), - file=path, - imports=", ".join([x['name'] for x in imports]) - )) - fmt = '{name}' + symbol + '{version}' - out_file.write('\n'.join( - fmt.format(**item) if item['version'] else '{name}'.format(**item) - for item in imports) + '\n') + logging.debug( + "Writing {num} requirements: {imports} to {file}".format( + num=len(imports), file=path, imports=", ".join([x["name"] for x in imports]) + ) + ) + fmt = "{name}" + symbol + "{version}" + out_file.write( + "\n".join( + fmt.format(**item) if item["version"] else "{name}".format(**item) + for item in imports + ) + + "\n" + ) def output_requirements(imports, symbol): - generate_requirements_file('-', imports, symbol) + generate_requirements_file("-", imports, symbol) -def get_imports_info( - imports, pypi_server="https://pypi.python.org/pypi/", proxy=None): +def get_imports_info(imports, pypi_server="https://pypi.python.org/pypi/", proxy=None): result = [] for item in imports: try: logging.warning( - 'Import named "%s" not found locally. ' - 'Trying to resolve it at the PyPI server.', - item + 'Import named "%s" not found locally. ' "Trying to resolve it at the PyPI server.", + item, ) - response = requests.get( - "{0}{1}/json".format(pypi_server, item), proxies=proxy) + response = requests.get("{0}{1}/json".format(pypi_server, item), proxies=proxy) if response.status_code == 200: - if hasattr(response.content, 'decode'): + if hasattr(response.content, "decode"): data = json2package(response.content.decode()) else: data = json2package(response.content) elif response.status_code >= 300: - raise HTTPError(status_code=response.status_code, - reason=response.reason) + raise HTTPError(status_code=response.status_code, reason=response.reason) except HTTPError: - logging.warning( - 'Package "%s" does not exist or network problems', item) + logging.warning('Package "%s" does not exist or network problems', item) continue logging.warning( 'Import named "%s" was resolved to "%s:%s" package (%s).\n' - 'Please, verify manually the final list of requirements.txt ' - 'to avoid possible dependency confusions.', + "Please, verify manually the final list of requirements.txt " + "to avoid possible dependency confusions.", item, data.name, data.latest_release_id, - data.pypi_url + data.pypi_url, ) - result.append({'name': item, 'version': data.latest_release_id}) + result.append({"name": item, "version": data.latest_release_id}) return result @@ -225,25 +220,23 @@ def get_locally_installed_packages(encoding=None): filtered_top_level_modules = list() for module in top_level_modules: - if ( - (module not in ignore) and - (package[0] not in ignore) - ): + if (module not in ignore) and (package[0] not in ignore): # append exported top level modules to the list filtered_top_level_modules.append(module) version = None if len(package) > 1: - version = package[1].replace( - ".dist", "").replace(".egg", "") + version = package[1].replace(".dist", "").replace(".egg", "") # append package: top_level_modules pairs # instead of top_level_module: package pairs - packages.append({ - 'name': package[0], - 'version': version, - 'exports': filtered_top_level_modules - }) + packages.append( + { + "name": package[0], + "version": version, + "exports": filtered_top_level_modules, + } + ) return packages @@ -256,14 +249,14 @@ def get_import_local(imports, encoding=None): # if candidate import name matches export name # or candidate import name equals to the package name # append it to the result - if item in package['exports'] or item == package['name']: + if item in package["exports"] or item == package["name"]: result.append(package) # removing duplicates of package/version # had to use second method instead of the previous one, # because we have a list in the 'exports' field # https://stackoverflow.com/questions/9427163/remove-duplicate-dict-in-list-in-python - result_unique = [i for n, i in enumerate(result) if i not in result[n+1:]] + result_unique = [i for n, i in enumerate(result) if i not in result[n + 1:]] return result_unique @@ -294,7 +287,7 @@ def get_name_without_alias(name): match = REGEXP[0].match(name.strip()) if match: name = match.groups(0)[0] - return name.partition(' as ')[0].partition('.')[0].strip() + return name.partition(" as ")[0].partition(".")[0].strip() def join(f): @@ -308,6 +301,9 @@ def parse_requirements(file_): delimiter, get module name by element index, create a dict consisting of module:version, and add dict to list of parsed modules. + If file ´file_´ is not found in the system, the program will print a + helpful message and end its execution immediately. + Args: file_: File to parse. @@ -324,9 +320,12 @@ def parse_requirements(file_): try: f = open(file_, "r") - except OSError: - logging.error("Failed on file: {}".format(file_)) - raise + except FileNotFoundError: + print(f"File {file_} was not found. Please, fix it and run again.") + sys.exit(1) + except OSError as error: + logging.error(f"There was an error opening the file {file_}: {str(error)}") + raise error else: try: data = [x.strip() for x in f.readlines() if x != "\n"] @@ -353,6 +352,7 @@ def parse_requirements(file_): return modules + def compare_modules(file_, imports): """Compare modules in a file to imported modules in a project. @@ -379,7 +379,8 @@ def diff(file_, imports): logging.info( "The following modules are in {} but do not seem to be imported: " - "{}".format(file_, ", ".join(x for x in modules_not_imported))) + "{}".format(file_, ", ".join(x for x in modules_not_imported)) + ) def clean(file_, imports): @@ -427,30 +428,34 @@ def dynamic_versioning(scheme, imports): def init(args): - encoding = args.get('--encoding') - extra_ignore_dirs = args.get('--ignore') - follow_links = not args.get('--no-follow-links') - input_path = args[''] + encoding = args.get("--encoding") + extra_ignore_dirs = args.get("--ignore") + follow_links = not args.get("--no-follow-links") + input_path = args[""] if input_path is None: input_path = os.path.abspath(os.curdir) if extra_ignore_dirs: - extra_ignore_dirs = extra_ignore_dirs.split(',') - - path = (args["--savepath"] if args["--savepath"] else - os.path.join(input_path, "requirements.txt")) - if (not args["--print"] - and not args["--savepath"] - and not args["--force"] - and os.path.exists(path)): - logging.warning("requirements.txt already exists, " - "use --force to overwrite it") + extra_ignore_dirs = extra_ignore_dirs.split(",") + + path = ( + args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt") + ) + if ( + not args["--print"] + and not args["--savepath"] + and not args["--force"] + and os.path.exists(path) + ): + logging.warning("requirements.txt already exists, " "use --force to overwrite it") return - candidates = get_all_imports(input_path, - encoding=encoding, - extra_ignore_dirs=extra_ignore_dirs, - follow_links=follow_links) + candidates = get_all_imports( + input_path, + encoding=encoding, + extra_ignore_dirs=extra_ignore_dirs, + follow_links=follow_links, + ) candidates = get_pkg_names(candidates) logging.debug("Found imports: " + ", ".join(candidates)) pypi_server = "https://pypi.python.org/pypi/" @@ -459,11 +464,10 @@ def init(args): pypi_server = args["--pypi-server"] if args["--proxy"]: - proxy = {'http': args["--proxy"], 'https': args["--proxy"]} + proxy = {"http": args["--proxy"], "https": args["--proxy"]} if args["--use-local"]: - logging.debug( - "Getting package information ONLY from local installation.") + logging.debug("Getting package information ONLY from local installation.") imports = get_import_local(candidates, encoding=encoding) else: logging.debug("Getting packages information from Local/PyPI") @@ -473,20 +477,21 @@ def init(args): # the list of exported modules, installed locally # and the package name is not in the list of local module names # it add to difference - difference = [x for x in candidates if - # aggregate all export lists into one - # flatten the list - # check if candidate is in exports - x.lower() not in [y for x in local for y in x['exports']] - and - # check if candidate is package names - x.lower() not in [x['name'] for x in local]] - - imports = local + get_imports_info(difference, - proxy=proxy, - pypi_server=pypi_server) + difference = [ + x + for x in candidates + if + # aggregate all export lists into one + # flatten the list + # check if candidate is in exports + x.lower() not in [y for x in local for y in x["exports"]] and + # check if candidate is package names + x.lower() not in [x["name"] for x in local] + ] + + imports = local + get_imports_info(difference, proxy=proxy, pypi_server=pypi_server) # sort imports based on lowercase name of package, similar to `pip freeze`. - imports = sorted(imports, key=lambda x: x['name'].lower()) + imports = sorted(imports, key=lambda x: x["name"].lower()) if args["--diff"]: diff(args["--diff"], imports) @@ -501,8 +506,9 @@ def init(args): if scheme in ["compat", "gt", "no-pin"]: imports, symbol = dynamic_versioning(scheme, imports) else: - raise ValueError("Invalid argument for mode flag, " - "use 'compat', 'gt' or 'no-pin' instead") + raise ValueError( + "Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead" + ) else: symbol = "==" @@ -516,8 +522,8 @@ def init(args): def main(): # pragma: no cover args = docopt(__doc__, version=__version__) - log_level = logging.DEBUG if args['--debug'] else logging.INFO - logging.basicConfig(level=log_level, format='%(levelname)s: %(message)s') + log_level = logging.DEBUG if args["--debug"] else logging.INFO + logging.basicConfig(level=log_level, format="%(levelname)s: %(message)s") try: init(args) @@ -525,5 +531,5 @@ def main(): # pragma: no cover sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() # pragma: no cover diff --git a/tests/test_pipreqs.py b/tests/test_pipreqs.py index f239f07..e67a83d 100644 --- a/tests/test_pipreqs.py +++ b/tests/test_pipreqs.py @@ -8,11 +8,12 @@ Tests for `pipreqs` module. """ -import io -import sys +from io import StringIO +from unittest.mock import patch import unittest import os import requests +import sys from pipreqs import pipreqs @@ -76,9 +77,10 @@ def setUp(self): self.project_with_ignore_directory = os.path.join(os.path.dirname(__file__), "_data_ignore") self.project_with_duplicated_deps = os.path.join(os.path.dirname(__file__), "_data_duplicated_deps") - + self.requirements_path = os.path.join(self.project, "requirements.txt") self.alt_requirement_path = os.path.join(self.project, "requirements2.txt") + self.non_existing_filepath = "xpto" def test_get_all_imports(self): imports = pipreqs.get_all_imports(self.project) @@ -471,14 +473,14 @@ def test_compare_modules(self): modules_not_imported = pipreqs.compare_modules(filename, imports) self.assertSetEqual(modules_not_imported, expected_modules_not_imported) - + def test_output_requirements(self): """ Test --print parameter It should print to stdout the same content as requeriments.txt """ - capturedOutput = io.StringIO() + capturedOutput = StringIO() sys.stdout = capturedOutput pipreqs.init( @@ -540,6 +542,23 @@ def test_parse_requirements(self): self.assertListEqual(parsed_requirements, expected_parsed_requirements) + @patch("sys.exit") + def test_parse_requirements_handles_file_not_found(self, exit_mock): + captured_output = StringIO() + sys.stdout = captured_output + + # This assertion is needed, because since "sys.exit" is mocked, the program won't end, + # and the code that is after the except block will be run + with self.assertRaises(UnboundLocalError): + pipreqs.parse_requirements(self.non_existing_filepath) + + exit_mock.assert_called_once_with(1) + + printed_text = captured_output.getvalue().strip() + sys.stdout = sys.__stdout__ + + self.assertEqual(printed_text, "File xpto was not found. Please, fix it and run again.") + def tearDown(self): """ Remove requiremnts.txt files that were written