Skip to content

Commit

Permalink
retrieve all elements from graphql lists
Browse files Browse the repository at this point in the history
  • Loading branch information
vavalomi committed Jan 19, 2022
1 parent 30172f2 commit 77432f6
Show file tree
Hide file tree
Showing 17 changed files with 792 additions and 186 deletions.
2 changes: 1 addition & 1 deletion ssaw/__about__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__title__ = 'ssaw'
__version__ = '0.6.2'
__version__ = '0.6.3'
__description__ = 'Survey Solutions API Wrapper'
__url__ = 'https://github.com/vavalomi/ssaw'
__author__ = 'Zurab Sajaia'
Expand Down
50 changes: 39 additions & 11 deletions ssaw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,29 @@

from sgqlc.endpoint.requests import RequestsEndpoint
from sgqlc.operation import Operation
from sgqlc.types import Arg, Int, Variable

from .exceptions import (
ForbiddenError, GraphQLError,
NotAcceptableError, NotFoundError, UnauthorizedError
)
from .headquarters import Client
from .headquarters_schema import HeadquartersMutation
from .headquarters_schema import HeadquartersMutation, HeadquartersQuery


GRAPHQL_PAGE_SIZE = 100


class HQBase(object):
_apiprefix: str = ""

def __init__(self, client: Client, workspace: str = None) -> None:
self._hq = client
self.workspace = workspace or client.workspace
self.workspace = client.workspace if workspace is None else workspace

@property
def url(self) -> str:
if self.workspace:
path = '/' + self.workspace + self._apiprefix
else:
path = self._apiprefix
return self._hq.baseurl + path
return self._hq.baseurl + '/' + self.workspace + self._apiprefix

