Skip to content

Commit

Permalink
513 refactor recursive ls and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
justinTM committed Jul 13, 2022
1 parent fcc1c8e commit 0e8ff70
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 42 deletions.
14 changes: 12 additions & 2 deletions databricks_cli/dbfs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,18 @@ class DbfsApi(object):
def __init__(self, api_client):
self.client = DbfsService(api_client)

def list_files(self, dbfs_path, headers=None):
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
def _recursive_list(self, **kwargs):
paths = self.client.list_files(**kwargs)
files = [p for p in paths if not p.is_dir]
for p in paths:
files = files + self._recursive_list(p) if p.is_dir else files
return files

def list_files(self, dbfs_path, headers=None, is_recursive=False):
if is_recursive:
list_response = self._recursive_list(dbfs_path, headers)
else:
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
if 'files' in list_response:
return [FileInfo.from_json(f) for f in list_response['files']]
else:
Expand Down
25 changes: 9 additions & 16 deletions databricks_cli/dbfs/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@
@click.option('-l', is_flag=True, default=False,
help="""Displays full information including size, file type
and modification time since Epoch in milliseconds.""")
@click.option('--recursive', is_flag=True, default=False,
@click.option('--recursive', '-r', is_flag=True, default=False,
help='Displays all subdirectories and files.')
@click.argument('dbfs_path', nargs=-1, type=DbfsPathClickType())
@debug_option
@profile_option
@eat_exceptions
@provide_api_client
def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
def ls_cli(api_client, l, absolute, recursive, dbfs_path): # NOQA
"""
List files in DBFS.
"""
Expand All @@ -55,20 +55,13 @@ def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
dbfs_path = dbfs_path[0]
else:
error_and_quit('ls can take a maximum of one path.')

def echo_path(files):
table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files],
tablefmt='plain')
click.echo(table)

def recursive_echo(this_dbfs_path):
files = DbfsApi(api_client).list_files(this_dbfs_path)
echo_path(files)
for f in files:
if f.is_dir:
recursive_echo(this_dbfs_path.join(f.basename))

recursive_echo(dbfs_path) if recursive else echo_path(dbfs_path)

DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive)
absolute = absolute or recursive

table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files],
tablefmt='plain')
click.echo(table)


@click.command(context_settings=CONTEXT_SETTINGS)
Expand Down
73 changes: 49 additions & 24 deletions tests/dbfs/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,29 @@
from databricks_cli.dbfs.dbfs_path import DbfsPath
from databricks_cli.dbfs.exceptions import LocalFileExistsException

