Skip to content

Commit

Permalink
adding private key authentication and role parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
iireland-ii committed Nov 12, 2024
1 parent 4cfee2d commit 76bfd0f
Showing 1 changed file with 35 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class SnowflakeTableRetriever:
database="<DATABASE-NAME>",
db_schema="<SCHEMA-NAME>",
warehouse="<WAREHOUSE-NAME>",
role="<ROLE-NAME>"
)
# When database and schema are provided during component initialization.
Expand Down Expand Up @@ -69,7 +70,10 @@ def __init__(
user: str,
account: str,
api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008
private_key_file: Optional[str] = None,
private_key_file_pwd: Optional[str] = None,
database: Optional[str] = None,
role: Optional[str] = None,
db_schema: Optional[str] = None,
warehouse: Optional[str] = None,
login_timeout: Optional[int] = None,
Expand All @@ -80,15 +84,21 @@ def __init__(
:param api_key: Snowflake account password.
:param database: Name of the database to use.
:param db_schema: Name of the schema to use.
:param role: Name of role to use.
:param warehouse: Name of the warehouse to use.
:param login_timeout: Timeout in seconds for login. By default, 60 seconds.
:param private_key_file: Location of private key
:param private_key_file_pwd: Password for private key file
"""

self.user = user
self.account = account
self.api_key = api_key
self.private_key_file = private_key_file
self.private_key_file_pwd = private_key_file_pwd
self.database = database
self.db_schema = db_schema
self.role = role
self.warehouse = warehouse
self.login_timeout = login_timeout or 60

Expand All @@ -104,8 +114,11 @@ def to_dict(self) -> Dict[str, Any]:
user=self.user,
account=self.account,
api_key=self.api_key.to_dict(),
private_key_file=self.private_key_file,
private_key_file_pwd=self.private_key_file_pwd,
database=self.database,
db_schema=self.db_schema,
role=self.role,
warehouse=self.warehouse,
login_timeout=self.login_timeout,
)
Expand Down Expand Up @@ -208,14 +221,16 @@ def _has_select_privilege(privileges: list, table_name: str) -> bool:
:param privileges: List of privileges.
:param table_name: Name of the table.
"""

print(privileges)
for privilege in reversed(privileges):
if table_name.lower() == privilege[3].lower() and re.match(
pattern="truncate|update|insert|delete|operate|references",
pattern="select",
string=privilege[1],
flags=re.IGNORECASE,
):
return False
return True
else:
False

return True

Expand All @@ -226,23 +241,25 @@ def _check_privilege(
user: str,
) -> bool:
"""
Check whether a user has a `select`-only access to the table.
Check whether a user has `select` access to the table.
:param conn: An open connection to Snowflake.
:param query: The query from where to extract table names to check read-only access.
"""
cur = conn.cursor()
if self.role is None:
cur.execute(f"SHOW GRANTS TO USER {user};")

cur.execute(f"SHOW GRANTS TO USER {user};")

# Get user's latest role
roles = cur.fetchall()
if not roles:
logger.error("User does not exist")
return False
# Get user's latest role
roles = cur.fetchall()
if not roles:
logger.error("User does not exist")
return False

# Last row second column from GRANT table
role = roles[-1][1]
# Last row second column from GRANT table
role = roles[-1][1]
else:
role = self.role

# Get role privilege
cur.execute(f"SHOW GRANTS TO ROLE {role};")
Expand Down Expand Up @@ -280,11 +297,14 @@ def _fetch_data(
connect_params={
"user": self.user,
"account": self.account,
"password": self.api_key.resolve_value(),
"password": self.api_key,
"private_key_file": self.private_key_file,
"private_key_file_pwd": self.private_key_file_pwd,
"database": self.database,
"schema": self.db_schema,
"warehouse": self.warehouse,
"login_timeout": self.login_timeout,
"role": self.role,
"login_timeout": self.login_timeout
}
)
if conn is None:
Expand Down

0 comments on commit 76bfd0f

Please sign in to comment.