From 60f3ed2c2550dfb2deb57a74f87462b043a8d10f Mon Sep 17 00:00:00 2001 From: Sudheesh Singanamalla Date: Wed, 9 Aug 2023 11:55:08 -0700 Subject: [PATCH 1/2] Update requirements and postgres DB connector strings Signed-off-by: Sudheesh Singanamalla --- config.json | 4 ++-- requirements.txt | 13 +++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/config.json b/config.json index e4a3413..7adc4e9 100644 --- a/config.json +++ b/config.json @@ -1,7 +1,7 @@ { "_comment": "This doesn't appear to get used in ETL at all? TODO: remove if unnecessary.", - "rds_uri": "postgres://$CYBERGREEN_STATS_RDS_NAME:$CYBERGREEN_STATS_RDS_PASSWORD@$CYBERGREEN_RAW_SCAN_RDS_NAME.crovisjepxcd.eu-west-1.rds.amazonaws.com:5432/$CYBERGREEN_STATS_RDS_NAME", - "redshift_uri": "postgres://$CYBERGREEN_REDSHIFT_USER:$CYBERGREEN_REDSHIFT_PASSWORD@$CYBERGREEN_REDSHIFT_CLUSTER_NAME.cqxchced59ta.eu-west-1.redshift.amazonaws.com:5439/$CYBERGREEN_BUILD_ENV", + "rds_uri": "postgresql://$CYBERGREEN_STATS_RDS_NAME:$CYBERGREEN_STATS_RDS_PASSWORD@$CYBERGREEN_RAW_SCAN_RDS_NAME.crovisjepxcd.eu-west-1.rds.amazonaws.com:5432/$CYBERGREEN_STATS_RDS_NAME", + "redshift_uri": "postgresql://$CYBERGREEN_REDSHIFT_USER:$CYBERGREEN_REDSHIFT_PASSWORD@$CYBERGREEN_REDSHIFT_CLUSTER_NAME.cqxchced59ta.eu-west-1.redshift.amazonaws.com:5439/$CYBERGREEN_BUILD_ENV", "role_arn": "arn:aws:iam::635396214416:role/RedshiftCopyUnload", "source_path": "$CYBERGREEN_SOURCE_ROOT", "dest_path": "$CYBERGREEN_DEST_ROOT", diff --git a/requirements.txt b/requirements.txt index 6490f6c..8443a30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -datapackage==0.8.4 -jsontableschema-sql==0.8.0 -SQLAlchemy==1.2.6 -psycopg2==2.7.3.2 -boto3==1.4.7 -rfc3986==0.4.1 +datapackage==1.15.2 +tableschema-sql==2.0.1 +SQLAlchemy==2.0.19 +psycopg2==2.9.7 +boto3==1.28.22 +rfc3986==2.0.0 +sqlalchemy-redshift==0.8.14 From efa92662c3324afe3366ec9d08b6f1db2fe88cb0 Mon Sep 17 00:00:00 2001 From: Sudheesh Singanamalla Date: Thu, 17 Aug 2023 15:22:09 -0700 Subject: [PATCH 2/2] Update tests, begin migrating to latest postgresql Signed-off-by: Sudheesh Singanamalla --- load_asn_ref_data.sh | 2 +- main.py | 161 +++++++++++++++++++++++-------------- models.py | 39 +++++++++ requirements.txt | 1 + tests/__init__.py | 0 tests/aggregation_tests.py | 96 +++++++++------------- tests/config.test.json | 4 +- tests/requirements.txt | 2 +- 8 files changed, 183 insertions(+), 122 deletions(-) mode change 100644 => 100755 load_asn_ref_data.sh create mode 100644 models.py create mode 100644 tests/__init__.py diff --git a/load_asn_ref_data.sh b/load_asn_ref_data.sh old mode 100644 new mode 100755 index 25a7ba7..7507afe --- a/load_asn_ref_data.sh +++ b/load_asn_ref_data.sh @@ -1,4 +1,4 @@ # Sets up environment to load ASN ref data source ../env -./load_asn_ref_data.py +python ./load_asn_ref_data.py diff --git a/main.py b/main.py index 5c0623e..6496af4 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,9 @@ from __future__ import print_function -from datapackage import push_datapackage +from datapackage import push_datapackage, Package, Resource from psycopg2.extensions import AsIs -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text, MetaData, Table, Integer, Column, TIMESTAMP, String, BigInteger, Text, \ + Float, Boolean from os.path import dirname, join from string import Template from textwrap import dedent @@ -17,6 +18,9 @@ import csv import os +from sqlalchemy.dialects import postgresql +from sqlalchemy.sql.ddl import CreateTable + #utils def rpath(*args): return join(dirname(__file__), *args) @@ -84,10 +88,12 @@ def run(self): shutil.rmtree(self.tmpdir) - def drop_tables(self, cursor, tables): + def drop_tables(self, conn, tables): + cursor = conn.connect() for tablename in tables: + statement = text("DROP TABLE IF EXISTS :table CASCADE") cursor.execute( - "DROP TABLE IF EXISTS %(table)s CASCADE", + statement, {"table": AsIs(tablename)} ) @@ -127,35 +133,55 @@ def upload_manifest(self): def create_tables(self): conn = self.connRedshift.connect() - tablenames = [ - 'dim_risk', 'logentry', 'count' - ] - self.drop_tables(conn, tablenames) - create_logentry = dedent(''' - CREATE TABLE logentry( - date TIMESTAMP, ip VARCHAR(32), risk INT, - asn BIGINT, country VARCHAR(2) - ) - ''') - create_risk = dedent(''' - CREATE TABLE dim_risk( - id INT, slug VARCHAR(32), title VARCHAR(32), - is_archived BOOLEAN, - taxonomy VARCHAR(16), measurement_units VARCHAR(32), - amplification_factor FLOAT, description TEXT - ) - ''') - create_count = dedent(''' - CREATE TABLE count( - date TIMESTAMP, risk INT, country VARCHAR(2), - asn BIGINT, count INT, count_amplified FLOAT - ) - ''') - conn.execute(create_risk) - conn.execute(create_logentry) - conn.execute(create_count) - conn.close() - logging.info('Redshift tables created') + + transaction = conn.begin() + + metadata = MetaData() + logentry_table = Table('logentry', metadata, + Column('id', Integer), + Column('date', TIMESTAMP), + Column('ip', String(32)), + Column('risk', Integer), + Column('asn', BigInteger), + Column('country', String(2))) + statement = CreateTable(logentry_table, if_not_exists=True) + print(statement.compile(dialect=postgresql.dialect())) + + conn.execute(statement) + transaction.commit() + + transaction = conn.begin() + metadata = MetaData() + risk_table = Table('dim_risk', metadata, + Column('id', Integer), + Column('slug', String(32)), + Column('title', String(32)), + Column('is_archived', Boolean), + Column('taxonomy', String(16)), + Column('measurement_units', String(32)), + Column('amplification_factor', Float), + Column('description', Text)) + statement = CreateTable(risk_table, if_not_exists=True) + print(statement.compile(dialect=postgresql.dialect())) + + conn.execute(statement) + transaction.commit() + + transaction = conn.begin() + metadata = MetaData() + count_table = Table('count', metadata, + Column('date', TIMESTAMP), + Column('risk', Integer), + Column('country', String(2)), + Column('asn', BigInteger), + Column('count', Integer), + Column('count_amplified', Float)) + statement = CreateTable(count_table, if_not_exists=True) + print(statement.compile(dialect=postgresql.dialect())) + + conn.execute(statement) + transaction.commit() + print('Redshift tables created') def load_data(self): @@ -182,14 +208,18 @@ def load_ref_data(self): if inv.get('name') == 'risk': url = inv.get('url') dp = datapackage.DataPackage(url) - risks = dp.resources[0].data - query = dedent(''' - INSERT INTO dim_risk - VALUES (%(id)s, %(slug)s, %(title)s, %(is_archived)s, %(taxonomy)s, %(measurement_units)s, %(amplification_factor)s, %(description)s)''') + risks = dp.resources[0].read(keyed=True) + table_name = 'dim_risk' + columns = ', '.join(['id', 'slug', 'title', + 'is_archived', 'taxonomy', 'measurement_units', + 'amplification_factor', 'description']) + + statement = text(f'INSERT INTO {table_name} {columns} VALUES (:id, :slug, :title, :is_archived, :taxonomy, :measurement_units, :amplification_factor, :description)') + for risk in risks: # description is too long and not needed here risk['description']='' - conn.execute(query,risk) + conn.execute(statement,risk) conn.close() @@ -210,9 +240,9 @@ def aggregate(self): FROM( SELECT DISTINCT (ip), date_trunc('day', date) AS date, risk, asn, country FROM logentry ) AS foo - GROUP BY date, asn, risk, country HAVING count(*) > %(threshold)s ORDER BY date DESC, country ASC, asn ASC, risk ASC) + GROUP BY date, asn, risk, country HAVING count(*) > :threshold ORDER BY date DESC, country ASC, asn ASC, risk ASC) ''') - conn.execute(query, {'threshold': self.country_count_threshold}) + conn.execute(text(query), {'threshold': self.country_count_threshold}) conn.close() @@ -224,7 +254,7 @@ def update_amplified_count(self): SET count_amplified = count*amplification_factor FROM dim_risk WHERE risk=id ''') - conn.execute(query) + conn.execute(text(query)) conn.close() logging.info('Aggregation Finished!') @@ -289,8 +319,11 @@ def run(self): def drop_tables(self, tables): + conn = self.connRDS.connect() + transaction = conn.begin() for tablename in tables: - self.connRDS.execute("DROP TABLE IF EXISTS %(table)s CASCADE",{"table": AsIs(tablename)}) + conn.execute(text("DROP TABLE IF EXISTS :table CASCADE"),{"table": AsIs(tablename)}) + transaction.commit() def download_and_load(self): @@ -304,7 +337,7 @@ def download_and_load(self): logging.info('Loading into RDS ...') # TODO: replace shelling out to psql copy_command = dedent(''' - psql {uri} -c "\COPY fact_count FROM {tmp}/count.csv WITH delimiter as ',' null '' csv;" + psql {uri} -c "\\COPY fact_count FROM {tmp}/count.csv WITH delimiter as ',' null '' csv;" ''') os.system(copy_command.format(tmp=self.tmpdir,uri=self.config.get('rds_uri'))) @@ -313,29 +346,33 @@ def load_ref_data_rds(self): logging.info('Loading reference_data to RDS ...') conn = self.connRDS.connect() # creating dim_asn table here with other ref data - conn.execute('DROP TABLE IF EXISTS data__asn___asn CASCADE') - create_asn = 'CREATE TABLE data__asn___asn(number BIGINT, title TEXT, country TEXT)' + transaction = conn.begin() + conn.execute(text('DROP TABLE IF EXISTS data__asn___asn')) + create_asn = text('CREATE TABLE data__asn___asn(number BIGINT, title TEXT, country TEXT)') conn.execute(create_asn) + transaction.commit() for url in self.ref_data_urls: + print(url) # Loading of asn with push_datapackage takes more then 2 hours # So have to download locally and save (takes ~5 seconds) if 'asn' not in url: - push_datapackage(descriptor=url, backend='sql', engine=conn) + print(f'Skipping {url}') + # push_datapackage(descriptor=url, backend='sql', engine=conn) else: dp = datapackage.DataPackage(url) # local path will be returned if not found remote one (for tests) - if dp.resources[0].remote_data_path: - r = requests.get(dp.resources[0].remote_data_path) - with open(join(self.tmpdir, 'asn.csv'),"wb") as fp: - fp.write(r.content) - else: - shutil.copy(dp.resources[0].local_data_path,join(self.tmpdir, 'asn.csv')) - # TODO: replace shelling out - copy_command = dedent(''' - psql {uri} -c "\COPY data__asn___asn FROM {tmp}/asn.csv WITH delimiter as ',' csv header;" - ''') - os.system(copy_command.format(tmp=self.tmpdir,uri=self.config.get('rds_uri'))) + if len(dp.resources) > 0: + data = dp.resources[0].read() + import csv + with open(join(self.tmpdir, 'asn.csv'),"w") as fp: + fw = csv.writer(fp) + fw.writerows(data) + copy_command = dedent(''' + psql {uri} -c "\\COPY data__asn___asn FROM {tmp}/asn.csv WITH delimiter as ',' csv header;" + ''') + print(f'Executing psql command {copy_command.format(tmp=self.tmpdir, uri=self.config.get("rds_uri"))}') + os.system(copy_command.format(tmp=self.tmpdir, uri=self.config.get('rds_uri'))) conn.close() @@ -367,11 +404,11 @@ def create_tables(self): count_amplified FLOAT )''') - conn.execute(create_risk) - conn.execute(create_country) - conn.execute(create_asn) - conn.execute(create_time) - conn.execute(create_count) + conn.execute(text(create_risk)) + conn.execute(text(create_country)) + conn.execute(text(create_asn)) + conn.execute(text(create_time)) + conn.execute(text(create_count)) self.create_or_update_cubes(conn, create_cube) conn.close() diff --git a/models.py b/models.py new file mode 100644 index 0000000..bf2403a --- /dev/null +++ b/models.py @@ -0,0 +1,39 @@ +from sqlalchemy import Column, String, Integer, BigInteger, Boolean, Float, Text, TIMESTAMP +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() + + +class LogEntry(Base): + __tablename__ = 'logentry' + + id = Column(Integer) + date = Column(TIMESTAMP) + ip = Column(String(32)) + risk = Column(Integer) + asn = Column(BigInteger) + country = Column(String(2)) + + +class Risk(Base): + __tablename__ = "dim_risk" + + id = Column(Integer) + slug = Column(String(32)) + title = Column(String(32)) + is_archived = Column(Boolean) + taxonomy = Column(String(16)) + measurement_units = Column(String(32)) + amplification_factor = Column(Float) + description = Column(Text) + + +class Count(Base): + __tablename__ = "count" + + date = Column(TIMESTAMP) + risk = Column(Integer) + country = Column(String(2)) + asn = Column(BigInteger) + count = Column(Integer) + count_amplified = Column(Float) diff --git a/requirements.txt b/requirements.txt index 8443a30..148a1b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ psycopg2==2.9.7 boto3==1.28.22 rfc3986==2.0.0 sqlalchemy-redshift==0.8.14 +redshift_connector==2.0.913 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/aggregation_tests.py b/tests/aggregation_tests.py index 350356e..03c014b 100644 --- a/tests/aggregation_tests.py +++ b/tests/aggregation_tests.py @@ -9,9 +9,12 @@ from io import StringIO from textwrap import dedent +from time import sleep from psycopg2.extensions import AsIs -from sqlalchemy import create_engine +from sqlalchemy import create_engine, text, inspect +from sqlalchemy.dialects import postgresql +from sqlalchemy.engine import Inspector from aggregator.main import Aggregator, LoadToRDS @@ -36,25 +39,15 @@ def setUp(self): # create tables self.aggregator.create_tables() - def test_all_tables_created(self): ''' Checks if all necessary tables are created for redshift ''' - self.aggregator.create_tables() - self.cursor.execute( - 'select exists(select * from information_schema.tables where table_name=%(table)s)', - {'table': 'logentry'}) - self.assertEqual(self.cursor.fetchone()[0], True) - self.cursor.execute( - 'select exists(select * from information_schema.tables where table_name=%(table)s)', - {'table': 'count'}) - self.assertEqual(self.cursor.fetchone()[0], True) - self.cursor.execute( - 'select exists(select * from information_schema.tables where table_name=%(table)s)', - {'table': 'dim_risk'}) - self.assertEqual(self.cursor.fetchone()[0], True) - + inspector = inspect(self.aggregator.connRedshift) + tables_list = inspector.get_table_names() + for table_name in ['logentry', 'dim_risk', 'count']: + print(f'Checking for table {table_name}') + self.assertTrue(table_name in tables_list) def test_drop_tables(self): ''' @@ -64,21 +57,11 @@ def test_drop_tables(self): self.aggregator.connRedshift, ['logentry', 'count', 'dim_risk'] ) - self.cursor.execute( - 'select exists(select * from information_schema.tables where table_name=%(table)s)', - {'table': 'logentry'} - ) - self.assertEqual(self.cursor.fetchone()[0], False) - self.cursor.execute( - 'select exists(select * from information_schema.tables where table_name=%(table)s)', - {'table': 'count'} - ) - self.assertEqual(self.cursor.fetchone()[0], False) - self.cursor.execute( - 'select exists(select * from information_schema.tables where table_name=%(table)s)', - {'table': 'dim_risk'} - ) - self.assertEqual(self.cursor.fetchone()[0], False) + inspector = inspect(self.aggregator.connRedshift) + tables_list = inspector.get_table_names() + for table_name in ['logentry', 'dim_risk', 'count']: + print(f'Checking for table {table_name}') + self.assertFalse(table_name in tables_list) def test_referenece_data_loaded(self): @@ -87,6 +70,7 @@ def test_referenece_data_loaded(self): ''' # load data self.aggregator.load_ref_data() + print('Finished loading reference data') self.cursor.execute('SELECT * FROM dim_risk') self.assertEqual(self.cursor.fetchone(), (0, u'test-risk', u'Test Risk', False, 'Testable','count', 0.13456, u'')) @@ -283,11 +267,11 @@ def test_aplified_count(self): self.aggregator.create_tables() # GIVEN 4 entries of the same day, country, ASN an IP but different risks scan_csv = dedent('''\ - ts,ip,risk_id,asn,cc - 2016-09-28T00:00:01+00:00,71.3.0.1,1,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.1,2,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.1,4,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.1,5,4444,US + id,ts,ip,risk_id,asn,cc + 1,2016-09-28T00:00:01+00:00,71.3.0.1,1,4444,US + 2,2016-09-28T00:00:01+00:00,71.3.0.1,2,4444,US + 3,2016-09-28T00:00:01+00:00,71.3.0.1,4,4444,US + 4,2016-09-28T00:00:01+00:00,71.3.0.1,5,4444,US ''') self.cursor.copy_expert("COPY logentry from STDIN csv header", StringIO(scan_csv)) # import ref data @@ -297,7 +281,7 @@ def test_aplified_count(self): self.aggregator.update_amplified_count() self.maxDiff = None - self.cursor.execute('select * from count;') + self.cursor.execute(text('select * from count;')) self.assertEqual( self.cursor.fetchall(), [ @@ -316,21 +300,21 @@ def test_aplified_count_when_grouped(self): self.aggregator.create_tables() # GIVEN 4 entries of the same day, country, ASN, but different risks scan_csv = dedent('''\ - ts,ip,risk_id,asn,cc - 2016-09-28T00:00:01+00:00,71.3.0.1,1,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.2,1,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.3,1,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.1,2,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.2,2,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.1,4,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.2,4,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.3,4,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.4,4,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.1,5,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.1,5,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.2,5,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.3,5,4444,US - 2016-09-28T00:00:01+00:00,71.3.0.4,5,4444,US + id,ts,ip,risk_id,asn,cc + 1,2016-09-28T00:00:01+00:00,71.3.0.1,1,4444,US + 2,2016-09-28T00:00:01+00:00,71.3.0.2,1,4444,US + 3,2016-09-28T00:00:01+00:00,71.3.0.3,1,4444,US + 4,2016-09-28T00:00:01+00:00,71.3.0.1,2,4444,US + 5,2016-09-28T00:00:01+00:00,71.3.0.2,2,4444,US + 6,2016-09-28T00:00:01+00:00,71.3.0.1,4,4444,US + 7,2016-09-28T00:00:01+00:00,71.3.0.2,4,4444,US + 8,2016-09-28T00:00:01+00:00,71.3.0.3,4,4444,US + 9,2016-09-28T00:00:01+00:00,71.3.0.4,4,4444,US + 10,2016-09-28T00:00:01+00:00,71.3.0.1,5,4444,US + 11,2016-09-28T00:00:01+00:00,71.3.0.1,5,4444,US + 12,2016-09-28T00:00:01+00:00,71.3.0.2,5,4444,US + 13,2016-09-28T00:00:01+00:00,71.3.0.3,5,4444,US + 14,2016-09-28T00:00:01+00:00,71.3.0.4,5,4444,US ''') self.cursor.copy_expert("COPY logentry from STDIN csv header", StringIO(scan_csv)) # import ref data @@ -340,7 +324,7 @@ def test_aplified_count_when_grouped(self): self.aggregator.update_amplified_count() self.maxDiff = None - self.cursor.execute('select * from count;') + self.cursor.execute(text('select * from count;')) self.assertEqual( self.cursor.fetchall(), [ @@ -365,8 +349,8 @@ def setUp(self): self.tablenames = [ 'fact_count', 'agg_risk_country_week', 'agg_risk_country_month', 'agg_risk_country_quarter', - 'agg_risk_country_year', 'dim_risk', 'dim_country', - 'dim_asn', 'dim_date' + 'agg_risk_country_year', 'data__risk___risk', 'data__country___country', + 'data__asn___asn', 'dim_date' ] # snipet for fact_count table self.counts = dedent(''' @@ -524,4 +508,4 @@ def test_create_manifest(self): 'mandatory': True} ]} manifest = self.aggregator.create_manifest(datapackage, 's3://test.bucket/test/key') - self.assertEquals(manifest,expected_manifest) + self.assertEqual(manifest, expected_manifest) diff --git a/tests/config.test.json b/tests/config.test.json index 561a8cf..ff1b240 100644 --- a/tests/config.test.json +++ b/tests/config.test.json @@ -1,6 +1,6 @@ { - "rds_uri": "postgres://cg_test_user:secret@localhost/cg_test_db", - "redshift_uri": "postgres://cg_test_user:secret@localhost/cg_test_db", + "rds_uri": "postgresql://cg_test_user:secret@localhost/cg_test_db", + "redshift_uri": "postgresql://cg_test_user:secret@localhost/cg_test_db", "role_arn":"", "source_path": "", "dest_path": "", diff --git a/tests/requirements.txt b/tests/requirements.txt index 15cc901..aba4b47 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,3 @@ sqlalchemy -mock==2.0.0 +mock==5.1.0 nose==1.3.7