From 710554c64bcc36dee035d23bb687e81fa9e023cc Mon Sep 17 00:00:00 2001 From: pseusys Date: Wed, 6 Nov 2024 19:57:03 +0800 Subject: [PATCH] all user input escapedin ydb --- chatsky/context_storages/database.py | 2 +- chatsky/context_storages/ydb.py | 94 ++++++++++++++++++++-------- 2 files changed, 70 insertions(+), 26 deletions(-) diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index dece1ea98..df6608ed8 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -61,7 +61,7 @@ def __init__( @staticmethod def _verify_field_name(method: Callable): - def verifier(self, *args, **kwargs): + def verifier(self: "DBContextStorage", *args, **kwargs): field_name = args[1] if len(args) >= 1 else kwargs.get("field_name", None) if field_name is None: raise ValueError(f"For method {method.__name__} argument 'field_name' is not found!") diff --git a/chatsky/context_storages/ydb.py b/chatsky/context_storages/ydb.py index 34fd063fe..ff5de9fba 100644 --- a/chatsky/context_storages/ydb.py +++ b/chatsky/context_storages/ydb.py @@ -55,6 +55,9 @@ class YDBContextStorage(DBContextStorage): :param table_name: The name of the table to use. """ + _LIMIT_VAR = "limit" + _KEY_VAR = "key" + is_asynchronous = True def __init__( @@ -136,12 +139,15 @@ async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, byt async def callee(session: Session) -> Optional[Tuple[int, int, int, bytes, bytes]]: query = f""" PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${self._id_column_name} AS Utf8; SELECT {self._current_turn_id_column_name}, {self._created_at_column_name}, {self._updated_at_column_name}, {self._misc_column_name}, {self._framework_data_column_name} FROM {self.main_table} - WHERE {self._id_column_name} = "{ctx_id}"; + WHERE {self._id_column_name} = ${self._id_column_name}; """ # noqa: E501 result_sets = await session.transaction(SerializableReadWrite()).execute( - await session.prepare(query), dict(), commit_tx=True + await session.prepare(query), { + f"${self._id_column_name}": ctx_id, + }, commit_tx=True ) return ( result_sets[0].rows[0][self._current_turn_id_column_name], @@ -157,17 +163,19 @@ async def update_main_info(self, ctx_id: str, turn_id: int, crt_at: int, upd_at: async def callee(session: Session) -> None: query = f""" PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${self._id_column_name} AS Utf8; DECLARE ${self._current_turn_id_column_name} AS Uint64; DECLARE ${self._created_at_column_name} AS Uint64; DECLARE ${self._updated_at_column_name} AS Uint64; DECLARE ${self._misc_column_name} AS String; DECLARE ${self._framework_data_column_name} AS String; UPSERT INTO {self.main_table} ({self._id_column_name}, {self._current_turn_id_column_name}, {self._created_at_column_name}, {self._updated_at_column_name}, {self._misc_column_name}, {self._framework_data_column_name}) - VALUES ("{ctx_id}", ${self._current_turn_id_column_name}, ${self._created_at_column_name}, ${self._updated_at_column_name}, ${self._misc_column_name}, ${self._framework_data_column_name}); + VALUES (${self._id_column_name}, ${self._current_turn_id_column_name}, ${self._created_at_column_name}, ${self._updated_at_column_name}, ${self._misc_column_name}, ${self._framework_data_column_name}); """ # noqa: E501 await session.transaction(SerializableReadWrite()).execute( await session.prepare(query), { + f"${self._id_column_name}": ctx_id, f"${self._current_turn_id_column_name}": turn_id, f"${self._created_at_column_name}": crt_at, f"${self._updated_at_column_name}": upd_at, @@ -184,11 +192,14 @@ def construct_callee(table_name: str) -> Callable[[Session], Awaitable[None]]: async def callee(session: Session) -> None: query = f""" PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${self._id_column_name} AS Utf8; DELETE FROM {table_name} - WHERE {self._id_column_name} = "{ctx_id}"; + WHERE {self._id_column_name} = ${self._id_column_name}; """ # noqa: E501 await session.transaction(SerializableReadWrite()).execute( - await session.prepare(query), dict(), commit_tx=True + await session.prepare(query), { + f"${self._id_column_name}": ctx_id, + }, commit_tx=True ) return callee @@ -201,21 +212,32 @@ async def callee(session: Session) -> None: @DBContextStorage._verify_field_name async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]: async def callee(session: Session) -> List[Tuple[int, bytes]]: - limit, key = "", "" + declare, prepare, limit, key = list(), dict(), "", "" if isinstance(self._subscripts[field_name], int): - limit = f"LIMIT {self._subscripts[field_name]}" + declare += [f"DECLARE ${self._LIMIT_VAR} AS Uint64;"] + prepare.update({f"${self._LIMIT_VAR}": self._subscripts[field_name]}) + limit = f"LIMIT ${self._LIMIT_VAR}" elif isinstance(self._subscripts[field_name], Set): - keys = ", ".join([str(e) for e in self._subscripts[field_name]]) - key = f"AND {self._key_column_name} IN ({keys})" + values = list() + for i, k in enumerate(self._subscripts[field_name]): + declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Utf8;"] + prepare.update({f"${self._KEY_VAR}_{i}": k}) + values += [f"${self._KEY_VAR}_{i}"] + key = f"AND {self._KEY_VAR} IN ({', '.join(values)})" query = f""" PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${self._id_column_name} AS Utf8; + {" ".join(declare)} SELECT {self._key_column_name}, {field_name} FROM {self.turns_table} - WHERE {self._id_column_name} = "{ctx_id}" AND {field_name} IS NOT NULL {key} + WHERE {self._id_column_name} = ${self._id_column_name} AND {field_name} IS NOT NULL {key} ORDER BY {self._key_column_name} DESC {limit}; """ # noqa: E501 result_sets = await session.transaction(SerializableReadWrite()).execute( - await session.prepare(query), dict(), commit_tx=True + await session.prepare(query), { + f"${self._id_column_name}": ctx_id, + **prepare, + }, commit_tx=True ) return [ (e[self._key_column_name], e[field_name]) for e in result_sets[0].rows @@ -228,12 +250,15 @@ async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]: async def callee(session: Session) -> List[int]: query = f""" PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${self._id_column_name} AS Utf8; SELECT {self._key_column_name} FROM {self.turns_table} - WHERE {self._id_column_name} = "{ctx_id}" AND {field_name} IS NOT NULL; + WHERE {self._id_column_name} = ${self._id_column_name} AND {field_name} IS NOT NULL; """ # noqa: E501 result_sets = await session.transaction(SerializableReadWrite()).execute( - await session.prepare(query), dict(), commit_tx=True + await session.prepare(query), { + f"${self._id_column_name}": ctx_id, + }, commit_tx=True ) return [ e[self._key_column_name] for e in result_sets[0].rows @@ -244,15 +269,24 @@ async def callee(session: Session) -> List[int]: @DBContextStorage._verify_field_name async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]: async def callee(session: Session) -> List[Tuple[int, bytes]]: + declare, prepare = list(), dict() + for i, k in enumerate(keys): + declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Uint32;"] + prepare.update({f"${self._KEY_VAR}_{i}": k}) query = f""" PRAGMA TablePathPrefix("{self.database}"); + DECLARE ${self._id_column_name} AS Utf8; + {" ".join(declare)} SELECT {self._key_column_name}, {field_name} FROM {self.turns_table} - WHERE {self._id_column_name} = "{ctx_id}" AND {field_name} IS NOT NULL - AND {self._key_column_name} IN ({', '.join([str(e) for e in keys])}); + WHERE {self._id_column_name} = ${self._id_column_name} AND {field_name} IS NOT NULL + AND {self._key_column_name} IN ({", ".join(prepare.keys())}); """ # noqa: E501 result_sets = await session.transaction(SerializableReadWrite()).execute( - await session.prepare(query), dict(), commit_tx=True + await session.prepare(query), { + f"${self._id_column_name}": ctx_id, + **prepare, + }, commit_tx=True ) return [ (e[self._key_column_name], e[field_name]) for e in result_sets[0].rows @@ -266,20 +300,30 @@ async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tup return async def callee(session: Session) -> None: - keys = [str(k) for k, _ in items] - placeholders = {k: f"${field_name}_{i}" for i, (k, v) in enumerate(items) if v is not None} - declarations = "\n".join(f"DECLARE {p} AS String;" for p in placeholders.values()) - values = ", ".join(f"(\"{ctx_id}\", {keys[i]}, {placeholders.get(k, 'NULL')})" for i, (k, _) in enumerate(items)) + declare, prepare, values = list(), dict(), list() + for i, (k, v) in enumerate(items): + declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Uint32;"] + prepare.update({f"${self._KEY_VAR}_{i}": k}) + if v is not None: + declare += [f"DECLARE ${field_name}_{i} AS String;"] + prepare.update({f"${field_name}_{i}": v}) + value_param = f"${field_name}_{i}" + else: + value_param = "NULL" + values += [f"(${self._id_column_name}, ${self._KEY_VAR}_{i}, {value_param})"] query = f""" PRAGMA TablePathPrefix("{self.database}"); - {declarations} + DECLARE ${self._id_column_name} AS Utf8; + {" ".join(declare)} UPSERT INTO {self.turns_table} ({self._id_column_name}, {self._key_column_name}, {field_name}) - VALUES {values}; + VALUES {", ".join(values)}; """ # noqa: E501 + await session.transaction(SerializableReadWrite()).execute( - await session.prepare(query), - {placeholders[k]: v for k, v in items if k in placeholders.keys()}, - commit_tx=True + await session.prepare(query), { + f"${self._id_column_name}": ctx_id, + **prepare, + }, commit_tx=True ) await self.pool.retry_operation(callee)