Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix : corrected unworking storageobject_reader #99

Merged
merged 7 commits into from
Feb 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions nck/readers/gcs_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
@click.command(name="read_gcs")
@click.option("--gcs-bucket", required=True)
@click.option("--gcs-prefix", required=True, multiple=True)
@click.option("--gcs-format", required=True, type=click.Choice(["csv", "gz"]))
@click.option("--gcs-format", required=True, type=click.Choice(["csv", "gz", "njson"]))
@click.option("--gcs-dest-key-split", default=-1, type=int)
@click.option("--gcs-csv-delimiter", default=",")
@click.option("--gcs-csv-fieldnames", default=None)
Expand All @@ -39,14 +39,10 @@ def gcs(**kwargs):

class GCSReader(ObjectStorageReader, GoogleBaseClass):
def __init__(self, bucket, prefix, format, dest_key_split=-1, **kwargs):
super().__init__(
bucket, prefix, format, dest_key_split, platform="GCS", **kwargs
)
super().__init__(bucket, prefix, format, dest_key_split, platform="GCS", **kwargs)

def create_client(self, config):
return storage.Client(
credentials=self._get_credentials(), project=config.project_id
)
return storage.Client(credentials=self._get_credentials(), project=config.project_id)
benoitgoujon marked this conversation as resolved.
Show resolved Hide resolved

def create_bucket(self, client, bucket):
return client.bucket(bucket)
Expand Down
101 changes: 11 additions & 90 deletions nck/readers/objectstorage_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,7 @@
from nck.config import logger
from nck.readers.reader import Reader
from nck.streams.normalized_json_stream import NormalizedJSONStream
from nck.utils.file_reader import FileEnum


def find_reader(_format, kwargs):
_format = _format.upper()
if _format in FileEnum.__members__:
r = getattr(FileEnum, _format).value
_reader = r(**kwargs).get_csv_reader()
else:
raise NotImplementedError(f"The file format {str(_format)} has not been implemented for reading yet.")
return _reader


def no_files_seen_before(max_timestamp):
return not max_timestamp


def _object_older_than_most_recently_ingested_file(max_timestamp, _object_timestamp):
return max_timestamp > _object_timestamp


def _object_newer_than_most_recently_ingested_file(max_timestamp, _object_timestamp):
return max_timestamp < _object_timestamp


def _object_as_old_as_most_recently_ingested_file(max_timestamp, _object_timestamp):
return max_timestamp == _object_timestamp
from nck.utils.file_reader import create_file_reader