def _make_call(self, method: str, path: str, filepath: str = None, parser=None, use_login_session=False, **kwargs):
if use_login_session:
Expand All @@ -53,13 +53,17 @@ def _make_call(self, method: str, path: str, filepath: str = None, parser=None,
return response.content

def _make_call_with_login(self, method: str, path: str, **kwargs):
url = f"{self._hq.baseurl}/Account/LogOn"
with Session() as login_session:
response = login_session.request(method="post",
url=f"{self._hq.baseurl}/Account/LogOn",
url=url,
data={"UserName": self._hq.session.auth[0],
"Password": self._hq.session.auth[1]})
if response.status_code < 300:
return login_session.request(method=method, url=path, **kwargs)
if response.url == url: # unsuccessful logon will return 200 but will not get redirected
raise UnauthorizedError()
else:
return login_session.request(method=method, url=path, **kwargs)
else:
self._process_status_code(response)

Expand All @@ -73,11 +77,35 @@ def _call_mutation(self, method_name: str, fields: list = [], **kwargs):
res = (op + cont)
return getattr(res, method_name)

def _make_graphql_call(self, op, **kwargs):
@staticmethod
def _graphql_query_operation(selector_name: str, args: dict):
op = Operation(HeadquartersQuery, variables={'take': Arg(Int), 'skip': Arg(Int), })
getattr(op, selector_name)(take=Variable('take'), skip=Variable('skip'), **args)

return op

def _get_full_list(self, op: Operation, selector_name: str, skip: int = 0, take: int = None):
query = bytes(op).decode('utf-8')

returned_count = 0
local_take = returned_count + GRAPHQL_PAGE_SIZE if take is None else take
while returned_count < local_take:
page_size = min(GRAPHQL_PAGE_SIZE, local_take - returned_count)
cont = self._make_graphql_call(query, variables={'take': page_size, 'skip': skip + returned_count})
res = getattr((op + cont), selector_name).nodes
max_index = min(len(res), local_take - returned_count)
if max_index == 0:
return
yield from res[:max_index]
returned_count += page_size
local_take = returned_count + GRAPHQL_PAGE_SIZE if take is None else take

def _make_graphql_call(self, query, variables: dict = {}, **kwargs):
if "session" not in kwargs:
kwargs["session"] = self._hq.session
endpoint = RequestsEndpoint(self._hq.baseurl + '/graphql', **kwargs)
cont = endpoint(op)

cont = endpoint(query, variables=variables)
errors = cont.get('errors')
if not errors:
return cont
Expand Down
25 changes: 7 additions & 18 deletions ssaw/interviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
from typing import Generator, Union
from uuid import UUID

from sgqlc.operation import Operation

from .base import HQBase
from .headquarters_schema import (
CalendarEvent,
HeadquartersQuery,
Interview,
InterviewsFilter,
)
Expand All @@ -22,7 +19,7 @@ class InterviewsApi(HQBase):

@fix_qid(expects={'questionnaire_id': 'hex'})
def get_list(self, fields: list = [], order=None,
skip: int = None, take: int = None, where: InterviewsFilter = None,
skip: int = 0, take: int = None, where: InterviewsFilter = None,
include_calendar_events: Union[list, tuple, bool] = False, **kwargs
) -> Generator[Interview, None, None]:
"""Get list of interviews
Expand All @@ -43,10 +40,6 @@ def get_list(self, fields: list = [], order=None,
}
if order:
interview_args["order"] = order_object("InterviewSort", order)
if skip:
interview_args["skip"] = skip
if take:
interview_args["take"] = take

if where or kwargs:
interview_args['where'] = filter_object("InterviewsFilter", where=where, **kwargs)
Expand All @@ -62,19 +55,16 @@ def get_list(self, fields: list = [], order=None,
'status',
]

op = Operation(HeadquartersQuery)
q = op.interviews(**interview_args)
op = self._graphql_query_operation('interviews', interview_args)
q = op.interviews
q.nodes.__fields__(*fields)
if include_calendar_events:
if type(include_calendar_events) in [list, tuple]:
q.nodes.calendar_event.__fields__(*include_calendar_events)
else:
q.nodes.calendar_event.__fields__()
cont = self._make_graphql_call(op)

res = (op + cont).interviews

yield from res.nodes
yield from self._get_full_list(op, 'interviews', skip=skip, take=take)

def get_info(self, interview_id: UUID) -> InterviewAnswers:
path = self.url + '/{}'.format(interview_id)
Expand Down Expand Up @@ -155,11 +145,10 @@ def comment(self, interview_id, comment, question_id: str = None, variable: str

if variable:
path = self.url + '/{}/comment-by-variable/{}'.format(interview_id, variable)
elif question_id:
path = self.url + '/{}/comment/{}'.format(interview_id, question_id)
else:
if question_id:
path = self.url + '/{}/comment/{}'.format(interview_id, question_id)
else:
raise TypeError("comment() either 'variable' or 'question_id' argument is required")
raise TypeError("comment() either 'variable' or 'question_id' argument is required")

self._make_call('post', path, params=params)

Expand Down
20 changes: 4 additions & 16 deletions ssaw/questionnaires.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from tempfile import TemporaryDirectory
from uuid import UUID

from sgqlc.operation import Operation

from .base import HQBase
from .headquarters_schema import HeadquartersQuery
from .interviews import InterviewsApi
from .models import AssignmentWebLink, QuestionnaireDocument

Expand All @@ -18,7 +15,7 @@ class QuestionnairesApi(HQBase):
_apiprefix = "/api/v1/questionnaires"

def get_list(self, fields: list = [], questionnaire_id: str = None, version: int = None,
skip: int = None, take: int = None):
skip: int = 0, take: int = None):
if not fields:
fields = [
"id",
Expand All @@ -35,20 +32,11 @@ def get_list(self, fields: list = [], questionnaire_id: str = None, version: int
q_args["id"] = questionnaire_id
if version:
q_args["version"] = version
if skip:
q_args["skip"] = skip
if take:
q_args["take"] = take

op = Operation(HeadquartersQuery)
q = op.questionnaires(**q_args)
q.nodes.__fields__(*fields)

cont = self._make_graphql_call(op)

res = (op + cont).questionnaires
op = self._graphql_query_operation('questionnaires', q_args)
op.questionnaires.nodes.__fields__(*fields)

yield from res.nodes
yield from self._get_full_list(op, 'questionnaires', skip=skip, take=take)

def statuses(self):
path = self.url + '/statuses'
Expand Down
19 changes: 8 additions & 11 deletions ssaw/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sgqlc.operation import Operation

from .base import HQBase
from .headquarters import Client
from .headquarters_schema import HeadquartersQuery, UsersFilterInput
from .headquarters_schema import User as GraphQLUser, Viewer
from .models import InterviewerAction, User, UserRole
Expand All @@ -16,28 +17,24 @@ class UsersApi(HQBase):

_apiprefix = "/api/v1"

def __init__(self, client: Client) -> None:
super().__init__(client)

def get_list(self, fields: list = [],
order=None, skip: int = None, take: int = None,
order=None, skip: int = 0, take: int = None,
where: UsersFilterInput = None, **kwargs) -> Generator[GraphQLUser, None, None]:

q_args = {
}
if order:
q_args["order"] = order_object("UsersSortInput", order)
if skip:
q_args["skip"] = skip
if take:
q_args["take"] = take
if where or kwargs:
q_args['where'] = filter_object("UsersFilterInput", where=where, **kwargs)

op = Operation(HeadquartersQuery)
q = op.users(**q_args)
q.nodes.__fields__(*fields)
cont = self._make_graphql_call(op)
res = (op + cont).users
op = self._graphql_query_operation('users', q_args)
op.users.nodes.__fields__(*fields)

yield from res.nodes
yield from self._get_full_list(op, 'users', skip=skip, take=take)

def get_info(self, id):
path = self._url_users + '/{}'.format(id)
Expand Down
12 changes: 10 additions & 2 deletions tests/headquarters/test_headquarters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pytest import raises

from ssaw import Client, MapsApi, QuestionnairesApi, headquarters_schema as schema
from ssaw import AssignmentsApi, Client, MapsApi, QuestionnairesApi, headquarters_schema as schema
from ssaw.exceptions import GraphQLError, IncompleteQuestionnaireIdError, UnauthorizedError
from ssaw.models import Group, QuestionnaireDocument
from ssaw.utils import filter_object, fix_qid, get_properties, order_object, parse_qidentity
Expand All @@ -16,9 +16,12 @@
def test_headquarters_unathorized():
s = Client('https://demo.mysurvey.solutions/', "aa", "")

with raises(UnauthorizedError):
with raises(UnauthorizedError): # graphql endpoint
next(QuestionnairesApi(s).get_list())

with raises(UnauthorizedError): # rest endpoint
next(AssignmentsApi(s).get_list())


@my_vcr.use_cassette()
def test_headquarters_graphql_error(session):
Expand All @@ -27,6 +30,11 @@ def test_headquarters_graphql_error(session):
next(MapsApi(session, workspace="dddd").get_list())


def test_headquarters_user_session_login(session, params):
with raises(UnauthorizedError):
QuestionnairesApi(session).download_web_links(params['TemplateId'], params['TemplateVersion'])


def test_utils_parse_qidentity():

random_guid = "f6a5bd80-fdb4-40b6-8759-0f7531c4a3df"
Expand Down
16 changes: 14 additions & 2 deletions tests/headquarters/test_interviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@
from . import my_vcr


@my_vcr.use_cassette(decode_compressed_response=True)
@my_vcr.use_cassette(decode_compressed_response=False)
def test_interview_list(session, params):
r = InterviewsApi(session).get_list(questionnaire_id=to_hex(params['TemplateId']))
large_take = 103
r = InterviewsApi(session).get_list(take=large_take, questionnaire_id=to_hex(params['TemplateId']))
assert isinstance(r, types.GeneratorType)
assert isinstance(next(r), Interview), "There should be a list of Interview objects"
assert len(list(r)) == large_take - 1, "We have to have all items returned"

r = list(InterviewsApi(session).get_list(take=2,
order={'created_date': 'ASC'},
fields=['created_date'],
questionnaire_id=to_hex(params['TemplateId'])))
assert r[0].created_date < r[1].created_date


@my_vcr.use_cassette()
Expand Down Expand Up @@ -92,6 +100,10 @@ def test_interview_comment(session, params):

# no way to check comments for now, make sure there are no exceptions
InterviewsApi(session).comment(params['InterviewId'], comment="aaa", variable="sex")
InterviewsApi(session).comment(params['InterviewId'], comment="aaa", question_id="fe9719791f0bde796f28d74e66d67d12")

with raises(NotAcceptableError):
InterviewsApi(session).comment(params['InterviewId'], comment="aaa", variable="sex", roster_vector=[1])


@my_vcr.use_cassette()
Expand Down
3 changes: 2 additions & 1 deletion tests/headquarters/test_questionnaires.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def test_interview_statuses(session, statuses):

@my_vcr.use_cassette()
def test_questionnaire_list(session, params):
TOTAL_QUESTIONNAIRES = 103
response = QuestionnairesApi(session).get_list()
assert isinstance(response, types.GeneratorType)
assert isinstance(next(response), Questionnaire), "Should be list of Questionnaire objects"
assert len(list(response)) == 12, "We have to have all items returned"
assert len(list(response)) == TOTAL_QUESTIONNAIRES - 1, "We have to have all items returned"

response = QuestionnairesApi(session).get_list(skip=5)
assert next(response).version == 7
Expand Down
4 changes: 4 additions & 0 deletions tests/headquarters/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ def test_user_list(admin_session):
response = UsersApi(admin_session).get_list()
assert isinstance(response, GeneratorType)
assert isinstance(next(response), GraphQLUser), "Should be list of User objects"
assert len(list(response)) == 109, "We have to have all items returned"

first_user = next(UsersApi(admin_session).get_list(order=['creation_date'], take=1))
assert first_user.role == 'ADMINISTRATOR'


@my_vcr.use_cassette()
Expand Down
Loading

0 comments on commit 77432f6

Please sign in to comment.