TEST_DBFS_PATH = DbfsPath('dbfs:/test')
TEST_DBFS_PATH1 = DbfsPath('dbfs:/test')
TEST_DBFS_PATH2 = DbfsPath('dbfs:/dir/test')
DUMMY_TIME = 1613158406000
TEST_FILE_JSON = {
TEST_FILE_JSON1 = {
'path': '/test',
'is_dir': False,
'file_size': 1,
'modification_time': DUMMY_TIME
}
TEST_FILE_INFO = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
TEST_FILE_JSON2 = {
'path': '/dir/test',
'is_dir': False,
'file_size': 1,
'modification_time': DUMMY_TIME
}
TEST_DIR_JSON = {
'path': '/dir',
'is_dir': True,
'file_size': 0,
'modification_time': DUMMY_TIME
}
TEST_FILE_INFO0 = api.FileInfo(TEST_DBFS_PATH1, False, 1, DUMMY_TIME)
TEST_FILE_INFO1 = api.FileInfo(TEST_DBFS_PATH2, False, 1, DUMMY_TIME)


def get_resource_does_not_exist_exception():
Expand All @@ -60,22 +74,22 @@ def get_partial_delete_exception(message="[...] operation has deleted 10 files [

class TestFileInfo(object):
def test_to_row_not_long_form_not_absolute(self):
file_info = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
file_info = api.FileInfo(TEST_DBFS_PATH1, False, 1, DUMMY_TIME)
row = file_info.to_row(is_long_form=False, is_absolute=False)
assert len(row) == 1
assert TEST_DBFS_PATH.basename == row[0]
assert TEST_DBFS_PATH1.basename == row[0]

def test_to_row_long_form_not_absolute(self):
file_info = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
file_info = api.FileInfo(TEST_DBFS_PATH1, False, 1, DUMMY_TIME)
row = file_info.to_row(is_long_form=True, is_absolute=False)
assert len(row) == 4
assert row[0] == 'file'
assert row[1] == 1
assert TEST_DBFS_PATH.basename == row[2]
assert TEST_DBFS_PATH1.basename == row[2]

def test_from_json(self):
file_info = api.FileInfo.from_json(TEST_FILE_JSON)
assert file_info.dbfs_path == TEST_DBFS_PATH
file_info = api.FileInfo.from_json(TEST_FILE_JSON1)
assert file_info.dbfs_path == TEST_DBFS_PATH1
assert not file_info.is_dir
assert file_info.file_size == 1

Expand All @@ -89,41 +103,52 @@ def dbfs_api():


class TestDbfsApi(object):
def test_list_files_recursive(self, dbfs_api):
json = {
'files': [TEST_FILE_JSON1, TEST_DIR_JSON, TEST_FILE_JSON2]
}
dbfs_api.client.list.return_value = json
files = dbfs_api.list_files("dbfs:/")

assert len(files) == 2
assert TEST_FILE_INFO0 == files[0]
assert TEST_FILE_INFO1 == files[1]

def test_list_files_exists(self, dbfs_api):
json = {
'files': [TEST_FILE_JSON]
'files': [TEST_FILE_JSON1]
}
dbfs_api.client.list.return_value = json
files = dbfs_api.list_files(TEST_DBFS_PATH)
files = dbfs_api.list_files(TEST_DBFS_PATH1, is_recursive=True)

assert len(files) == 1
assert TEST_FILE_INFO == files[0]
assert TEST_FILE_INFO0 == files[0]

def test_list_files_does_not_exist(self, dbfs_api):
json = {}
dbfs_api.client.list.return_value = json
files = dbfs_api.list_files(TEST_DBFS_PATH)
files = dbfs_api.list_files(TEST_DBFS_PATH1)

assert len(files) == 0

def test_file_exists_true(self, dbfs_api):
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
assert dbfs_api.file_exists(TEST_DBFS_PATH)
dbfs_api.client.get_status.return_value = TEST_FILE_JSON1
assert dbfs_api.file_exists(TEST_DBFS_PATH1)

def test_file_exists_false(self, dbfs_api):
exception = get_resource_does_not_exist_exception()
dbfs_api.client.get_status = mock.Mock(side_effect=exception)
assert not dbfs_api.file_exists(TEST_DBFS_PATH)
assert not dbfs_api.file_exists(TEST_DBFS_PATH1)

def test_get_status(self, dbfs_api):
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO
dbfs_api.client.get_status.return_value = TEST_FILE_JSON1
assert dbfs_api.get_status(TEST_DBFS_PATH1) == TEST_FILE_INFO0

def test_get_status_fail(self, dbfs_api):
exception = get_resource_does_not_exist_exception()
dbfs_api.client.get_status = mock.Mock(side_effect=exception)
with pytest.raises(exception.__class__):
dbfs_api.get_status(TEST_DBFS_PATH)
dbfs_api.get_status(TEST_DBFS_PATH1)

def test_put_file(self, dbfs_api, tmpdir):
test_file_path = os.path.join(tmpdir.strpath, 'test')
Expand All @@ -133,7 +158,7 @@ def test_put_file(self, dbfs_api, tmpdir):
api_mock = dbfs_api.client
test_handle = 0
api_mock.create.return_value = {'handle': test_handle}
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True)
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH1, True)

# Should not call add-block since file is < 2GB
assert api_mock.add_block.call_count == 0
Expand All @@ -148,7 +173,7 @@ def test_put_large_file(self, dbfs_api, tmpdir):
dbfs_api.MULTIPART_UPLOAD_LIMIT = 2
test_handle = 0
api_mock.create.return_value = {'handle': test_handle}
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True)
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH1, True)
assert api_mock.add_block.call_count == 1
assert test_handle == api_mock.add_block.call_args[0][0]
assert b64encode(b'test').decode() == api_mock.add_block.call_args[0][1]
Expand All @@ -160,18 +185,18 @@ def test_get_file_check_overwrite(self, dbfs_api, tmpdir):
with open(test_file_path, 'w') as f:
f.write('test')
with pytest.raises(LocalFileExistsException):
dbfs_api.get_file(TEST_DBFS_PATH, test_file_path, False)
dbfs_api.get_file(TEST_DBFS_PATH1, test_file_path, False)

def test_get_file(self, dbfs_api, tmpdir):
api_mock = dbfs_api.client
api_mock.get_status.return_value = TEST_FILE_JSON
api_mock.get_status.return_value = TEST_FILE_JSON1
api_mock.read.return_value = {
'bytes_read': 1,
'data': b64encode(b'x'),
}

test_file_path = os.path.join(tmpdir.strpath, 'test')
dbfs_api.get_file(TEST_DBFS_PATH, test_file_path, True)
dbfs_api.get_file(TEST_DBFS_PATH1, test_file_path, True)

with open(test_file_path, 'r') as f:
assert f.read() == 'x'
Expand Down

0 comments on commit 0e8ff70

Please sign in to comment.