diff --git a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py index aa6f5ff4d..057502e0d 100644 --- a/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py +++ b/integrations/snowflake/src/haystack_integrations/components/retrievers/snowflake/snowflake_table_retriever.py @@ -39,6 +39,7 @@ class SnowflakeTableRetriever: database="", db_schema="", warehouse="", + role="" ) # When database and schema are provided during component initialization. @@ -69,7 +70,10 @@ def __init__( user: str, account: str, api_key: Secret = Secret.from_env_var("SNOWFLAKE_API_KEY"), # noqa: B008 + private_key_file: Optional[str] = None, + private_key_file_pwd: Optional[str] = None, database: Optional[str] = None, + role: Optional[str] = None, db_schema: Optional[str] = None, warehouse: Optional[str] = None, login_timeout: Optional[int] = None, @@ -80,15 +84,21 @@ def __init__( :param api_key: Snowflake account password. :param database: Name of the database to use. :param db_schema: Name of the schema to use. + :param role: Name of role to use. :param warehouse: Name of the warehouse to use. :param login_timeout: Timeout in seconds for login. By default, 60 seconds. + :param private_key_file: Location of private key + :param private_key_file_pwd: Password for private key file """ self.user = user self.account = account self.api_key = api_key + self.private_key_file = private_key_file + self.private_key_file_pwd = private_key_file_pwd self.database = database self.db_schema = db_schema + self.role = role self.warehouse = warehouse self.login_timeout = login_timeout or 60 @@ -104,8 +114,11 @@ def to_dict(self) -> Dict[str, Any]: user=self.user, account=self.account, api_key=self.api_key.to_dict(), + private_key_file=self.private_key_file, + private_key_file_pwd=self.private_key_file_pwd, database=self.database, db_schema=self.db_schema, + role=self.role, warehouse=self.warehouse, login_timeout=self.login_timeout, ) @@ -208,14 +221,16 @@ def _has_select_privilege(privileges: list, table_name: str) -> bool: :param privileges: List of privileges. :param table_name: Name of the table. """ - + print(privileges) for privilege in reversed(privileges): if table_name.lower() == privilege[3].lower() and re.match( - pattern="truncate|update|insert|delete|operate|references", + pattern="select", string=privilege[1], flags=re.IGNORECASE, ): - return False + return True + else: + False return True @@ -226,23 +241,25 @@ def _check_privilege( user: str, ) -> bool: """ - Check whether a user has a `select`-only access to the table. + Check whether a user has `select` access to the table. :param conn: An open connection to Snowflake. :param query: The query from where to extract table names to check read-only access. """ cur = conn.cursor() + if self.role is None: + cur.execute(f"SHOW GRANTS TO USER {user};") - cur.execute(f"SHOW GRANTS TO USER {user};") - - # Get user's latest role - roles = cur.fetchall() - if not roles: - logger.error("User does not exist") - return False + # Get user's latest role + roles = cur.fetchall() + if not roles: + logger.error("User does not exist") + return False - # Last row second column from GRANT table - role = roles[-1][1] + # Last row second column from GRANT table + role = roles[-1][1] + else: + role = self.role # Get role privilege cur.execute(f"SHOW GRANTS TO ROLE {role};") @@ -280,11 +297,14 @@ def _fetch_data( connect_params={ "user": self.user, "account": self.account, - "password": self.api_key.resolve_value(), + "password": self.api_key, + "private_key_file": self.private_key_file, + "private_key_file_pwd": self.private_key_file_pwd, "database": self.database, "schema": self.db_schema, "warehouse": self.warehouse, - "login_timeout": self.login_timeout, + "role": self.role, + "login_timeout": self.login_timeout } ) if conn is None: