Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new JobFile model to store UrsaDB iterator batch files #420

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .models.job import Job
from .models.jobagent import JobAgent
from .models.match import Match
from .models.queryresult import QueryResult
from .schema import MatchesSchema, ConfigSchema
from .config import app_config

Expand Down Expand Up @@ -111,6 +112,19 @@ def add_match(self, job: JobId, match: Match) -> None:
session.add(match)
session.commit()

def add_queryresult(self, job_id: int | None, files: List[str]) -> None:
with self.session() as session:
obj = QueryResult(job_id=job_id, files=files)
session.add(obj)
session.commit()

def remove_queryresult(self, job_id: int | None) -> None:
with self.session() as session:
session.query(QueryResult).where(
QueryResult.job_id == job_id
).delete()
session.commit()

def job_contains(self, job: JobId, ordinal: int, file_path: str) -> bool:
"""Make sure that the file path is in the job results."""
with self.session() as session:
Expand Down
5 changes: 3 additions & 2 deletions src/e2etests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import requests
import random
import os
import pprint

from ..lib.ursadb import UrsaDb # noqa

Expand Down Expand Up @@ -261,7 +262,7 @@ def request_query(log, i, taints=[]):
"taints": taints,
},
)
log.info("API response: %s", res.json())
log.info("API response: %s\n", pprint.pformat(res.json()))
res.raise_for_status()

query_hash = res.json()["query_hash"]
Expand All @@ -270,7 +271,7 @@ def request_query(log, i, taints=[]):
res = requests.get(
f"http://web:5000/api/matches/{query_hash}?offset=0&limit=50"
)
log.info("API response: %s", res.json())
log.info("API response: %s\n", pprint.pformat(res.json()))
if res.json()["job"]["status"] == "done":
break
time.sleep(1)
Expand Down
5 changes: 2 additions & 3 deletions src/lib/ursadb.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def query(
command += f"with taints {taints_whole_str} "
if dataset:
command += f'with datasets ["{dataset}"] '
command += f"into iterator {query};"
command += f"{query};"

start = time.perf_counter()
res = self.__execute(command, recv_timeout=-1)
Expand All @@ -75,8 +75,7 @@ def query(

return {
"time": (end - start),
"iterator": res["result"]["iterator"],
"file_count": res["result"]["file_count"],
"files": res["result"]["files"],
}

def pop(self, iterator: str, count: int) -> PopResult:
Expand Down
32 changes: 32 additions & 0 deletions src/migrations/versions/4e4c88411541_create_queryresult_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""create Queryresult model
Revision ID: 4e4c88411541
Revises: dbb81bd4d47f
Create Date: 2024-10-17 14:31:49.278443
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "4e4c88411541"
down_revision = "dbb81bd4d47f"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"queryresult",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("job_id", sa.Integer(), nullable=False),
sa.Column("files", sa.ARRAY(sa.String()), nullable=True),
sa.ForeignKeyConstraint(
["job_id"],
["job.internal_id"],
),
sa.PrimaryKeyConstraint("id"),
)


def downgrade() -> None:
op.drop_table("queryresult")
8 changes: 8 additions & 0 deletions src/models/queryresult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from sqlmodel import Field, SQLModel, ARRAY, Column, String
from typing import List, Union


class QueryResult(SQLModel, table=True):
mickol34 marked this conversation as resolved.
Show resolved Hide resolved
id: Union[int, None] = Field(default=None, primary_key=True)
job_id: Union[int, None] = Field(foreign_key="job.internal_id")
files: List[str] = Field(sa_column=Column(ARRAY(String)))
44 changes: 20 additions & 24 deletions src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from redis import Redis
from contextlib import contextmanager
import yara # type: ignore
from itertools import accumulate

from .db import Database, JobId
from .util import make_sha256_tag
Expand Down Expand Up @@ -240,9 +241,10 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
if "error" in result:
raise RuntimeError(result["error"])

file_count = result["file_count"]
iterator = result["iterator"]
logging.info(f"Iterator {iterator} contains {file_count} files")
files = result["files"]
agent.db.add_queryresult(job.internal_id, files)

file_count = len(files)

total_files = agent.db.update_job_files(job_id, file_count)
if job.files_limit and total_files > job.files_limit:
Expand All @@ -251,42 +253,36 @@ def query_ursadb(job_id: JobId, dataset_id: str, ursadb_query: str) -> None:
"Try a more precise query."
)

batches = __get_batch_sizes(file_count)
# add len(batches) new tasks, -1 to account for this task
agent.add_tasks_in_progress(job, len(batches) - 1)
batch_sizes = __get_batch_sizes(file_count)
# add len(batch_sizes) new tasks, -1 to account for this task
agent.add_tasks_in_progress(job, len(batch_sizes) - 1)

batched_files = (
files[batch_end - batch_size : batch_end]
for batch_end, batch_size in zip(
accumulate(batch_sizes), batch_sizes
)
)

for batch in batches:
for batch_files in batched_files:
agent.queue.enqueue(
run_yara_batch,
job_id,
iterator,
batch,
batch_files,
job_timeout=app_config.rq.job_timeout,
)

agent.db.dataset_query_done(job_id)
agent.db.remove_queryresult(job.internal_id)


def run_yara_batch(job_id: JobId, iterator: str, batch_size: int) -> None:
def run_yara_batch(job_id: JobId, batch_files: List[str]) -> None:
"""Actually scans files, and updates a database with the results."""
with job_context(job_id) as agent:
job = agent.db.get_job(job_id)
if job.status == "cancelled":
logging.info("Job was cancelled, returning...")
return

pop_result = agent.ursa.pop(iterator, batch_size)
logging.info("job %s: Pop successful: %s", job_id, pop_result)
if pop_result.was_locked:
# Iterator is currently locked, re-enqueue self
agent.queue.enqueue(
run_yara_batch,
job_id,
iterator,
batch_size,
job_timeout=app_config.rq.job_timeout,
)
return

agent.execute_yara(job, pop_result.files)
agent.execute_yara(job, batch_files)
agent.add_tasks_in_progress(job, -1)
Loading