Skip to content

Commit

Permalink
Merge pull request #35 from lsloan/upsteam_master
Browse files Browse the repository at this point in the history
  • Loading branch information
lsloan authored Nov 6, 2023
2 parents d067ab8 + 387de3a commit d69d77b
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 120 deletions.
25 changes: 14 additions & 11 deletions migration/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
from urllib.parse import parse_qs, urlparse

import httpx
import trio
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt
)
import trio


logger = logging.getLogger(__name__)

Expand All @@ -35,11 +34,11 @@ class API:
client: httpx.AsyncClient

def __init__(
self,
url: str,
key: str,
endpoint_type: EndpointType = EndpointType.REST,
timeout: float = 10.0
self,
url: str,
key: str,
endpoint_type: EndpointType = EndpointType.REST,
timeout: float = 10.0
):
headers = {'Authorization': f'Bearer {key}'}
timeoutsConfiguration = httpx.Timeout(timeout, pool=None)
Expand All @@ -66,7 +65,8 @@ def get_next_page_params(resp: httpx.Response) -> dict[str, Any] | None:
before_sleep=before_sleep_log(logger, logging.WARN),
sleep=trio.sleep
)
async def get(self, url: str, params: dict[str, Any] | None = None) -> GetResponse:
async def get(self, url: str,
params: dict[str, Any] | None = None) -> GetResponse:
resp = await self.client.get(url=url, params=params)
resp.raise_for_status()
data = resp.json()
Expand All @@ -83,20 +83,23 @@ async def get(self, url: str, params: dict[str, Any] | None = None) -> GetRespon
async def put(self, url: str, params: dict[str, Any] | None = None) -> Any:
resp = await self.client.put(url=url, params=params)
if resp.status_code == httpx.codes.UNPROCESSABLE_ENTITY:
logger.warning(f"HTTP {resp.status_code}: PUT {resp.url}; response: '{resp.text}'")
logger.warning(
f'HTTP {resp.status_code}: PUT {resp.url}; '
f'response: {repr(resp.text)}')
return None
resp.raise_for_status()
return resp.json()

async def get_results_from_pages(
self, endpoint: str, params: dict[str, Any] | None = None, page_size: int = 50, limit: int | None = None
self, endpoint: str, params: dict[str, Any] | None = None,
page_size: int = 50, limit: int | None = None
) -> list[dict[str, Any]]:
extra_params: dict[str, Any]
if params is not None:
extra_params = params
else:
extra_params = {}
extra_params.update({ 'per_page': page_size })
extra_params.update({'per_page': page_size})

more_pages = True
page_num = 1
Expand Down
2 changes: 1 addition & 1 deletion migration/data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from abc import ABC
from dataclasses import dataclass


@dataclass(frozen=True)
Expand Down
5 changes: 3 additions & 2 deletions migration/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ class DB:
def __init__(self, dialect: Dialect, params: DBParams):
params['password'] = quote_plus(params['password'])
core_string = '{user}:{password}@{host}:{port}/{name}'.format(**params)
self.engine = sqlalchemy.create_engine(f'{dialect.value}://{core_string}')
self.engine = sqlalchemy.create_engine(
f'{dialect.value}://{core_string}')

def get_connection(self) -> sqlalchemy.engine.Connection:
if self.connection is None:
self.connection = self.engine.connect()
return self.connection

def close_connection(self) -> None:
if self.connection is not None:
self.connection.close()
Expand Down
7 changes: 4 additions & 3 deletions migration/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ def run():
env_file_name: str = os.path.join(root_dir, 'env')

if os.path.exists(env_file_name):
logger.info(f'Setting environment from file "{env_file_name}".')
logger.info(f'Setting environment from file {repr(env_file_name)}.')
load_dotenv(env_file_name, verbose=True)
else:
logger.info(f'File "{env_file_name}" not found. '
'Using existing environment.')
logger.info(f'File {repr(env_file_name)} not found. '
'Using existing environment.')

