diff --git a/terraform/src/lambda.py b/terraform/src/lambda.py index 5f4350c..709d153 100644 --- a/terraform/src/lambda.py +++ b/terraform/src/lambda.py @@ -82,42 +82,55 @@ def create_group(): def get_groups(): try: if is_admin: + # Scan the entire groups table response = group_table.scan() - result = response["Items"] - while "LastEvaluateKey" in response: + groups = response["Items"] + while "LastEvaluatedKey" in response: response = group_table.scan(ExclusiveStartKey=response["LastEvaluatedKey"]) - result.extend(response["Items"]) + groups.extend(response["Items"]) - # get totals used + # Scan the entire tickets table + ticket_response = ticket_table.scan() + tickets = ticket_response["Items"] + while "LastEvaluatedKey" in ticket_response: + ticket_response = ticket_table.scan(ExclusiveStartKey=ticket_response["LastEvaluatedKey"]) + tickets.extend(ticket_response["Items"]) + + # Process tickets in memory to count ticket types for each group data = [] - for g in response["Items"]: - total = {} - for ticket_type in ['adult', 'child', 'vehicle']: - response = ticket_table.scan( - FilterExpression="group_id = :group_id AND ticket_type = :ticket_type", - ExpressionAttributeValues={ - ":group_id": g['group_id'], - ":ticket_type": ticket_type - } - ) - total[ticket_type] = response.get('Count', 0) - g[ticket_type + '_used'] = total[ticket_type] - data.append(g) + for group in groups: + group_id = group['group_id'] + total = {'adult': 0, 'child': 0, 'vehicle': 0} + for ticket in tickets: + if ticket['group_id'] == group_id: + ticket_type = ticket['ticket_type'] + if ticket_type in total: + total[ticket_type] += 1 + + group['adult_used'] = total['adult'] + group['child_used'] = total['child'] + group['vehicle_used'] = total['vehicle'] + + data.append(group) body = { "groups": data } + return build_response(200, body) + else: body = { "Operation": "GET_GROUPS", "Message": "ACCESS_DENIED" } return build_response(403, body) + except: logger.exception("ERROR - GET_GROUPS") + def get_group(group_id): try: if (is_admin or (user_id == group_id)) and group_id: @@ -470,4 +483,4 @@ class DecimalEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, decimal.Decimal): return str(o) - return super().default(o) \ No newline at end of file + return super().default(o)