Skip to content

Commit

Permalink
Merge pull request #385 from andreidenissov-cog/bugfix/381
Browse files Browse the repository at this point in the history
Fixes to allow use of Minio service as S3 API storage/database endpoint.
  • Loading branch information
andreidenissov-cog authored Jan 31, 2020
2 parents e14f8f0 + 8148cab commit a8c0b9a
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 26 deletions.
11 changes: 11 additions & 0 deletions studio/base_artifact_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class BaseArtifactStore(object):

def __init__(self):
self.storage_client = None

def set_storage_client(self, sclient):
self.storage_client = sclient

def get_storage_client(self):
return self.storage_client

3 changes: 1 addition & 2 deletions studio/http_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
import re

from . import pyrebase, logs
from . import logs
from .auth import get_auth
from .http_artifact_store import HTTPArtifactStore
from .experiment import experiment_from_dict
Expand All @@ -27,7 +27,6 @@ def __init__(
self.logger.setLevel(self.verbose)

self.auth = None
self.app = pyrebase.initialize_app(config)
guest = config.get('guest')
if not guest and 'serviceAccount' not in config.keys():
self.auth = get_auth(
Expand Down
8 changes: 4 additions & 4 deletions studio/keyvalue_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from threading import Thread

from . import util, git_util, pyrebase, logs
from . import util, git_util, logs
from .firebase_artifact_store import FirebaseArtifactStore
from .auth import get_auth
from .experiment import experiment_from_dict
Expand All @@ -24,7 +24,6 @@ def __init__(
compression=None):
guest = db_config.get('guest')

self.app = pyrebase.initialize_app(db_config)
self.logger = logs.getLogger(self.__class__.__name__)
self.logger.setLevel(verbose)

Expand Down Expand Up @@ -425,11 +424,12 @@ def register_user(self, userid, email):
if existing_email != email:
self._set(keypath, email)

def get_artifact_store(self):
return self.store

def __enter__(self):
return self

def __exit__(self, *args):
if self.app:
self.app.requests.close()
if self.store:
self.store.__exit__()
19 changes: 13 additions & 6 deletions studio/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from .firebase_provider import FirebaseProvider
from .s3_provider import S3Provider
from .gs_provider import GSProvider
from .model_setup import setup_model
from . import logs


def get_config(config_file=None):

config_paths = []
Expand Down Expand Up @@ -56,7 +56,6 @@ def replace_with_env(config):
raise ValueError('None of the config paths {} exits!'
.format(config_paths))


def get_db_provider(config=None, blocking_auth=True):
if not config:
config = get_config()
Expand All @@ -76,31 +75,39 @@ def get_db_provider(config=None, blocking_auth=True):
artifact_store = None

assert 'database' in config.keys()
db_provider = None
db_config = config['database']
if db_config['type'].lower() == 'firebase':
return FirebaseProvider(
db_provider = FirebaseProvider(
db_config,
blocking_auth,
verbose=verbose,
store=artifact_store)
artifact_store = db_provider.get_artifact_store()
elif db_config['type'].lower() == 'http':
return HTTPProvider(db_config,
db_provider = HTTPProvider(db_config,
verbose=verbose,
blocking_auth=blocking_auth)
elif db_config['type'].lower() == 's3':
return S3Provider(db_config,
db_provider = S3Provider(db_config,
verbose=verbose,
store=artifact_store,
blocking_auth=blocking_auth)
artifact_store = db_provider.get_artifact_store()

elif db_config['type'].lower() == 'gs':
return GSProvider(db_config,
db_provider = GSProvider(db_config,
verbose=verbose,
store=artifact_store,
blocking_auth=blocking_auth)
artifact_store = db_provider.get_artifact_store()

else:
_model_setup = None
raise ValueError('Unknown type of the database ' + db_config['type'])

setup_model(db_provider, artifact_store)
return db_provider

def parse_verbosity(verbosity=None):
if verbosity is None:
Expand Down
20 changes: 20 additions & 0 deletions studio/model_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

DB_KEY = "database"
STORE_KEY = "store"

# Global dictionary which keeps Database Provider
# and Artifact Store objects created from experiment configuration.
_model_setup = None

def setup_model(db_provider, artifact_store):
_model_setup = { DB_KEY: db_provider, STORE_KEY: artifact_store }

def get_db_provider():
if _model_setup is None:
return None
return _model_setup.get(DB_KEY, None)

def get_artifact_store():
if _model_setup is None:
return None
return _model_setup.get(STORE_KEY, None)
5 changes: 3 additions & 2 deletions studio/s3_artifact_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .tartifact_store import TartifactStore
from . import logs


class S3ArtifactStore(TartifactStore):
def __init__(self, config,
verbose=10,
Expand All @@ -38,12 +37,14 @@ def __init__(self, config,
self.client.create_bucket(
Bucket=self.bucket
)

super(S3ArtifactStore, self).__init__(
measure_timestamp_diff,
compression=compression,
verbose=verbose)

self.set_storage_client(self.client)


def _upload_file(self, key, local_path):
self.client.upload_file(local_path, self.bucket, key)

Expand Down
14 changes: 12 additions & 2 deletions studio/tartifact_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
from .util import compression_to_extension, compression_to_taropt, timeit
from .util import sixdecode

from .base_artifact_store import BaseArtifactStore

class TartifactStore(object):

class TartifactStore(BaseArtifactStore):

def __init__(self, measure_timestamp_diff=False, compression=None,
verbose=logs.DEBUG):

super(TartifactStore, self).__init__()

if measure_timestamp_diff:
try:
self.timestamp_shift = self._measure_timestamp_diff()
Expand Down Expand Up @@ -244,10 +249,12 @@ def finish_download():
listtar = listtar.strip().split(b'\n')
listtar = [s.decode('utf-8') for s in listtar]

isTarFromDotDir = False
self.logger.info('List of files in the tar: ' + str(listtar))
if listtar[0].startswith('./'):
# Files are archived into tar from .; adjust path
# accordingly
isTarFromDotDir = True
basepath = local_path
else:
basepath = local_basepath
Expand All @@ -274,7 +281,10 @@ def finish_download():
self.logger.info('tar stdout output: \n ' + str(tarout))
self.logger.info('tar stderr output: \n ' + str(tarerr))

if len(listtar) == 1:
if len(listtar) == 1 and not isTarFromDotDir:
# Here we protect ourselves from the corner case,
# when we try to move A/. folder to A.
# os.rename() will fail to do that.
actual_path = os.path.join(basepath, listtar[0])
self.logger.info(
'Renaming {} into {}'.format(
Expand Down
43 changes: 33 additions & 10 deletions studio/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@
import numpy as np
import requests
import six

import boto3
from botocore.exceptions import ClientError
from . import model_setup
from .base_artifact_store import BaseArtifactStore

DAY = 86400
HOUR = 3600
MINUTE = 60



def remove_backspaces(line):
splitline = re.split('(\x08+)', line)
try:
Expand Down Expand Up @@ -275,6 +273,20 @@ def upload_file(url, local_path, logger=None):
logger.debug('File upload done in {} s'
.format(time.time() - tic))

def _looks_like_url(name):
"""
Function tries to determine if input argument
looks like URL and not like S3 bucket name.
:param name - input name
:return: True, if name looks like URL;
False otherwise.
"""
if name.endswith('.com'):
return True
if name.find(':') >= 0:
# Assume it is port number
return True
return False

def download_file_from_qualified(qualified, local_path, logger=None):
if qualified.startswith('dockerhub://') or \
Expand All @@ -285,7 +297,7 @@ def download_file_from_qualified(qualified, local_path, logger=None):
qualified.startswith('gs://')

qualified_split = qualified.split('/')
if qualified_split[2].endswith('.com'):
if _looks_like_url(qualified_split[2]):
bucket = qualified_split[3]
key = '/'.join(qualified_split[4:])
else:
Expand All @@ -309,15 +321,26 @@ def download_file_from_qualified(qualified, local_path, logger=None):
# ).communicate()
_s3_download_dir(bucket, key, local_path, logger=logger)
else:
boto3.client('s3').download_file(bucket, key, local_path)
s3_client = _get_active_s3_client()
s3_client.download_file(bucket, key, local_path)
else:
raise NotImplementedError

def _get_active_s3_client():
artifact_store = model_setup.get_artifact_store()
if artifact_store is None \
or not isinstance(artifact_store, BaseArtifactStore):
raise NotImplementedError("Artifact store is not set up or has the wrong type")

storage_client = ((BaseArtifactStore)(artifact_store)).get_storage_client()
if storage_client is None:
raise NotImplementedError("Expected boto3 storage client for current artifact store")
return storage_client

def _s3_download_dir(bucket, dist, local, logger=None):
client = boto3.client('s3')
s3_client = _get_active_s3_client()

paginator = client.get_paginator('list_objects')
paginator = s3_client.get_paginator('list_objects')
for result in paginator.paginate(
Bucket=bucket,
Delimiter='/',
Expand Down Expand Up @@ -355,15 +378,15 @@ def _s3_download_dir(bucket, dist, local, logger=None):
'Downloading {}/{} to {}'
.format(bucket, key, local_path))

client.download_file(bucket, key, local_path)
s3_client.download_file(bucket, key, local_path)
except ClientError as e:
if logger:
logger.debug(
'Download failed with exception {}'.format(e))


def has_aws_credentials():
return boto3.client('s3')._request_signer._credentials is not None
return _get_active_s3_client()._request_signer._credentials is not None


def retry(f,
Expand Down

0 comments on commit a8c0b9a

Please sign in to comment.