Skip to content

Commit

Permalink
!513 refactor recursive dbfs ls and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
justinTM committed Jul 13, 2022
1 parent 0b6a78e commit 3330faf
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
14 changes: 12 additions & 2 deletions databricks_cli/dbfs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,18 @@ class DbfsApi(object):
def __init__(self, api_client):
self.client = DbfsService(api_client)

def list_files(self, dbfs_path, headers=None):
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
def _recursive_list(self, **kwargs):
paths = self.client.list_files(**kwargs)
files = [p for p in paths if not p.is_dir]
for p in paths:
files = files + self._recursive_list(p) if p.is_dir else files
return files

def list_files(self, dbfs_path, headers=None, is_recursive=False):
if is_recursive:
list_response = self._recursive_list(dbfs_path, headers)
else:
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
if 'files' in list_response:
return [FileInfo.from_json(f) for f in list_response['files']]
else:
Expand Down
9 changes: 7 additions & 2 deletions databricks_cli/dbfs/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,14 @@
@click.option('-l', is_flag=True, default=False,
help="""Displays full information including size, file type
and modification time since Epoch in milliseconds.""")
@click.option('--recursive', '-r', is_flag=True, default=False,
help='Displays all subdirectories and files.')
@click.argument('dbfs_path', nargs=-1, type=DbfsPathClickType())
@debug_option
@profile_option
@eat_exceptions
@provide_api_client
def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
def ls_cli(api_client, l, absolute, recursive, dbfs_path): # NOQA
"""
List files in DBFS.
"""
Expand All @@ -53,7 +55,10 @@ def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
dbfs_path = dbfs_path[0]
else:
error_and_quit('ls can take a maximum of one path.')
files = DbfsApi(api_client).list_files(dbfs_path)

DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive)
absolute = absolute or recursive

table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files],
tablefmt='plain')
click.echo(table)
Expand Down
47 changes: 36 additions & 11 deletions tests/dbfs/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,26 @@

TEST_DBFS_PATH = DbfsPath('dbfs:/test')
DUMMY_TIME = 1613158406000
TEST_FILE_JSON = {
TEST_FILE_JSON1 = {
'path': '/test',
'is_dir': False,
'file_size': 1,
'modification_time': DUMMY_TIME
}
TEST_FILE_INFO = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
TEST_FILE_JSON2 = {
'path': '/dir/test',
'is_dir': False,
'file_size': 1,
'modification_time': DUMMY_TIME
}
TEST_DIR_JSON = {
'path': '/dir',
'is_dir': True,
'file_size': 0,
'modification_time': DUMMY_TIME
}
TEST_FILE_INFO0 = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
TEST_FILE_INFO1 = api.FileInfo(TEST_DBFS_PATH2, False, 1, DUMMY_TIME)


def get_resource_does_not_exist_exception():
Expand Down Expand Up @@ -74,7 +87,7 @@ def test_to_row_long_form_not_absolute(self):
assert TEST_DBFS_PATH.basename == row[2]

def test_from_json(self):
file_info = api.FileInfo.from_json(TEST_FILE_JSON)
file_info = api.FileInfo.from_json(TEST_FILE_JSON0)
assert file_info.dbfs_path == TEST_DBFS_PATH
assert not file_info.is_dir
assert file_info.file_size == 1
Expand All @@ -89,15 +102,26 @@ def dbfs_api():


class TestDbfsApi(object):
def test_list_files_recursive(self, dbfs_api):
json = {
'files': [TEST_FILE_JSON0, TEST_DIR_JSON, TEST_FILE_JSON1]
}
dbfs_api.client.list.return_value = json
files = dbfs_api.list_files("dbfs:/")

assert len(files) == 2
assert TEST_FILE_INFO0 == files[0]
assert TEST_FILE_INFO1 == files[1]

def test_list_files_exists(self, dbfs_api):
json = {
'files': [TEST_FILE_JSON]
'files': [TEST_FILE_JSON0]
}
dbfs_api.client.list.return_value = json
files = dbfs_api.list_files(TEST_DBFS_PATH)
files = dbfs_api.list_files(TEST_DBFS_PATH, is_recursive=True)

assert len(files) == 1
assert TEST_FILE_INFO == files[0]
assert TEST_FILE_INFO0 == files[0]

def test_list_files_does_not_exist(self, dbfs_api):
json = {}
Expand All @@ -107,7 +131,7 @@ def test_list_files_does_not_exist(self, dbfs_api):
assert len(files) == 0

def test_file_exists_true(self, dbfs_api):
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
dbfs_api.client.get_status.return_value = TEST_FILE_JSON0
assert dbfs_api.file_exists(TEST_DBFS_PATH)

def test_file_exists_false(self, dbfs_api):
Expand All @@ -116,8 +140,8 @@ def test_file_exists_false(self, dbfs_api):
assert not dbfs_api.file_exists(TEST_DBFS_PATH)

def test_get_status(self, dbfs_api):
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO
dbfs_api.client.get_status.return_value = TEST_FILE_JSON0
assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO0

def test_get_status_fail(self, dbfs_api):
exception = get_resource_does_not_exist_exception()
Expand Down Expand Up @@ -151,7 +175,8 @@ def test_put_large_file(self, dbfs_api, tmpdir):
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True)
assert api_mock.add_block.call_count == 1
assert test_handle == api_mock.add_block.call_args[0][0]
assert b64encode(b'test').decode() == api_mock.add_block.call_args[0][1]
assert b64encode(b'test').decode(
) == api_mock.add_block.call_args[0][1]
assert api_mock.close.call_count == 1
assert test_handle == api_mock.close.call_args[0][0]

Expand All @@ -164,7 +189,7 @@ def test_get_file_check_overwrite(self, dbfs_api, tmpdir):

def test_get_file(self, dbfs_api, tmpdir):
api_mock = dbfs_api.client
api_mock.get_status.return_value = TEST_FILE_JSON
api_mock.get_status.return_value = TEST_FILE_JSON0
api_mock.read.return_value = {
'bytes_read': 1,
'data': b64encode(b'x'),
Expand Down

0 comments on commit 3330faf

Please sign in to comment.