logger.info('Parameters from environment…')

Expand Down Expand Up @@ -227,5 +227,6 @@ def run():

logger.info('Migration complete.')


if '__main__' == __name__:
run()
70 changes: 43 additions & 27 deletions migration/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from db import DB
from utils import chunk_integer, time_execution


logger = logging.getLogger(__name__)


Expand All @@ -31,7 +30,8 @@ async def get_tools_installed_in_account(self) -> list[ExternalTool]:
pass

@abstractmethod
async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
async def get_courses_in_terms(self, term_ids: list[int],
limit: int | None = None) -> list[Course]:
pass


Expand All @@ -52,23 +52,27 @@ async def get_term_names(self, term_ids: list[int]) -> Dict[int, str]:
return term_names

async def get_tools_installed_in_account(self) -> list[ExternalTool]:
params = {"include_parents": True}
results = await self.api.get_results_from_pages(f'/accounts/{self.account_id}/external_tools', params)
tools = [ExternalTool(id=tool_dict['id'], name=tool_dict['name']) for tool_dict in results]
params = {'include_parents': True}
results = await self.api.get_results_from_pages(
f'/accounts/{self.account_id}/external_tools', params)
tools = [ExternalTool(id=tool_dict['id'], name=tool_dict['name']) for
tool_dict in results]
return tools

@time_execution
async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
async def get_courses_in_terms(self, term_ids: list[int],
limit: int | None = None) -> list[Course]:
limit_chunks = None
if limit is not None:
limit_chunks = chunk_integer(limit, len(term_ids))

results: list[dict[str, Any]] = []
for i, term_id in enumerate(term_ids):
limit_for_term = limit_chunks[i] if limit_chunks is not None else None
limit_for_term = limit_chunks[
i] if limit_chunks is not None else None
term_results = await self.api.get_results_from_pages(
f'/accounts/{self.account_id}/courses',
params={ 'enrollment_term_id': term_id },
params={'enrollment_term_id': term_id},
page_size=50,
limit=limit_for_term
)
Expand Down Expand Up @@ -104,14 +108,16 @@ async def get_tools_installed_in_account(self) -> list[ExternalTool]:

async def get_subaccount_ids(self) -> list[int]:
results = await self.api.get_results_from_pages(
f'/accounts/{self.account_id}/sub_accounts', { 'recursive': True }
f'/accounts/{self.account_id}/sub_accounts',
{'recursive': True}
)
sub_account_ids = [result['id'] for result in results]
logger.debug(sub_account_ids)
return sub_account_ids

@time_execution
async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = None) -> list[Course]:
async def get_courses_in_terms(self, term_ids: list[int],
limit: int | None = None) -> list[Course]:
account_ids = [self.account_id] + await self.get_subaccount_ids()

