Skip to content

Commit

Permalink
Fix ingest to use auth token
Browse files Browse the repository at this point in the history
  • Loading branch information
barrycarey committed May 18, 2024
1 parent 4981915 commit 18bcdf9
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 27 deletions.
4 changes: 2 additions & 2 deletions redditrepostsleuth/core/util/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,8 @@ def get_next_ids(start_id, count):
return ids

def generate_next_ids(start_id, count):
start_num = base36decode(start_id)
for id_num in range(start_num, start_num + count):
#start_num = base36decode(start_id)
for id_num in range(start_id, start_id + count):
yield base36encode(id_num)


Expand Down
1 change: 1 addition & 0 deletions redditrepostsleuth/core/util/reddithelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Text, Optional, List

import requests
from asyncpraw import Reddit as AsyncReddit
from praw import Reddit
from praw.exceptions import APIException
from praw.models import Subreddit
Expand Down
120 changes: 95 additions & 25 deletions redditrepostsleuth/ingestsvc/ingestsvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import time
from asyncio import ensure_future, gather, run, TimeoutError, CancelledError
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Union, Generator

from aiohttp import ClientSession, ClientTimeout, ClientConnectorError, TCPConnector, \
ServerDisconnectedError, ClientOSError
from praw import Reddit

from redditrepostsleuth.core.celery.tasks.ingest_tasks import save_new_post, save_new_posts
from redditrepostsleuth.core.config import Config
Expand All @@ -20,7 +21,6 @@
from redditrepostsleuth.core.model.misc_models import BatchedPostRequestJob, JobStatus
from redditrepostsleuth.core.util.helpers import get_reddit_instance, get_newest_praw_post_id, get_next_ids, \
base36decode, generate_next_ids
from redditrepostsleuth.core.util.objectmapping import reddit_submission_to_post
from redditrepostsleuth.core.util.utils import build_reddit_query_string

log = configure_logger(name='redditrepostsleuth')
Expand Down Expand Up @@ -88,6 +88,9 @@ async def fetch_page_as_job(job: BatchedPostRequestJob, session: ClientSession)
elif resp.status == 429:
log.warning('Data API Rate Limit')
job.status = JobStatus.RATELIMIT
elif resp.status == 500:
log.warning('Reddit Server Error')
job.status = JobStatus.ERROR
else:
log.warning('Unexpected request status %s - %s', resp.status, job.url)
job.status = JobStatus.ERROR
Expand All @@ -107,29 +110,47 @@ async def fetch_page_as_job(job: BatchedPostRequestJob, session: ClientSession)

return job

async def ingest_range(newest_post_id: Union[str, int], oldest_post_id: Union[str, int], alt_headers: dict = None) -> None:
if isinstance(newest_post_id, str):
newest_post_id = base36decode(newest_post_id)

if isinstance(oldest_post_id, str):
oldest_post_id = base36decode(oldest_post_id)

missing_ids = generate_next_ids(oldest_post_id, newest_post_id - oldest_post_id)
log.info('Total missing IDs: %s', newest_post_id - oldest_post_id)
await ingest_sequence(missing_ids, alt_headers=alt_headers)


async def ingest_range(newest_post_id: str, oldest_post_id: str) -> None:

async def ingest_sequence(ids: Union[list[int], Generator[int, None, None]], alt_headers: dict = None) -> None:
"""
Take a range of posts and attempt to ingest them.
Mainly used to catch any missed posts while script is down
:param newest_post_id: Most recent Post ID, usually pulled from Praw
:param oldest_post_id: Oldest post ID, is usually the most recent post ingested in the database
"""
missing_ids = generate_next_ids(oldest_post_id, base36decode(newest_post_id) - base36decode(oldest_post_id))
batch = []

if isinstance(ids, list):
def id_gen(list_of_ids):
for id in list_of_ids:
yield id
ids = id_gen(ids)

saved_posts = 0
tasks = []
conn = TCPConnector(limit=0)
async with ClientSession(connector=conn, headers=HEADERS) as session:

async with ClientSession(connector=conn, headers=alt_headers or HEADERS) as session:
while True:
try:
chunk = list(itertools.islice(missing_ids, 100))
chunk = list(itertools.islice(ids, 100))
except StopIteration:
break

#url = f'{config.util_api}/reddit/info?submission_ids={build_reddit_query_string(chunk)}'
url = f'https://api.reddit.com/api/info?id={build_reddit_query_string(chunk)}'
url = f'https://oauth.reddit.com/api/info?id={build_reddit_query_string(chunk)}'
job = BatchedPostRequestJob(url, chunk, JobStatus.STARTED)
tasks.append(ensure_future(fetch_page_as_job(job, session)))
if len(tasks) >= 50 or len(chunk) == 0:
Expand All @@ -151,6 +172,7 @@ async def ingest_range(newest_post_id: str, oldest_post_id: str) -> None:
if post['data']['removed_by_category'] in REMOVAL_REASONS_TO_SKIP:
continue
posts_to_save.append(post['data'])
saved_posts += 1

