diff --git a/genesis/utils/mjcf.py b/genesis/utils/mjcf.py index acf13d7..2326e0b 100644 --- a/genesis/utils/mjcf.py +++ b/genesis/utils/mjcf.py @@ -1,4 +1,5 @@ import os +import xml.etree.ElementTree as ET import mujoco import numpy as np @@ -12,8 +13,6 @@ from . import mesh as mu from .misc import get_assets_dir -import xml.etree.ElementTree as ET - def extract_compiler_attributes(xml_path): # Parse the XML file @@ -302,16 +301,11 @@ def parse_geom(mj, i_g, scale, convexify, surface, xml_path): uv_coordinates = tmesh.vertices[:, :2].copy() uv_coordinates -= uv_coordinates.min(axis=0) uv_coordinates /= uv_coordinates.max(axis=0) - image = Image.open(os.path.join(assets_dir, tex_path)) + image = Image.open(os.path.join(assets_dir, tex_path)).convert("RGBA") image_array = np.array(image) - if image_array.ndim == 2: # convert gray image to RGBA - rgba = np.zeros((image_array.shape[0], image_array.shape[1], 4), dtype=np.uint8) - rgba[:, :, :3] = image_array[:, :, None] - rgba[:, :, 3] = 255 - image_array = rgba tex_repeat = mj.mat_texrepeat[mat_id].astype(int) image_array = np.tile(image_array, (tex_repeat[0], tex_repeat[1], 1)) - visual = TextureVisuals(uv=uv_coordinates, image=Image.fromarray(image_array)) + visual = TextureVisuals(uv=uv_coordinates, image=Image.fromarray(image_array, mode="RGBA")) tmesh.visual = visual elif mj_type == mujoco.mjtGeom.mjGEOM_MESH: @@ -326,6 +320,10 @@ def parse_geom(mj, i_g, scale, convexify, surface, xml_path): if tex_start >= 0: tex_end = mj.mesh_texcoordadr[i + 1] if not last else mj.mesh_texcoord.shape[0] + if tex_end == -1: + tex_end = tex_start + (vert_end - vert_start) + assert tex_end - tex_start == vert_end - vert_start + mat_id = mj.geom_matid[i_g] tex_id = next((x for x in mj.mat_texid[mat_id] if x != -1), None) if not tex_id is None: @@ -336,7 +334,12 @@ def parse_geom(mj, i_g, scale, convexify, surface, xml_path): # TODO: check if we can parse tag with mj model texturedir = extract_compiler_attributes(xml_path)["texturedir"] assets_dir = os.path.join(get_assets_dir(), os.path.join(os.path.dirname(xml_path), texturedir)) - visual = TextureVisuals(uv=uv, image=Image.open(os.path.join(assets_dir, tex_path))) + + image = Image.open(os.path.join(assets_dir, tex_path)).convert("RGBA") + image_array = np.array(image) + tex_repeat = mj.mat_texrepeat[mat_id].astype(int) + image_array = np.tile(image_array, (tex_repeat[0], tex_repeat[1], 1)) + visual = TextureVisuals(uv=uv, image=Image.fromarray(image_array, mode="RGBA")) tmesh = trimesh.Trimesh( vertices=mj.mesh_vert[vert_start:vert_end],