diff --git a/nck/readers/amazon_s3/reader.py b/nck/readers/amazon_s3/reader.py index 70c311ed..3b610e99 100644 --- a/nck/readers/amazon_s3/reader.py +++ b/nck/readers/amazon_s3/reader.py @@ -21,16 +21,17 @@ class AmazonS3Reader(ObjectStorageReader): - def __init__(self, bucket, prefix, format, dest_key_split=-1, **kwargs): + def __init__(self, bucket, bucket_region, access_key_id, access_key_secret, prefix, format, dest_key_split=-1, **kwargs): + self.boto_config = { + "region_name": bucket_region, + "aws_access_key_id": access_key_id, + "aws_secret_access_key": access_key_secret, + } super().__init__(bucket, prefix, format, dest_key_split, platform="S3", **kwargs) def create_client(self, config): - boto_config = { - "region_name": config.REGION_NAME, - "aws_access_key_id": config.AWS_ACCESS_KEY_ID, - "aws_secret_access_key": config.AWS_SECRET_ACCESS_KEY, - } - return boto3.resource("s3", **boto_config) + + return boto3.resource("s3", **self.boto_config) def create_bucket(self, client, bucket): return client.Bucket(bucket) @@ -51,5 +52,5 @@ def to_object(_object): return _object.Object() @staticmethod - def download_object_to_file(_object, temp): - _object.download_fileobj(temp) + def download_object_to_file(_object, stream): + _object.download_fileobj(stream) diff --git a/nck/readers/object_storage/reader.py b/nck/readers/object_storage/reader.py index 7a107ad9..963a2736 100644 --- a/nck/readers/object_storage/reader.py +++ b/nck/readers/object_storage/reader.py @@ -16,13 +16,12 @@ # along with this program; if not, write to the Free Software Foundation, # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -import tempfile +from io import BytesIO from nck import config from nck.config import logger from nck.readers.reader import Reader -from nck.streams.json_stream import JSONStream -from nck.utils.file_reader import create_file_reader +from nck.streams.new_stream import NewStream class ObjectStorageReader(Reader): @@ -33,19 +32,14 @@ def __init__(self, bucket, prefix, file_format, dest_key_split, platform=None, * self._platform = platform self._format = file_format - self._reader = create_file_reader(self._format, **kwargs).get_reader() self._dest_key_split = dest_key_split - self.MAX_TIMESTAMP_STATE_KEY = f"{self._platform}_max_timestamp".lower() - self.MAX_FILES_STATE_KEY = f"{self._platform}_max_files".lower() - def read(self): for prefix in self._prefix_list: objects_sorted_by_time = sorted( - self.list_objects(bucket=self._bucket, prefix=prefix), - key=lambda o: self.get_timestamp(o), + self.list_objects(bucket=self._bucket, prefix=prefix), key=lambda o: self.get_timestamp(o), ) for _object in objects_sorted_by_time: @@ -60,13 +54,12 @@ def read(self): name = self.get_key(_object).split("/", self._dest_key_split)[-1] - yield JSONStream(name, self._result_generator(_object)) + yield NewStream(name, self._dowload_to_stream(_object)) - def _result_generator(self, _object): - with tempfile.TemporaryFile() as temp: - self.download_object_to_file(_object, temp) - for record in self._reader(temp): - yield record + def _dowload_to_stream(self, _object): + f = BytesIO() + self.download_object_to_file(_object, f) + return f def is_compatible_object(self, _object): return self.get_key(_object).endswith("." + self._format) diff --git a/nck/streams/new_stream.py b/nck/streams/new_stream.py new file mode 100644 index 00000000..7c09c021 --- /dev/null +++ b/nck/streams/new_stream.py @@ -0,0 +1,5 @@ +class NewStream: + def __init__(self, name, stream): + self.name = name + self.stream = stream + self.stream.seek(0) diff --git a/nck/writers/amazon_s3/writer.py b/nck/writers/amazon_s3/writer.py index 88ae877c..bb9cdf0f 100644 --- a/nck/writers/amazon_s3/writer.py +++ b/nck/writers/amazon_s3/writer.py @@ -39,7 +39,7 @@ def _list_buckets(self, client): return client.buckets.all() def _create_blob(self, file_name, stream): - self._bucket.upload_fileobj(stream.as_file(), file_name) + self._bucket.upload_fileobj(stream.stream, file_name) def _get_uri(self, file_name): return f"s3{self._get_file_path(file_name)}" diff --git a/tests/end_to_end/S3_test.py b/tests/end_to_end/S3_test.py new file mode 100644 index 00000000..0c46c478 --- /dev/null +++ b/tests/end_to_end/S3_test.py @@ -0,0 +1,46 @@ +import boto3 +from moto import mock_s3 +from unittest import TestCase + +from nck.writers.amazon_s3.writer import AmazonS3Writer +from nck.readers.amazon_s3.reader import AmazonS3Reader + +csv_file = [["a", "b", "c"], [4, 5, 6], [7, 8, 9]] + + +@mock_s3 +class AmazonS3WriterTest(TestCase): + @classmethod + @mock_s3 + def setUpClass(cls): + + client = boto3.resource("s3", region_name="us-east-1") + client.create_bucket(Bucket="test") + obj = client.Object("test", "filename.csv") + obj.put(Body=b"some data") + + def test_Write(self): + + reader = AmazonS3Reader( + bucket="test", + bucket_region="us-east-1", + access_key_id="", + access_key_secret="", + prefix=[""], + format="csv", + dest_key_split=-1, + csv_delimiter=",", + csv_fieldnames=None, + ) + + writer = AmazonS3Writer("test", "us-east-1", "", "", filename="ok") + + for stream in reader.read(): + writer.write(stream) + + client = boto3.resource("s3", region_name="us-east-1") + bucket = client.Bucket("test") + + obj = list(bucket.objects.all())[-1] + bod = obj.get()["Body"].read().decode("utf-8") + self.assertEqual("some data", bod)