diff --git a/aimmo-game-worker/avatar_runner.py b/aimmo-game-worker/avatar_runner.py index 380f1a222..9a0540762 100644 --- a/aimmo-game-worker/avatar_runner.py +++ b/aimmo-game-worker/avatar_runner.py @@ -11,6 +11,7 @@ import simulation.action as avatar_action import simulation.direction as direction +from print_collector import LogManager from simulation.action import WaitAction, Action from user_exceptions import InvalidActionException @@ -35,11 +36,12 @@ def add_actions_to_globals(): __metaclass__ = type restricted_globals = dict(__builtins__=safe_builtins) +log_manager = LogManager() restricted_globals['_getattr_'] = _getattr_ restricted_globals['_setattr_'] = _setattr_ restricted_globals['_getiter_'] = list -restricted_globals['_print_'] = PrintCollector +restricted_globals['_print_'] = log_manager.get_print_collector() restricted_globals['_write_'] = _write_ restricted_globals['__metaclass__'] = __metaclass__ restricted_globals['__name__'] = "Avatar" @@ -64,11 +66,13 @@ def _get_new_avatar(self, src_code): module = imp.new_module('avatar') # Create a temporary module to execute the src_code in module.__dict__.update(restricted_globals) - byte_code = compile_restricted(src_code, filename='', mode='exec') - exec(byte_code, restricted_globals) + try: + byte_code = compile_restricted(src_code, filename='', mode='exec') + exec(byte_code, restricted_globals) + except SyntaxWarning as w: + pass module.__dict__['Avatar'] = restricted_globals['Avatar'] - return module.Avatar() def _update_avatar(self, src_code): @@ -96,7 +100,6 @@ def _should_update(self, src_code): def process_avatar_turn(self, world_map, avatar_state, src_code): output_log = StringIO() - src_code = self.get_printed(src_code) avatar_updated = self._avatar_src_changed(src_code) try: @@ -104,15 +107,17 @@ def process_avatar_turn(self, world_map, avatar_state, src_code): sys.stderr = output_log self._update_avatar(src_code) action = self.decide_action(world_map, avatar_state) - + self.print_logs() # When an InvalidActionException is raised, the traceback might not contain # reference to the user's code as it can still technically be correct. so we # handle this case explicitly to avoid printing out unwanted parts of the traceback except InvalidActionException as e: + self.print_logs() print(e) action = WaitAction().serialise() except Exception as e: + self.print_logs() user_traceback = self.get_only_user_traceback() for trace in user_traceback: print(trace) @@ -125,27 +130,18 @@ def process_avatar_turn(self, world_map, avatar_state, src_code): sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ - logs = self.clean_logs(output_log.getvalue()) - - return {'action': action, 'log': logs, 'avatar_updated': avatar_updated} + return {'action': action, 'log': output_log.getvalue(), 'avatar_updated': avatar_updated} def decide_action(self, world_map, avatar_state): try: - action, printed = self.avatar.handle_turn(world_map, avatar_state) - print(printed) + action = self.avatar.handle_turn(world_map, avatar_state) if not isinstance(action, Action): raise InvalidActionException(action) return action.serialise() - except TypeError: + except TypeError as e: + print(e) raise InvalidActionException(None) - def clean_logs(self, logs): - getattr_pattern = "" - - clean_logs = re.sub(getattr_pattern, '', logs) - - return clean_logs - @staticmethod def get_only_user_traceback(): """ If the traceback does not contain any reference to the user code, found by '', @@ -159,20 +155,7 @@ def get_only_user_traceback(): return traceback_list[start_of_user_traceback:] @staticmethod - def get_printed(src_code): - """ This method adds ', printed' to the end of the handle_turn return statement. - This is due to the fact that restricted python's PrintCollector requires this - explicitly, in order to get whatever has been printed by the user's code. """ - src_code = src_code.split('\n') - new_src_code = [] - in_handle_turn = False - for line in src_code: - if "def handle_turn" == line.strip()[0:15]: - in_handle_turn = True - elif "def" == line.strip()[0:3]: - in_handle_turn = False - if "return" == line.strip()[0:6] and in_handle_turn: - line = line + ', printed' - new_src_code.append(line) - - return '\n'.join(new_src_code) + def print_logs(): + if not log_manager.is_empty(): + print(log_manager.get_logs(), end='') + log_manager.clear_logs() diff --git a/aimmo-game-worker/print_collector.py b/aimmo-game-worker/print_collector.py new file mode 100644 index 000000000..1f841f6fe --- /dev/null +++ b/aimmo-game-worker/print_collector.py @@ -0,0 +1,40 @@ +class LogManager(object): + """ Wrapper for the PrintCollector which allows logs to have + state. The class definition for the PrintCollector is + passed into the globals for the user's code before execution. """ + def __init__(self): + class PrintCollector(object): + """ Collect written text, and return it when called. """ + + def __init__(print_collector, _getattr_=None): + print_collector.logs = self.logs + print_collector._getattr_ = _getattr_ + + def write(print_collector, text): + print_collector.logs.append(text) + + def __call__(print_collector): + return ''.join(print_collector.logs) + + def _call_print(print_collector, *objects, **kwargs): + if kwargs.get('file', None) is None: + kwargs['file'] = print_collector + else: + print_collector._getattr_(kwargs['file'], 'write') + + print(*objects, **kwargs) + + self.logs = [] + self.print_collector = PrintCollector + + def get_print_collector(self): + return self.print_collector + + def get_logs(self): + return ''.join(self.logs) + + def is_empty(self): + return self.logs == [] + + def clear_logs(self): + self.logs = [] diff --git a/aimmo-game-worker/tests/test_avatar_runner.py b/aimmo-game-worker/tests/test_avatar_runner.py index 35ef7f47c..8e90fcea8 100644 --- a/aimmo-game-worker/tests/test_avatar_runner.py +++ b/aimmo-game-worker/tests/test_avatar_runner.py @@ -17,8 +17,8 @@ def test_runner_does_not_crash_on_code_errors(self): def handle_turn(self, world_map, avatar_state): assert False''' - runner = AvatarRunner(avatar=avatar, auto_update=False) - action = runner.process_avatar_turn(world_map={}, avatar_state={}, src_code='')['action'] + runner = AvatarRunner() + action = runner.process_avatar_turn(world_map={}, avatar_state={}, src_code=avatar)['action'] self.assertEqual(action, {'action_type': 'wait'}) def test_runner_updates_code_on_change(self): @@ -195,3 +195,55 @@ def handle_turn(self, world_map, avatar_state): self.assertFalse('/usr/src/app/' in response['log']) response = runner.process_avatar_turn(world_map={}, avatar_state={}, src_code=avatar2) self.assertFalse('/usr/src/app/' in response['log']) + + def test_print_collector_outputs_logs(self): + avatar = '''class Avatar: + def handle_turn(self, world_map, avatar_state): + print('I AM A PRINT STATEMENT') + return MoveAction(direction.NORTH) + + ''' + + runner = AvatarRunner() + response = runner.process_avatar_turn(world_map={}, avatar_state={}, src_code=avatar) + self.assertTrue('I AM A PRINT STATEMENT' in response['log']) + + def test_print_collector_outputs_multiple_prints(self): + avatar = '''class Avatar: + def handle_turn(self, world_map, avatar_state): + print('I AM A PRINT STATEMENT') + print('I AM ALSO A PRINT STATEMENT') + return MoveAction(direction.NORTH) + + ''' + runner = AvatarRunner() + response = runner.process_avatar_turn(world_map={}, avatar_state={}, src_code=avatar) + self.assertTrue('I AM A PRINT STATEMENT' in response['log']) + self.assertTrue('I AM ALSO A PRINT STATEMENT' in response['log']) + + def test_print_collector_outputs_prints_from_different_scopes(self): + avatar = '''class Avatar: + def handle_turn(self, world_map, avatar_state): + print('I AM NOT A NESTED PRINT') + self.foo() + return MoveAction(direction.NORTH) + + def foo(self): + print('I AM A NESTED PRINT') + + ''' + runner = AvatarRunner() + response = runner.process_avatar_turn(world_map={}, avatar_state={}, src_code=avatar) + self.assertTrue('I AM NOT A NESTED PRINT' in response['log']) + self.assertTrue('I AM A NESTED PRINT' in response['log']) + + def test_print_collector_prints_output_and_runtime_error_if_exists(self): + avatar = '''class Avatar: + def handle_turn(self, world_map, avatar_state): + print('THIS CODE IS BROKEN') + return None + ''' + runner = AvatarRunner() + response = runner.process_avatar_turn(world_map={}, avatar_state={}, src_code=avatar) + self.assertTrue('THIS CODE IS BROKEN' in response['log']) + self.assertTrue('"None" is not a valid action object.' in response['log']) diff --git a/version.txt b/version.txt index f7abe273d..70d5b25fa 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.2 \ No newline at end of file +0.4.3 \ No newline at end of file