else:
tasks.append(ensure_future(fetch_page_as_job(j, session)))
Expand All @@ -167,6 +189,7 @@ async def ingest_range(newest_post_id: str, oldest_post_id: str) -> None:
if len(chunk) == 0:
break

log.info('Saved posts: %s', saved_posts)
log.info('Finished backfill ')


Expand All @@ -179,25 +202,60 @@ def queue_posts_for_ingest(posts: List[Post]):
for post in posts:
save_new_post.apply_async((post,))

def get_request_delay(submissions: list[dict], current_req_delay: int, target_ingest_delay: int = 30) -> int:
ingest_delay = datetime.utcnow() - datetime.utcfromtimestamp(
submissions[0]['data']['created_utc'])
log.info('Current Delay: %s', ingest_delay)

if ingest_delay.seconds > target_ingest_delay:
new_delay = current_req_delay - 1 if current_req_delay > 0 else 0
else:
new_delay = current_req_delay + 1

log.info('New Delay: %s', new_delay)
return new_delay

def get_auth_headers(reddit: Reddit) -> dict:
"""
For praw to make a call.
Hackey but I'd rather let Praw deal handle the tokens
:param reddit:
:return:
"""
reddit.user.me()
return {**HEADERS, **{'Authorization': f'Bearer {reddit.auth._reddit._core._authorizer.access_token}'}}

async def main() -> None:
log.info('Starting post ingestor')
reddit = get_reddit_instance(config)
allowed_submission_delay_seconds = 30
missed_id_retry_count = 2000

newest_id = get_newest_praw_post_id(reddit)
uowm = UnitOfWorkManager(get_db_engine(config))
auth_headers = get_auth_headers(reddit)

with uowm.start() as uow:
oldest_post = uow.posts.get_newest_post()
oldest_id = oldest_post.post_id

await ingest_range(newest_id, oldest_id)
await ingest_range(newest_id, oldest_id, alt_headers=auth_headers)

delay = 0
request_delay = 0
missed_ids = [] # IDs that we didn't get results back for or had a removal reason
last_token_refresh = datetime.utcnow()
while True:

if (datetime.utcnow() - last_token_refresh).seconds > 600:
log.info('Refreshing token')
auth_headers = get_auth_headers(reddit)
last_token_refresh = datetime.utcnow()

ids_to_get = get_next_ids(newest_id, 100)
#url = f'{config.util_api}/reddit/info?submission_ids={build_reddit_query_string(ids_to_get)}'
url = f'https://api.reddit.com/api/info?id={build_reddit_query_string(ids_to_get)}'
async with ClientSession(headers=HEADERS) as session:

url = f'https://oauth.reddit.com/api/info?id={build_reddit_query_string(ids_to_get)}'
async with ClientSession(headers=auth_headers) as session:
try:
log.debug('Sending fetch request')
results = await fetch_page(url, session)
Expand All @@ -215,18 +273,12 @@ async def main() -> None:
continue

res_data = json.loads(results)

if not res_data or not len(res_data['data']['children']):
log.info('No results')
continue

log.info('%s results returned from API', len(res_data['data']['children']))
if len(res_data['data']['children']) < 91:
delay += 1
log.debug('Delay increased by 1. Current delay: %s', delay)
else:
if delay > 0:
delay -= 1
log.debug('Delay decreased by 1. Current delay: %s', delay)

posts_to_save = []
for post in res_data['data']['children']:
Expand All @@ -235,17 +287,35 @@ async def main() -> None:
posts_to_save.append(post['data'])

log.info('Sending %s posts to save queue', len(posts_to_save))
# queue_posts_for_ingest([reddit_submission_to_post(submission) for submission in posts_to_save])

queue_posts_for_ingest(posts_to_save)

ingest_delay = datetime.utcnow() - datetime.utcfromtimestamp(
res_data['data']['children'][0]['data']['created_utc'])
log.info('Current Delay: %s', ingest_delay)
request_delay = get_request_delay(res_data['data']['children'], request_delay, allowed_submission_delay_seconds)

newest_id = res_data['data']['children'][-1]['data']['id']

time.sleep(delay)

saved_ids = [x['id'] for x in posts_to_save]
missing_ids_in_this_req = list(set(ids_to_get).difference(saved_ids))
missed_ids += [base36decode(x) for x in missing_ids_in_this_req]
time.sleep(request_delay)

log.info('Missed IDs: %s', len(missed_ids))
if len(missed_ids) > missed_id_retry_count:
await ingest_sequence(missed_ids, alt_headers=auth_headers)
missed_ids = []


async def temp_backfill():
reddit = get_reddit_instance(config)
uowm = UnitOfWorkManager(get_db_engine(config))
get_newest_praw_post_id(reddit)
new_headers = {**HEADERS, **{'Authorization': f'Bearer {reddit.auth._reddit._core._authorizer.access_token}'}}

with uowm.start() as uow:
oldest_post = uow.posts.get_newest_post()
oldest_id = oldest_post.post_id

await ingest_range(oldest_id, '1ctkrlw', alt_headers=new_headers)
if __name__ == '__main__':
run(main())

0 comments on commit 18bcdf9

Please sign in to comment.