diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3c9dff0..462947f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,35 +1,37 @@ -# Automatically build the project and run any configured tests for every push -# and submitted pull request. This can help catch issues that only occur on -# certain platforms or Java versions, and provides a first line of defence -# against bad commits. - +--- name: build and test -on: [pull_request, push] +on: + - pull_request + - push defaults: run: shell: bash -l {0} - jobs: build: name: build runs-on: self-hosted strategy: matrix: - # Use these Java versions - java: [ 21, ] # Current Java LTS & minimum supported by Minecraft + java: + - 21 steps: - name: lock file run: lockfile /tmp/minecraft-test-lock - name: checkout vereya uses: actions/checkout@v3 with: - repository: trueagi-io/Vereya + repository: noSkill/Vereya path: Vereya + ref: server clean: false - name: install vereya - run: rm /home/tester/.minecraft/mods/* ; ls && cd $GITHUB_WORKSPACE/Vereya/ && ./gradlew build && cp $GITHUB_WORKSPACE/Vereya/build/libs/* /home/tester/.minecraft/mods/ + run: rm /home/tester/.minecraft/mods/* ; ls && cd $GITHUB_WORKSPACE/Vereya/ && + ./gradlew build && cp $GITHUB_WORKSPACE/Vereya/build/libs/* /home/tester/.minecraft/mods/ && + mkdir $GITHUB_WORKSPACE/Vereya/server/mods/ ; + cp $GITHUB_WORKSPACE/Vereya/build/libs/* $GITHUB_WORKSPACE/Vereya/server/mods/ - name: install fabric - run: rsync -v $GITHUB_WORKSPACE/Vereya/fabric/* /home/tester/.minecraft/mods/ + run: rsync -v $GITHUB_WORKSPACE/Vereya/fabric/* /home/tester/.minecraft/mods/ && + cp $GITHUB_WORKSPACE/Vereya/fabric/* $GITHUB_WORKSPACE/Vereya/server/mods/ - name: checkout tagilmo uses: actions/checkout@v3 with: @@ -42,12 +44,15 @@ jobs: env: DISPLAY: :99 GITHUB_WORKSPACE: $GITHUB_WORKSPACE + - name: start minecraft server + run: cd Vereya/server && ./launch.sh & + env: + GITHUB_WORKSPACE: $GITHUB_WORKSPACE - name: run test run: | ps a|grep [j]ava && conda activate py31 && cd $GITHUB_WORKSPACE/tests/vereya && python run_tests.py - - name: save java logs if: always() uses: actions/upload-artifact@v3 diff --git a/examples/8_manyagents.py b/examples/8_manyagents.py new file mode 100755 index 0000000..295b3b2 --- /dev/null +++ b/examples/8_manyagents.py @@ -0,0 +1,64 @@ +from time import sleep +from tagilmo.utils.vereya_wrapper import MCConnector +from tagilmo.utils.mission_builder import AgentSection +import tagilmo.utils.mission_builder as mb +from examples.log import setup_logger + + +setup_logger() +agents = [AgentSection(name='Cristina'), AgentSection(name='Crestina')] +miss = mb.MissionXML(agentSections=agents) +# https://www.chunkbase.com/apps/superflat-generator +# flat world not working currently +# miss.setWorld(mb.flatworld("3;7,25*1,3*3,2;1;stronghold,biome_1,village,decoration,dungeon,lake,mineshaft,lava_lake")) +#miss.addAgent(1) +world = mb.defaultworld( + seed='5', + forceReset="false", + forceReuse="true") +miss.setWorld(world) + +mc = MCConnector(miss) +mc.safeStart() + + +fullStatKeys = ['XPos', 'YPos', 'ZPos', 'Pitch', 'Yaw'] +stats_old = [0]*len(fullStatKeys) +seenObjects = [] +nearEntities = [] +gridObjects = [] + +for i in range(600): + mc.observeProc() + + stats_new = [mc.getFullStat(key) for key in fullStatKeys] + if stats_new != stats_old and stats_new[0] != None: + print(' '.join(['%s: %.2f' % (fullStatKeys[n], stats_new[n]) for n in range(len(stats_new))])) + stats_old = stats_new + + crossObj = mc.getLineOfSight('type') + if crossObj is not None: + crossObj += ' ' + mc.getLineOfSight('hitType') + if not crossObj in seenObjects: + seenObjects += [crossObj] + print('******** Novel object in line-of-sight : ', crossObj) + + nearEnt = mc.getNearEntities() + if nearEnt != None: + for e in nearEnt: + if e['name'] != 'Agent-0': + if not e['name'] in nearEntities: + nearEntities += [e['name']] + print('++++++++ Novel nearby entity: ', e['name']) + elif abs(e['x'] - mc.getFullStat('XPos')) + abs(e['y'] - mc.getFullStat('YPos')) < 1: + print('!!!!!!!! Very close entity ', e['name']) + + grid = mc.getNearGrid() + for o in (grid if grid is not None else []): + if not o in gridObjects: + gridObjects += [o] + print('-------- Novel grid object: ', o) + + sleep(0.5) + +# run the script and control the agent manually to see updates diff --git a/examples/8_manyagents1.py b/examples/8_manyagents1.py new file mode 100755 index 0000000..7ce43f8 --- /dev/null +++ b/examples/8_manyagents1.py @@ -0,0 +1,65 @@ +from examples.log import setup_logger +from time import sleep +from tagilmo.utils.vereya_wrapper import MCConnector +from tagilmo.utils.mission_builder import AgentSection +import tagilmo.utils.mission_builder as mb + + + +setup_logger('app1.log') +agents = [AgentSection(name='Cristina'), AgentSection(name='Crestina')] +miss = mb.MissionXML(agentSections=agents) +# https://www.chunkbase.com/apps/superflat-generator +# flat world not working currently +# miss.setWorld(mb.flatworld("3;7,25*1,3*3,2;1;stronghold,biome_1,village,decoration,dungeon,lake,mineshaft,lava_lake")) +#miss.addAgent(1) +world = mb.defaultworld( + seed='5', + forceReset="false", + forceReuse="true") +miss.setWorld(world) + +mc = MCConnector(miss, agentId=1) +mc.safeStart() + + +fullStatKeys = ['XPos', 'YPos', 'ZPos', 'Pitch', 'Yaw'] +stats_old = [0]*len(fullStatKeys) +seenObjects = [] +nearEntities = [] +gridObjects = [] + +for i in range(600): + mc.observeProc() + + stats_new = [mc.getFullStat(key) for key in fullStatKeys] + if stats_new != stats_old and stats_new[0] != None: + print(' '.join(['%s: %.2f' % (fullStatKeys[n], stats_new[n]) for n in range(len(stats_new))])) + stats_old = stats_new + + crossObj = mc.getLineOfSight('type') + if crossObj is not None: + crossObj += ' ' + mc.getLineOfSight('hitType') + if not crossObj in seenObjects: + seenObjects += [crossObj] + print('******** Novel object in line-of-sight : ', crossObj) + + nearEnt = mc.getNearEntities() + if nearEnt != None: + for e in nearEnt: + if e['name'] != 'Agent-0': + if not e['name'] in nearEntities: + nearEntities += [e['name']] + print('++++++++ Novel nearby entity: ', e['name']) + elif abs(e['x'] - mc.getFullStat('XPos')) + abs(e['y'] - mc.getFullStat('YPos')) < 1: + print('!!!!!!!! Very close entity ', e['name']) + + grid = mc.getNearGrid() + for o in (grid if grid is not None else []): + if not o in gridObjects: + gridObjects += [o] + print('-------- Novel grid object: ', o) + + sleep(0.5) + +# run the script and control the agent manually to see updates diff --git a/examples/9_connect.py b/examples/9_connect.py new file mode 100755 index 0000000..add6c59 --- /dev/null +++ b/examples/9_connect.py @@ -0,0 +1,64 @@ +from time import sleep +from tagilmo.utils.vereya_wrapper import MCConnector +from tagilmo.utils.mission_builder import AgentSection +import tagilmo.utils.mission_builder as mb +from examples.log import setup_logger + + +setup_logger() +agents = [AgentSection(name='Cristina')] +miss = mb.MissionXML(agentSections=agents) +# https://www.chunkbase.com/apps/superflat-generator +# flat world not working currently +# miss.setWorld(mb.flatworld("3;7,25*1,3*3,2;1;stronghold,biome_1,village,decoration,dungeon,lake,mineshaft,lava_lake")) +#miss.addAgent(1) +world = mb.defaultworld( + seed='5', + forceReset="false", + forceReuse="true") +miss.setWorld(world) + +mc = MCConnector(miss, serverIp='127.0.0.1', serverPort=25565) +mc.safeStart() + + +fullStatKeys = ['XPos', 'YPos', 'ZPos', 'Pitch', 'Yaw'] +stats_old = [0]*len(fullStatKeys) +seenObjects = [] +nearEntities = [] +gridObjects = [] + +for i in range(600): + mc.observeProc() + + stats_new = [mc.getFullStat(key) for key in fullStatKeys] + if stats_new != stats_old and stats_new[0] != None: + print(' '.join(['%s: %.2f' % (fullStatKeys[n], stats_new[n]) for n in range(len(stats_new))])) + stats_old = stats_new + + crossObj = mc.getLineOfSight('type') + if crossObj is not None: + crossObj += ' ' + mc.getLineOfSight('hitType') + if not crossObj in seenObjects: + seenObjects += [crossObj] + print('******** Novel object in line-of-sight : ', crossObj) + + nearEnt = mc.getNearEntities() + if nearEnt != None: + for e in nearEnt: + if e['name'] != 'Agent-0': + if not e['name'] in nearEntities: + nearEntities += [e['name']] + print('++++++++ Novel nearby entity: ', e['name']) + elif abs(e['x'] - mc.getFullStat('XPos')) + abs(e['y'] - mc.getFullStat('YPos')) < 1: + print('!!!!!!!! Very close entity ', e['name']) + + grid = mc.getNearGrid() + for o in (grid if grid is not None else []): + if not o in gridObjects: + gridObjects += [o] + print('-------- Novel grid object: ', o) + + sleep(0.5) + +# run the script and control the agent manually to see updates diff --git a/examples/log.py b/examples/log.py index eac787b..08e8b19 100644 --- a/examples/log.py +++ b/examples/log.py @@ -1,6 +1,6 @@ import logging -def setup_logger(): +def setup_logger(log_file='app.log'): # create logger logger = logging.getLogger() logger.setLevel(logging.DEBUG) @@ -14,7 +14,7 @@ def setup_logger(): logger.addHandler(ch) - f = logging.handlers.RotatingFileHandler('app.log') + f = logging.handlers.RotatingFileHandler(log_file) f.setFormatter(formatter) f.setLevel(logging.DEBUG) logger.addHandler(f) diff --git a/tagilmo/VereyaPython/agent_host.py b/tagilmo/VereyaPython/agent_host.py index 672b9b6..c3aed2d 100644 --- a/tagilmo/VereyaPython/agent_host.py +++ b/tagilmo/VereyaPython/agent_host.py @@ -66,7 +66,7 @@ def __init__(self) -> None: def startMission(self, mission: MissionSpec, client_pool: List[ClientInfo], mission_record: MissionRecordSpec, role: int, - unique_experiment_id: str): + unique_experiment_id: str, serverIp=None, serverPort=0): logger.debug('startMission') self.world_state.clear() self.testSchemasCompatible() @@ -86,32 +86,30 @@ def startMission(self, mission: MissionSpec, client_pool: List[ClientInfo], if self.world_state.is_mission_running: raise MissionException("A mission is already running.", MissionErrorCode.MISSION_ALREADY_RUNNING) - self.initializeOurServers( mission, mission_record, role, unique_experiment_id ) + self.initializeOurServers( mission, mission_record, role, + unique_experiment_id, + serverIp, serverPort) pool = None - if role == 0: - logger.info("creating mission") - # We are the agent responsible for the integrated server. - # If we are part of a multi-agent mission, our mission should have been started before any of the others are attempted. - # This means we are in a position to reserve clients in the client pool: - reservedAgents = asyncio.run_coroutine_threadsafe(self.reserveClients(client_pool, - mission.getNumberOfAgents()), - self.io_service).result() - - if len(reservedAgents) != mission.getNumberOfAgents(): - # Not enough clients available - go no further. - logger.error("Failed to reserve sufficient clients - throwing MissionException.") - if (mission.getNumberOfAgents() == 1): - raise MissionException("Failed to find an available client for self.mission - tried all the clients in the supplied client pool.", MissionErrorCode.MISSION_INSUFFICIENT_CLIENTS_AVAILABLE) - else: - raise MissionException("There are not enough clients available in the ClientPool to start self." + str(mission.getNumberOfAgents()) + " agent mission.", MissionErrorCode.MISSION_INSUFFICIENT_CLIENTS_AVAILABLE) - pool = reservedAgents - else: - logger.info(f"our role {role}, joining existing mission") + + logger.info("creating mission") + # assume each agent is run by it's own AgentHost + reservedAgents = asyncio.run_coroutine_threadsafe(self.reserveClients(client_pool, + 1), + self.io_service).result() + + if len(reservedAgents) != 1: + # Not enough clients available - go no further. + logger.error("Failed to reserve sufficient clients - throwing MissionException.") + if (mission.getNumberOfAgents() == 1): + raise MissionException("Failed to find an available client for self.mission - tried all the clients in the supplied client pool.", MissionErrorCode.MISSION_INSUFFICIENT_CLIENTS_AVAILABLE) + else: + raise MissionException("There are not enough clients available in the ClientPool to start self." + str(mission.getNumberOfAgents()) + " agent mission.", MissionErrorCode.MISSION_INSUFFICIENT_CLIENTS_AVAILABLE) + pool = reservedAgents assert self.current_mission_init is not None - if( mission.getNumberOfAgents() > 1 and role > 0 \ - and not self.current_mission_init.hasMinecraftServerInformation()): - raise NotImplementedError("role > 0 is not implemented yet") + #if( mission.getNumberOfAgents() > 1 and role > 0 \ + # and not self.current_mission_init.hasMinecraftServerInformation()): + # raise NotImplementedError("role > 0 is not implemented yet") # work through the client pool until we find a client to run our mission for us assert pool @@ -147,6 +145,7 @@ async def reserveClients(self, client_pool: List[ClientInfo], clients_required: logger.info("Reserving client, received reply from " + str(item.ip_address) + ": " + reply) malmo_reservation_prefix = "MALMOOK" malmo_mismatch = "MALMOERRORVERSIONMISMATCH" + malmo_busy = "MALMOBUSY" if reply.startswith(malmo_reservation_prefix): # Successfully reserved self.client. reservedClients.add(item) @@ -223,7 +222,6 @@ def findClient(self, client_pool: List[ClientInfo]): tried all the clients in the \ supplied client pool.", MissionErrorCode.MISSION_INSUFFICIENT_CLIENTS_AVAILABLE ) - def peekWorldState(self) -> WorldState: with self.world_state_mutex: # Copy while holding lock. @@ -243,10 +241,11 @@ def getRecordingTemporaryDirectory(self) -> str: def initializeOurServers(self, mission: MissionSpec, mission_record: MissionRecordSpec, role: int, - unique_experiment_id: str) -> None: + unique_experiment_id: str, serverIp: Optional[str], serverPort:int) -> None: logging.debug("Initialising servers...") # make a MissionInit structure with default settings - self.current_mission_init = MissionInitSpec.from_param(mission, unique_experiment_id, role) + self.current_mission_init = MissionInitSpec.from_param(mission, unique_experiment_id, role, + serverIp, serverPort) self.current_mission_record = MissionRecord(mission_record) self.current_role = role self.listenForMissionControlMessages(self.current_mission_init.getAgentMissionControlPort()) diff --git a/tagilmo/VereyaPython/mission_init_spec.py b/tagilmo/VereyaPython/mission_init_spec.py index 06d7031..27af360 100644 --- a/tagilmo/VereyaPython/mission_init_spec.py +++ b/tagilmo/VereyaPython/mission_init_spec.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from .mission_spec import MissionSpec from .mission_init_xml import MissionInitXML from .client_info import default_client_mission_control_port @@ -7,10 +7,12 @@ @dataclass(slots=True, frozen=True) class MissionInitSpec: - mission_init: MissionInitXML = MissionInitXML() + mission_init: MissionInitXML = field( + default_factory=lambda: MissionInitXML() + ) @staticmethod - def from_param(mission_spec: MissionSpec, unique_experiment_id: str, role: int) -> 'MissionInitSpec': + def from_param(mission_spec: MissionSpec, unique_experiment_id: str, role: int, server_ip: str=None, server_port: int=0) -> 'MissionInitSpec': self = MissionInitSpec() # construct a default MissionInit using the provided MissionSpec self.mission_init.client_agent_connection.client_ip_address = "127.0.0.1" @@ -29,6 +31,8 @@ def from_param(mission_spec: MissionSpec, unique_experiment_id: str, role: int) self.mission_init.experiment_uid = unique_experiment_id self.mission_init.mission = mission_spec.mission self.mission_init.platform_version = MALMO_VERSION + self.mission_init.minecraft_server.connection_address = server_ip + self.mission_init.minecraft_server.connection_port = server_port return self @staticmethod diff --git a/tagilmo/VereyaPython/mission_init_xml.py b/tagilmo/VereyaPython/mission_init_xml.py index 6d3a4b6..34b6b5f 100644 --- a/tagilmo/VereyaPython/mission_init_xml.py +++ b/tagilmo/VereyaPython/mission_init_xml.py @@ -45,11 +45,13 @@ class MissionInitXML: minecraft_server: MinecraftServer client_agent_connection: ClientAgentConnection - def __init__(self, xml_text=None): + def __init__(self, xml_text=None, client_role=0, minecraft_server=None): self.client_agent_connection = ClientAgentConnection() - self.minecraft_server = MinecraftServer() + if minecraft_server is None: + minecraft_server = MinecraftServer() + self.minecraft_server = minecraft_server self.mission = None - self.client_role = 0 + self.client_role = client_role self.schema_version = '' self.platform_version = '' self.experiment_uid = '' diff --git a/tagilmo/VereyaPython/mission_spec.py b/tagilmo/VereyaPython/mission_spec.py index 5468a40..2366267 100644 --- a/tagilmo/VereyaPython/mission_spec.py +++ b/tagilmo/VereyaPython/mission_spec.py @@ -22,6 +22,7 @@ def getRoleValue(self, role: int, videoType: str, what: str) -> Optional[int]: elif what == 'w': tmp = vid.find('./{*}Width') assert tmp is not None + assert tmp.text is not None return int(tmp.text) elif what == 'h': tmp = vid.find('./{*}Height') diff --git a/tagilmo/VereyaPython/string_server.py b/tagilmo/VereyaPython/string_server.py index 3ead0f3..232ae43 100644 --- a/tagilmo/VereyaPython/string_server.py +++ b/tagilmo/VereyaPython/string_server.py @@ -24,7 +24,7 @@ def __init__(self, self.handle_string = handle_string self.log_name = log_name self.server = TCPServer(self.io_service, self.port, self.__cb, self.log_name) - self.writer = None + self.writer: TimestampedStringWriter = None def start(self) -> None: fut = asyncio.run_coroutine_threadsafe(self.server.startAccept(), self.io_service) diff --git a/tagilmo/utils/mission_builder.py b/tagilmo/utils/mission_builder.py index acf7e7a..592ec99 100755 --- a/tagilmo/utils/mission_builder.py +++ b/tagilmo/utils/mission_builder.py @@ -1,3 +1,5 @@ +from xml.sax.saxutils import quoteattr +from typing import Optional #skipped: # # 10 @@ -21,7 +23,7 @@ def xml(self): class ServerInitialConditions: - + def __init__(self, day_always=False, time_start_string=None, time_pass_string=None, weather_string=None, spawning_string="true", allowedmobs_string=None): self.day_always = day_always @@ -62,7 +64,8 @@ def xml(self): def flatworld(generatorString, forceReset="false", seed=''): - return '' + return '' + def defaultworld(seed=None, forceReset=False, forceReuse=False): if isinstance(forceReset, bool): @@ -73,10 +76,11 @@ def defaultworld(seed=None, forceReset=False, forceReuse=False): if forceReset: world_str += 'forceReset="' + forceReset + '" ' if forceReuse: - world_str += 'forceReuse="' + forceReuse + '" ' + world_str += 'forceReuse="' + str(forceReuse).lower() + '" ' world_str += '/>' return world_str + def fileworld(uri2save, forceReset="false"): str = '' + + class MissionXML: def __init__(self, about=About(), serverSection=ServerSection(), agentSections=[AgentSection()], namespace=None): @@ -500,7 +513,7 @@ def __init__(self, about=About(), serverSection=ServerSection(), agentSections=[ self.about = about self.serverSection = serverSection self.agentSections = agentSections - + def hasVideo(self): for section in self.agentSections: if section.hasVideo(): @@ -515,13 +528,13 @@ def hasSegmentation(self): def setSummary(self, summary_string): self.about.summary = summary_string - + def setWorld(self, worldgenerator_xml): self.serverSection.handlers.worldgenerator = worldgenerator_xml - + def setTimeLimit(self, timeLimitMs): self.serverSection.handlers.timeLimitMs = str(timeLimitMs) - + def addAgent(self, nCount=1, agentSections=None): if agentSections: self.agentSections += agentSections @@ -536,7 +549,7 @@ def setObservations(self, observations, nAgent=None): ag.agenthandlers.observations = observations else: self.agentSections[nAgent].agenthandlers.observations = observations - + def getAgentNames(self): return [ag.name for ag in self.agentSections] diff --git a/tagilmo/utils/vereya_wrapper.py b/tagilmo/utils/vereya_wrapper.py index e4f369b..5f800f7 100644 --- a/tagilmo/utils/vereya_wrapper.py +++ b/tagilmo/utils/vereya_wrapper.py @@ -58,33 +58,39 @@ def setMissionXML(self, missionXML, module=VP): self.mission = module.MissionSpec(missionXML.xml(), True) self.mission_record = module.MissionRecordSpec() - def __init__(self, missionXML, serverIp='127.0.0.1'): + def __init__(self, missionXML, clientIp='127.0.0.1', agentId=0, setupAll=False, serverIp=None, serverPort=None): self.missionDesc = None self.mission = None self.mission_record = None self.prev_mobs = defaultdict(set) # host -> set mapping - self.agentId = 0 + self.agentId = agentId + self.setupAll = setupAll self._data_lock = threading.RLock() - self.setUp(VP, missionXML, serverIp=serverIp) - - def setUp(self, module, missionXML, serverIp='127.0.0.1'): self.serverIp = serverIp - self.setMissionXML(missionXML, module) + self.serverPort = serverPort + self.setUp(VP, missionXML, clientIp=clientIp) + + def setUp(self, module, missionXML, clientIp='127.0.0.1'): + self.clientIp = clientIp + self.setMissionXML(missionXML, module ) agentIds = len(missionXML.agentSections) self.agent_hosts = dict() - self.agent_hosts.update({n: module.AgentHost() for n in range(agentIds)}) - self.agent_hosts[0].parse( sys.argv ) + if self.setupAll: + self.agent_hosts.update({n: module.AgentHost() for n in range(agentIds)}) + else: + self.agent_hosts[self.agentId] = module.AgentHost() + self.agent_hosts[self.agentId].parse( sys.argv ) if self.receivedArgument('recording_dir'): - recordingsDirectory = get_recordings_directory(self.agent_hosts[0]) + recordingsDirectory = get_recordings_directory(self.agent_hosts[self.agentId]) self.mission_record.recordRewards() self.mission_record.recordObservations() self.mission_record.recordCommands() self.mission_record.setDestination(recordingsDirectory + "//" + "lastRecording.tgz") - if self.agent_hosts[0].receivedArgument("record_video"): + if self.agent_hosts[self.agentId].receivedArgument("record_video"): self.mission_record.recordMP4(24, 2000000) self.client_pool = module.ClientPool() for x in range(10000, 10000 + agentIds): - self.client_pool.add( module.ClientInfo(serverIp, x) ) + self.client_pool.add( module.ClientInfo(clientIp, x) ) self.worldStates = [None] * agentIds self.observe = {k: None for k in range(agentIds)} self.isAlive = [True] * agentIds @@ -93,23 +99,34 @@ def setUp(self, module, missionXML, serverIp='127.0.0.1'): self._last_obs = dict() # agent_host -> TimestampedString self._all_mobs = set() - - def getVersion(self, num=0) -> str: + def getVersion(self, num=None) -> str: + if num is None: + num = self.agentId return self.agent_hosts[num].version - def receivedArgument(self, arg): - return self.agent_hosts[0].receivedArgument(arg) + def receivedArgument(self, arg, num=None): + if num is None: + num = self.agentId + return self.agent_hosts[num].receivedArgument(arg) def safeStart(self): # starting missions expId = str(uuid.uuid4()) # will not work for multithreading, distributed agents, etc. (should be same for all agents to join the same server/mission) - for role in range(len(self.agent_hosts)): + # chosen with fair dice + expId = '4' + if self.setupAll: + r = self.agent_hosts.keys() + else: + r = [self.agentId] + for role in r: used_attempts = 0 max_attempts = 5 while True: try: # Attempt start: - self.agent_hosts[role].startMission(self.mission, self.client_pool, self.mission_record, role, expId) + self.agent_hosts[role].startMission(self.mission, self.client_pool, + self.mission_record, role, expId, + self.serverIp, self.serverPort) #self.agent_hosts[role].startMission(self.mission, self.mission_record) break except (VP.MissionException) as e: @@ -199,7 +216,9 @@ def connect(name=None, video=False, seed=None): mc.safeStart() return mc - def is_mission_running(self, agentId=0): + def is_mission_running(self, agentId=None): + if agentId is None: + agentId = self.agentId world_state = self.agent_hosts[agentId].getWorldState() return world_state.is_mission_running @@ -212,7 +231,9 @@ def sendCommand(self, command, agentId=None): self.agent_hosts[agentId].sendCommand(command) def observeProc(self, agentId=None): - r = range(len(self.agent_hosts)) if agentId is None else range(agentId, agentId+1) + r = self.agent_hosts.keys() + if agentId is not None: + r = [agentId] for n in r: self.worldStates[n] = self.agent_hosts[n].getWorldState() self.isAlive[n] = self.worldStates[n].is_mission_running @@ -276,18 +297,26 @@ def _process_mobs(self, data, host): self.prev_mobs[host] = mobs self._all_mobs = set().union(*self.prev_mobs.values()) - def getImageFrame(self, agentId=0): + def getImageFrame(self, agentId=None): + if agentId is None: + agentId = self.agentId return self.frames[agentId] - def getSegmentationFrame(self, agentId=0): + def getSegmentationFrame(self, agentId=None): + if agentId is None: + agentId = self.agentId return self.segmentation_frames[agentId] - def getImage(self, agentId=0): + def getImage(self, agentId=None): + if agentId is None: + agentId = self.agentId if self.frames[agentId] is not None: return numpy.frombuffer(self.frames[agentId].pixels, dtype=numpy.uint8) return None - def getSegmentation(self, agentId=0): + def getSegmentation(self, agentId=None): + if agentId is None: + agentId = self.agentId if self.segmentation_frames[agentId] is not None: return numpy.frombuffer(self.segmentation_frames[agentId].pixels, dtype=numpy.uint8) return None @@ -314,7 +343,6 @@ def getFullStat(self, key, agentId=None): 'MobsKilled', 'PlayersKilled', 'DamageTaken', 'DamageDealt' keys (new) : 'input_type', 'isPaused' """ - return self.getParticularObservation(key, agentId) def getLineOfSights(self, agentId=None): @@ -428,7 +456,9 @@ def stop(self, idx=0): def getHumanInputs(self, agentId=None): return self.getParticularObservation('input_events', agentId) - def placeBlock(self, x: int, y: int, z: int, block_name: str, placement: str, agentId=0): + def placeBlock(self, x: int, y: int, z: int, block_name: str, placement: str, agentId=None): + if agentId is None: + agentId = self.agentId self.agent_hosts[agentId].sendCommand("placeBlock {} {} {} {} {}".format(x, y, z, block_name, placement)) def _sendMotionCommand(self, command, value, agentId=None): @@ -461,7 +491,7 @@ class RobustObserver: explicitlyPoseChangingCommands = ['move', 'jump', 'pitch', 'turn'] implicitlyPoseChangingCommands = ['attack'] - def __init__(self, mc, agentId = 0): + def __init__(self, mc, agentId=0): self.mc = mc self.passableBlocks = [] self.agentId = agentId diff --git a/tests/vereya/common.py b/tests/vereya/common.py index c538927..ce5a355 100644 --- a/tests/vereya/common.py +++ b/tests/vereya/common.py @@ -9,7 +9,11 @@ from tagilmo.utils.vereya_wrapper import MCConnector, RobustObserver from base_test import BaseTest -def init_mission(mc, start_x, start_z, seed, forceReset="false", forceReuse="false", start_y=78, worldType = "default", drawing_decorator = None): + +def init_mission(mc, start_x, start_z, seed, forceReset="false", + forceReuse="false", start_y=78, worldType = "default", drawing_decorator=None, serverIp=None, serverPort=0): + + want_depth = False video_producer = mb.VideoProducer(width=320 * 4, height=240 * 4, want_depth=want_depth) @@ -22,13 +26,15 @@ def init_mission(mc, start_x, start_z, seed, forceReset="false", forceReuse="fal print('starting at ({0}, {1})'.format(start_x, start_y)) + start = [start_x, start_y, start_z, 1] + if all(x is None for x in [start_x, start_y, start_z]): + start = None #miss = mb.MissionXML(namespace="ProjectMalmo.microsoft.com", miss = mb.MissionXML( - agentSections=[mb.AgentSection(name='Cristina', - agenthandlers=agent_handlers, - # depth - agentstart=mb.AgentStart([start_x, start_y, start_z, 1]))], - serverSection=mb.ServerSection(handlers=mb.ServerHandlers(drawingdecorator=drawing_decorator))) + agentSections=[mb.AgentSection(name='Cristina', + agenthandlers=agent_handlers, + agentstart=mb.AgentStart(start))], + serverSection=mb.ServerSection(handlers=mb.ServerHandlers(drawingdecorator=drawing_decorator))) flat_json = {"biome":"minecraft:plains", "layers":[{"block":"minecraft:diamond_block","height":1}], "structures":{"structures": {"village":{}}}} @@ -60,7 +66,7 @@ def init_mission(mc, start_x, start_z, seed, forceReset="false", forceReuse="fal os.mkdir('./observations') if mc is None: - mc = MCConnector(miss) + mc = MCConnector(miss, serverIp=serverIp, serverPort=serverPort) mc.mission_record.setDestination('./observations/') mc.mission_record.is_recording_observations = True obs = RobustObserver(mc) diff --git a/tests/vereya/test_agent.py b/tests/vereya/test_agent.py index 88a2317..6d6622a 100644 --- a/tests/vereya/test_agent.py +++ b/tests/vereya/test_agent.py @@ -36,12 +36,13 @@ def run(self): for act in acts: self.rob.sendCommand(act) + class TestAgent(BaseTest): mc = None @classmethod def setUpClass(self, *args, **kwargs): - start = (4.0, 69.0, 68) + start = (21.0, 135.0, 21) mc, obs = init_mission(None, start_x=start[0], start_y=start[1], start_z=start[2], forceReset='true', seed='2') self.mc = mc self.rob = obs @@ -79,6 +80,17 @@ def test_agent(self): inv = mc.getInventory() self.assertEqual(count_items(inv, item_to_obtain), 1, msg=f"check if {item_to_obtain} was crafted") + +class TestAgentServer(TestAgent): + @classmethod + def setUpClass(self, *args, **kwargs): + mc, obs = init_mission(None, start_x=-34, start_y=104, start_z=16, forceReset='false', forceReuse=True, seed='2', serverIp='127.0.0.1', serverPort=25565) + self.mc = mc + self.rob = obs + assert mc.safeStart() + time.sleep(4) + + def main(): VereyaPython.setupLogger() unittest.main()