conn = self.db.get_connection()
Expand All @@ -127,14 +133,15 @@ async def get_courses_in_terms(self, term_ids: list[int], limit: int | None = No
where t.canvas_id in :term_ids
and a.canvas_id in :account_ids
and c.workflow_state != 'deleted'
{"limit :result_limit" if limit is not None else ''};
{'limit :result_limit' if limit is not None else ''};
''')
extra_bind_params = {}
if limit is not None:
extra_bind_params['result_limit'] = limit
statement = statement.bindparams(
sqlalchemy.bindparam('term_ids', value=term_ids, expanding=True),
sqlalchemy.bindparam('account_ids', value=account_ids, expanding=True),
sqlalchemy.bindparam('account_ids', value=account_ids,
expanding=True),
**extra_bind_params
)
results = conn.execute(statement).all()
Expand All @@ -157,7 +164,8 @@ class CourseManager:
api: API

@staticmethod
def find_tab_by_tool_id(tool_id: int, tabs: list[ExternalToolTab]) -> ExternalToolTab | None:
def find_tab_by_tool_id(tool_id: int, tabs: list[
ExternalToolTab]) -> ExternalToolTab | None:
for tab in tabs:
if tab.tool_id == tool_id:
return tab
Expand All @@ -179,7 +187,8 @@ def create_course_log_message(self, message: str) -> str:
return f'{self.course} | {message}'

async def get_tool_tabs(self) -> list[ExternalToolTab]:
results = await self.api.get_results_from_pages(f'/courses/{self.course.id}/tabs')
results = await self.api.get_results_from_pages(
f'/courses/{self.course.id}/tabs')

tabs: list[ExternalToolTab] = []
for result in results:
Expand All @@ -188,10 +197,11 @@ async def get_tool_tabs(self) -> list[ExternalToolTab]:
tabs.append(CourseManager.convert_data_to_tool_tab(result))
return tabs

async def update_tool_tab(self, tab: ExternalToolTab, is_hidden: bool, position: int | None = None):
params: dict[str, Any] = { "hidden": is_hidden }
async def update_tool_tab(self, tab: ExternalToolTab, is_hidden: bool,
position: int | None = None):
params: dict[str, Any] = {'hidden': is_hidden}
if position is not None:
params.update({ "position": position })
params.update({'position': position})

result = await self.api.put(
f'/courses/{self.course.id}/tabs/{tab.id}',
Expand All @@ -201,47 +211,53 @@ async def update_tool_tab(self, tab: ExternalToolTab, is_hidden: bool, position:
return CourseManager.convert_data_to_tool_tab(result)

async def replace_tool_tab(
self, source_tab: ExternalToolTab, target_tab: ExternalToolTab
self, source_tab: ExternalToolTab, target_tab: ExternalToolTab
) -> tuple[ExternalToolTab, ExternalToolTab]:
logger.debug([source_tab, target_tab])

# Source tool is hidden in course, don't do anything
if source_tab.is_hidden:
logger.debug(self.create_course_log_message(
f'Skipping replacement for {[source_tab, target_tab]}; source tool is hidden.'
f'Skipping replacement for {[source_tab, target_tab]}; '
'source tool is hidden.'
))
return (source_tab, target_tab)
else:
if not target_tab.is_hidden:
logger.warning(self.create_course_log_message(
f'Both tools ({[source_tab, target_tab]}) are currently available. ' +
'Rolling back will hide the target tool!'
f'Both tools ({[source_tab, target_tab]}) are currently '
'available. Rolling back will hide the target tool!'
))
logger.debug(self.create_course_log_message(
f'Skipping update for {target_tab}; tool is already available.'
f'Skipping update for {target_tab}; '
'tool is already available.'
))
new_target_tab = target_tab
else:
target_position = source_tab.position
new_target_tab = await self.update_tool_tab(tab=target_tab, is_hidden=False, position=target_position)
new_target_tab = await self.update_tool_tab(
tab=target_tab, is_hidden=False, position=target_position)
if new_target_tab is not None:
logger.info(self.create_course_log_message(
f"Made available target tool in course's navigation: {new_target_tab}"
'Made available target tool in course navigation: '
f'{new_target_tab}'
))

# Always hide the source tool if it's available
new_source_tab = await self.update_tool_tab(tab=source_tab, is_hidden=True)
new_source_tab = await self.update_tool_tab(tab=source_tab,
is_hidden=True)
if new_source_tab is not None:
logger.info(self.create_course_log_message(
f"Hid source tool in course's navigation: {new_source_tab}"
f'Hid source tool in course navigation: {new_source_tab}'
))

return (new_source_tab, new_target_tab)


class AccountManagerFactory:

def get_manager(self, account_id: int, api: API, db: DB | None) -> AccountManagerBase:
def get_manager(self, account_id: int, api: API,
db: DB | None) -> AccountManagerBase:
if db is not None:
return WarehouseAccountManager(account_id, db, api)
else:
Expand Down
Loading

0 comments on commit d69d77b

Please sign in to comment.