class ObjectStorageReader(Reader):
Expand All @@ -58,7 +32,7 @@ def __init__(self, bucket, prefix, file_format, dest_key_split, platform=None, *
self._platform = platform

self._format = file_format
self._reader = find_reader(self._format, kwargs)
self._reader = create_file_reader(self._format, **kwargs).get_reader()
self._dest_key_split = dest_key_split

self.MAX_TIMESTAMP_STATE_KEY = f"{self._platform}_max_timestamp".lower()
Expand All @@ -69,7 +43,8 @@ def read(self):
for prefix in self._prefix_list:

objects_sorted_by_time = sorted(
self.list_objects(bucket=self._bucket, prefix=prefix), key=lambda o: self.get_timestamp(o),
self.list_objects(bucket=self._bucket, prefix=prefix),
key=lambda o: self.get_timestamp(o),
)

for _object in objects_sorted_by_time:
Expand All @@ -82,73 +57,19 @@ def read(self):
logger.info(f"Wrong extension: Skipping file {self.get_key(_object)}")
continue

if self.has_already_processed_object(_object):
logger.info(f"Skipping already processed file {self.get_key(_object)}")
continue

def result_generator():
temp = tempfile.TemporaryFile()
self.download_object_to_file(_object, temp)

for record in self._reader(temp):
yield record

self.checkpoint_object(_object)

name = self.get_key(_object).split("/", self._dest_key_split)[-1]

yield NormalizedJSONStream(name, result_generator())
yield NormalizedJSONStream(name, self._result_generator(_object))

def _result_generator(self, _object):
with tempfile.TemporaryFile() as temp:
self.download_object_to_file(_object, temp)
for record in self._reader(temp):
yield record

def is_compatible_object(self, _object):
return self.get_key(_object).endswith("." + self._format)

def has_already_processed_object(self, _object):

assert self.get_timestamp(_object) is not None, "Object has no timestamp!"

max_timestamp = self.state.get(self.MAX_TIMESTAMP_STATE_KEY)

if no_files_seen_before(max_timestamp):
return False

_object_timestamp = self.get_timestamp(_object)

if _object_older_than_most_recently_ingested_file(max_timestamp, _object_timestamp):
return True

if _object_newer_than_most_recently_ingested_file(max_timestamp, _object_timestamp):
return False

if _object_as_old_as_most_recently_ingested_file(max_timestamp, _object_timestamp):
max_files = self.state.get(self.MAX_FILES_STATE_KEY)
return self.get_key(_object) in max_files

def checkpoint_object(self, _object):

assert self.get_timestamp(_object) is not None, "Object has no timestamp!"

max_timestamp = self.state.get(self.MAX_TIMESTAMP_STATE_KEY)
_object_timestamp = self.get_timestamp(_object)

if max_timestamp and _object_older_than_most_recently_ingested_file(max_timestamp, _object_timestamp):
raise RuntimeError("Object is older than max timestamp at checkpoint time")

elif not max_timestamp or _object_newer_than_most_recently_ingested_file(max_timestamp, _object_timestamp):
self.update_max_timestamp(_object_timestamp, _object)

else:
assert _object_as_old_as_most_recently_ingested_file(max_timestamp, _object_timestamp)
self.update_max_files(_object)

def update_max_timestamp(self, _object_timestamp, _object):
self.state.set(self.MAX_TIMESTAMP_STATE_KEY, _object_timestamp)
self.state.set(self.MAX_FILES_STATE_KEY, [self.get_key(_object)])

def update_max_files(self, _object):
max_files = self.state.get(self.MAX_FILES_STATE_KEY)
max_files.append(self.get_key(_object))
self.state.set(self.MAX_FILES_STATE_KEY, max_files)

def create_client(self, config):
raise NotImplementedError

Expand Down
6 changes: 2 additions & 4 deletions nck/readers/s3_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
@click.command(name="read_s3")
@click.option("--s3-bucket", required=True)
@click.option("--s3-prefix", required=True, multiple=True)
@click.option("--s3-format", required=True, type=click.Choice(["csv", "gz"]))
@click.option("--s3-format", required=True, type=click.Choice(["csv", "gz", "njson"]))
@click.option("--s3-dest-key-split", default=-1, type=int)
@click.option("--s3-csv-delimiter", default=",")
@click.option("--s3-csv-fieldnames", default=None)
Expand All @@ -37,9 +37,7 @@ def s3(**kwargs):

class S3Reader(ObjectStorageReader):
def __init__(self, bucket, prefix, format, dest_key_split=-1, **kwargs):
super().__init__(
bucket, prefix, format, dest_key_split, platform="S3", **kwargs
)
super().__init__(bucket, prefix, format, dest_key_split, platform="S3", **kwargs)

def create_client(self, config):
boto_config = {
Expand Down
50 changes: 36 additions & 14 deletions nck/utils/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# You should have received a copy of the GNU Lesser General Public License
# along with this program; if not, write to the Free Software Foundation,
# Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
from enum import Enum
import csv
import codecs
import gzip
Expand All @@ -26,7 +25,7 @@


def unzip(input_file, output_path):
with zipfile.ZipFile(input_file, 'r') as zip_ref:
with zipfile.ZipFile(input_file, "r") as zip_ref:
zip_ref.extractall(output_path)


Expand All @@ -53,20 +52,41 @@ def format_csv_fieldnames(csv_fieldnames):
elif isinstance(csv_fieldnames, (str, bytes)):
_csv_fieldnames = json.loads(csv_fieldnames)
else:
raise TypeError(
f"The CSV fieldnames is of the following type: {type(csv_fieldnames)}."
)
raise TypeError(f"The CSV fieldnames is of the following type: {type(csv_fieldnames)}.")
assert isinstance(_csv_fieldnames, list)
return _csv_fieldnames


class CSVReader(object):
def create_file_reader(_format, **kwargs):
if _format == "csv":
return CSVReader(**kwargs)
if _format == "gz":
return GZReader(**kwargs)
if _format == "njson":
return NJSONReader(**kwargs)
else:
raise NotImplementedError(f"The file format {str(_format)} has not been implemented for reading yet.")


class FileReader:
def __init__(self, **kwargs):
self.reader = lambda fd: self.read(fd, **kwargs)

def read(self, fd, **kwargs):
fd.seek(0)
return codecs.iterdecode(fd, encoding="utf8")

def get_reader(self):
return self.reader


class CSVReader(FileReader):
def __init__(self, csv_delimiter, csv_fieldnames, **kwargs):
self.csv_delimiter = format_csv_delimiter(csv_delimiter)
self.csv_fieldnames = format_csv_fieldnames(csv_fieldnames) if csv_fieldnames is not None else None
self.csv_reader = lambda fd: self.read_csv(fd, **kwargs)
super().__init__(**kwargs)

def read_csv(self, fd, **kwargs):
def read(self, fd, **kwargs):
fd.seek(0)
fd = self.decompress(fd)
return csv.DictReader(
Expand All @@ -79,16 +99,18 @@ def read_csv(self, fd, **kwargs):
def decompress(self, fd):
return fd

def get_csv_reader(self):
return self.csv_reader


class GZReader(CSVReader):
def decompress(self, fd):
gzf = gzip.GzipFile(mode="rb", fileobj=fd)
return gzf


class FileEnum(Enum):
CSV = CSVReader
GZ = GZReader
class NJSONReader(FileReader):
def read(self, fd, **kwargs):
fd.seek(0)
return self.jsongene(fd, **kwargs)

def jsongene(self, fd, **kwargs):
for line in codecs.iterdecode(fd, encoding="utf8"):
yield json.loads(line)
96 changes: 96 additions & 0 deletions tests/readers/test_objectstorage_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import io
import csv
import json

from parameterized import parameterized
from nck.readers.objectstorage_reader import ObjectStorageReader
from unittest import TestCase, mock


mock_csv_names = ["a.csv", "a.njson", "b.csv", "b.njson"]
mock_csv_files = [
[["a", "b", "c"], [4, 5, 6], [7, 8, 9]],
[{"a": "4", "b": "5", "c": "6"}, {"a": "7", "b": "8", "c": "9"}],
[["a", "b", "c"], [4, 5, 6], [7, 8, 9]],
[{"a": "4", "b": "5", "c": "6"}, {"a": "7", "b": "8", "c": "9"}],
]

mock_timestamp = [
1614179262,
1614179272,
1614179277,
16141792778,
]


def mock_to_object(self, _object):
return _object


def mock_list_objects(self, bucket, prefix):
a = list(zip(mock_csv_names, mock_timestamp, mock_csv_files))
return [x for x in a if x[0].startswith(prefix)]


def mock_get_timestamp(self, _object, **kwargs):
return _object[1]


def write_to_file(self, _object, f, **kwargs):

if self._format == "csv":

text_file = io.TextIOWrapper(f, encoding="utf-8", newline="")
w = csv.writer(text_file)
w.writerows(_object[2])
text_file.detach()

else:

text_file = io.TextIOWrapper(f, encoding="utf-8")
for line in _object[2]:

json.dump(line, text_file)
text_file.write("\n")
text_file.detach()


def mock_get_key(self, _object, **kwargs):
return _object[0]


@mock.patch("nck.readers.objectstorage_reader.ObjectStorageReader.create_client")
@mock.patch("nck.readers.objectstorage_reader.ObjectStorageReader.create_bucket")
@mock.patch.object(ObjectStorageReader, "download_object_to_file", write_to_file)
@mock.patch.object(ObjectStorageReader, "to_object", mock_to_object)
@mock.patch.object(ObjectStorageReader, "get_timestamp", mock_get_timestamp)
@mock.patch.object(ObjectStorageReader, "list_objects", mock_list_objects)
@mock.patch.object(ObjectStorageReader, "get_key", mock_get_key)
class ObjectStorageReaderTest(TestCase):
benoitgoujon marked this conversation as resolved.
Show resolved Hide resolved
def test_wrong_format(self, a, b):
with self.assertRaises(NotImplementedError):
ObjectStorageReader(
bucket="", prefix=["a"], file_format="txt", dest_key_split=-1, csv_delimiter=",", csv_fieldnames=None
)

@parameterized.expand([("njson", 2), ("csv", 2)])
def test_ObjectStorageReader_filter_files(self, a, b, format, nb_files_expected):
reader = ObjectStorageReader(
bucket="", prefix=[""], file_format=format, dest_key_split=-1, csv_delimiter=",", csv_fieldnames=None
)
nb_file = len(list(reader.read()))
self.assertEqual(nb_file, nb_files_expected)

@parameterized.expand(
[
("njson", [{"a": "4", "b": "5", "c": "6"}, {"a": "7", "b": "8", "c": "9"}]),
("csv", [{"a": "4", "b": "5", "c": "6"}, {"a": "7", "b": "8", "c": "9"}]),
]
)
def test_ObjectStorageReader_read_all_file(self, a, b, format, expected):
reader = ObjectStorageReader(
bucket="", prefix=["a"], file_format="csv", dest_key_split=-1, csv_delimiter=",", csv_fieldnames=None
)
for file in reader.read():
for expect, data in zip(expected, file.readlines()):
self.assertEqual(expect, data)