Skip to content

Commit

Permalink
all user input escapedin ydb
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Nov 6, 2024
1 parent dbbbb28 commit 710554c
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 26 deletions.
2 changes: 1 addition & 1 deletion chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(

@staticmethod
def _verify_field_name(method: Callable):
def verifier(self, *args, **kwargs):
def verifier(self: "DBContextStorage", *args, **kwargs):
field_name = args[1] if len(args) >= 1 else kwargs.get("field_name", None)
if field_name is None:
raise ValueError(f"For method {method.__name__} argument 'field_name' is not found!")
Expand Down
94 changes: 69 additions & 25 deletions chatsky/context_storages/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class YDBContextStorage(DBContextStorage):
:param table_name: The name of the table to use.
"""

_LIMIT_VAR = "limit"
_KEY_VAR = "key"

is_asynchronous = True

def __init__(
Expand Down Expand Up @@ -136,12 +139,15 @@ async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, byt
async def callee(session: Session) -> Optional[Tuple[int, int, int, bytes, bytes]]:
query = f"""
PRAGMA TablePathPrefix("{self.database}");
DECLARE ${self._id_column_name} AS Utf8;
SELECT {self._current_turn_id_column_name}, {self._created_at_column_name}, {self._updated_at_column_name}, {self._misc_column_name}, {self._framework_data_column_name}
FROM {self.main_table}
WHERE {self._id_column_name} = "{ctx_id}";
WHERE {self._id_column_name} = ${self._id_column_name};
""" # noqa: E501
result_sets = await session.transaction(SerializableReadWrite()).execute(
await session.prepare(query), dict(), commit_tx=True
await session.prepare(query), {
f"${self._id_column_name}": ctx_id,
}, commit_tx=True
)
return (
result_sets[0].rows[0][self._current_turn_id_column_name],
Expand All @@ -157,17 +163,19 @@ async def update_main_info(self, ctx_id: str, turn_id: int, crt_at: int, upd_at:
async def callee(session: Session) -> None:
query = f"""
PRAGMA TablePathPrefix("{self.database}");
DECLARE ${self._id_column_name} AS Utf8;
DECLARE ${self._current_turn_id_column_name} AS Uint64;
DECLARE ${self._created_at_column_name} AS Uint64;
DECLARE ${self._updated_at_column_name} AS Uint64;
DECLARE ${self._misc_column_name} AS String;
DECLARE ${self._framework_data_column_name} AS String;
UPSERT INTO {self.main_table} ({self._id_column_name}, {self._current_turn_id_column_name}, {self._created_at_column_name}, {self._updated_at_column_name}, {self._misc_column_name}, {self._framework_data_column_name})
VALUES ("{ctx_id}", ${self._current_turn_id_column_name}, ${self._created_at_column_name}, ${self._updated_at_column_name}, ${self._misc_column_name}, ${self._framework_data_column_name});
VALUES (${self._id_column_name}, ${self._current_turn_id_column_name}, ${self._created_at_column_name}, ${self._updated_at_column_name}, ${self._misc_column_name}, ${self._framework_data_column_name});
""" # noqa: E501
await session.transaction(SerializableReadWrite()).execute(
await session.prepare(query),
{
f"${self._id_column_name}": ctx_id,
f"${self._current_turn_id_column_name}": turn_id,
f"${self._created_at_column_name}": crt_at,
f"${self._updated_at_column_name}": upd_at,
Expand All @@ -184,11 +192,14 @@ def construct_callee(table_name: str) -> Callable[[Session], Awaitable[None]]:
async def callee(session: Session) -> None:
query = f"""
PRAGMA TablePathPrefix("{self.database}");
DECLARE ${self._id_column_name} AS Utf8;
DELETE FROM {table_name}
WHERE {self._id_column_name} = "{ctx_id}";
WHERE {self._id_column_name} = ${self._id_column_name};
""" # noqa: E501
await session.transaction(SerializableReadWrite()).execute(
await session.prepare(query), dict(), commit_tx=True
await session.prepare(query), {
f"${self._id_column_name}": ctx_id,
}, commit_tx=True
)

return callee
Expand All @@ -201,21 +212,32 @@ async def callee(session: Session) -> None:
@DBContextStorage._verify_field_name
async def load_field_latest(self, ctx_id: str, field_name: str) -> List[Tuple[int, bytes]]:
async def callee(session: Session) -> List[Tuple[int, bytes]]:
limit, key = "", ""
declare, prepare, limit, key = list(), dict(), "", ""
if isinstance(self._subscripts[field_name], int):
limit = f"LIMIT {self._subscripts[field_name]}"
declare += [f"DECLARE ${self._LIMIT_VAR} AS Uint64;"]
prepare.update({f"${self._LIMIT_VAR}": self._subscripts[field_name]})
limit = f"LIMIT ${self._LIMIT_VAR}"
elif isinstance(self._subscripts[field_name], Set):
keys = ", ".join([str(e) for e in self._subscripts[field_name]])
key = f"AND {self._key_column_name} IN ({keys})"
values = list()
for i, k in enumerate(self._subscripts[field_name]):
declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Utf8;"]
prepare.update({f"${self._KEY_VAR}_{i}": k})
values += [f"${self._KEY_VAR}_{i}"]
key = f"AND {self._KEY_VAR} IN ({', '.join(values)})"
query = f"""
PRAGMA TablePathPrefix("{self.database}");
DECLARE ${self._id_column_name} AS Utf8;
{" ".join(declare)}
SELECT {self._key_column_name}, {field_name}
FROM {self.turns_table}
WHERE {self._id_column_name} = "{ctx_id}" AND {field_name} IS NOT NULL {key}
WHERE {self._id_column_name} = ${self._id_column_name} AND {field_name} IS NOT NULL {key}
ORDER BY {self._key_column_name} DESC {limit};
""" # noqa: E501
result_sets = await session.transaction(SerializableReadWrite()).execute(
await session.prepare(query), dict(), commit_tx=True
await session.prepare(query), {
f"${self._id_column_name}": ctx_id,
**prepare,
}, commit_tx=True
)
return [
(e[self._key_column_name], e[field_name]) for e in result_sets[0].rows
Expand All @@ -228,12 +250,15 @@ async def load_field_keys(self, ctx_id: str, field_name: str) -> List[int]:
async def callee(session: Session) -> List[int]:
query = f"""
PRAGMA TablePathPrefix("{self.database}");
DECLARE ${self._id_column_name} AS Utf8;
SELECT {self._key_column_name}
FROM {self.turns_table}
WHERE {self._id_column_name} = "{ctx_id}" AND {field_name} IS NOT NULL;
WHERE {self._id_column_name} = ${self._id_column_name} AND {field_name} IS NOT NULL;
""" # noqa: E501
result_sets = await session.transaction(SerializableReadWrite()).execute(
await session.prepare(query), dict(), commit_tx=True
await session.prepare(query), {
f"${self._id_column_name}": ctx_id,
}, commit_tx=True
)
return [
e[self._key_column_name] for e in result_sets[0].rows
Expand All @@ -244,15 +269,24 @@ async def callee(session: Session) -> List[int]:
@DBContextStorage._verify_field_name
async def load_field_items(self, ctx_id: str, field_name: str, keys: List[int]) -> List[Tuple[int, bytes]]:
async def callee(session: Session) -> List[Tuple[int, bytes]]:
declare, prepare = list(), dict()
for i, k in enumerate(keys):
declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Uint32;"]
prepare.update({f"${self._KEY_VAR}_{i}": k})
query = f"""
PRAGMA TablePathPrefix("{self.database}");
DECLARE ${self._id_column_name} AS Utf8;
{" ".join(declare)}
SELECT {self._key_column_name}, {field_name}
FROM {self.turns_table}
WHERE {self._id_column_name} = "{ctx_id}" AND {field_name} IS NOT NULL
AND {self._key_column_name} IN ({', '.join([str(e) for e in keys])});
WHERE {self._id_column_name} = ${self._id_column_name} AND {field_name} IS NOT NULL
AND {self._key_column_name} IN ({", ".join(prepare.keys())});
""" # noqa: E501
result_sets = await session.transaction(SerializableReadWrite()).execute(
await session.prepare(query), dict(), commit_tx=True
await session.prepare(query), {
f"${self._id_column_name}": ctx_id,
**prepare,
}, commit_tx=True
)
return [
(e[self._key_column_name], e[field_name]) for e in result_sets[0].rows
Expand All @@ -266,20 +300,30 @@ async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tup
return

async def callee(session: Session) -> None:
keys = [str(k) for k, _ in items]
placeholders = {k: f"${field_name}_{i}" for i, (k, v) in enumerate(items) if v is not None}
declarations = "\n".join(f"DECLARE {p} AS String;" for p in placeholders.values())
values = ", ".join(f"(\"{ctx_id}\", {keys[i]}, {placeholders.get(k, 'NULL')})" for i, (k, _) in enumerate(items))
declare, prepare, values = list(), dict(), list()
for i, (k, v) in enumerate(items):
declare += [f"DECLARE ${self._KEY_VAR}_{i} AS Uint32;"]
prepare.update({f"${self._KEY_VAR}_{i}": k})
if v is not None:
declare += [f"DECLARE ${field_name}_{i} AS String;"]
prepare.update({f"${field_name}_{i}": v})
value_param = f"${field_name}_{i}"
else:
value_param = "NULL"
values += [f"(${self._id_column_name}, ${self._KEY_VAR}_{i}, {value_param})"]
query = f"""
PRAGMA TablePathPrefix("{self.database}");
{declarations}
DECLARE ${self._id_column_name} AS Utf8;
{" ".join(declare)}
UPSERT INTO {self.turns_table} ({self._id_column_name}, {self._key_column_name}, {field_name})
VALUES {values};
VALUES {", ".join(values)};
""" # noqa: E501

await session.transaction(SerializableReadWrite()).execute(
await session.prepare(query),
{placeholders[k]: v for k, v in items if k in placeholders.keys()},
commit_tx=True
await session.prepare(query), {
f"${self._id_column_name}": ctx_id,
**prepare,
}, commit_tx=True
)

await self.pool.retry_operation(callee)
Expand Down

0 comments on commit 710554c

Please sign in to comment.