diff --git a/python/mujoco/codegen/generate_spec_bindings.py b/python/mujoco/codegen/generate_spec_bindings.py index 2db5387027..d37fc8ccca 100644 --- a/python/mujoco/codegen/generate_spec_bindings.py +++ b/python/mujoco/codegen/generate_spec_bindings.py @@ -70,6 +70,19 @@ def _value_binding_code( return f'{classname}.def_property({",".join(def_property_args)});' +def _struct_binding_code( + field: ast_nodes.AnonymousStructDecl, classname: str = '', varname: str = '' +) -> str: + code = '' + name = classname + varname.title() + # explicitly generate for nested fields with arrays + if any(isinstance(f.type, ast_nodes.ArrayType) for f in field.fields): + for subfield in field.fields: + code += _binding_code(subfield, name) + # generate for the struct itself + field = ast_nodes.ValueType(name=name) + code += _value_binding_code(field, classname, varname) + return code def _array_binding_code( field: ast_nodes.ArrayType, classname: str = '', varname: str = '' @@ -227,8 +240,7 @@ def _binding_code(field: ast_nodes.StructFieldDecl, key: str) -> str: if isinstance(field.type, ast_nodes.ValueType): return _value_binding_code(field.type, key, field.name) elif isinstance(field.type, ast_nodes.AnonymousStructDecl): - field.type = ast_nodes.ValueType(name='mjVisual'+field.name.title()) - return _value_binding_code(field.type, key, field.name) + return _struct_binding_code(field.type, key, field.name) elif isinstance(field.type, ast_nodes.PointerType): return _ptr_binding_code(field.type, key, field.name) elif isinstance(field.type, ast_nodes.ArrayType): diff --git a/python/mujoco/specs.cc b/python/mujoco/specs.cc index c376ea35e7..2749d5e5e5 100644 --- a/python/mujoco/specs.cc +++ b/python/mujoco/specs.cc @@ -236,6 +236,8 @@ PYBIND11_MODULE(_specs, m) { py::class_ mjOption(m, "MjOption"); py::class_ mjStatistic(m, "MjStatistic"); py::class_ mjVisual(m, "MjVisual"); + py::class_ mjVisualHeadlight(m, "MjVisualHeadlight"); + py::class_ mjVisualRgba(m, "MjVisualRgba"); py::class_ mjsCompiler(m, "MjsCompiler"); DefineArray(m, "MjCharVec"); DefineArray(m, "MjStringVec"); @@ -979,6 +981,13 @@ PYBIND11_MODULE(_specs, m) { }); mjsPlugin.def("delete", [](raw::MjsPlugin& self) { mjs_delete(self.element); }); + // ============================= MJVISUAL ==================================== + mjVisual.def_property( + "global_", + [](raw::MjVisual& self) -> raw::MjVisualGlobal& { return self.global; }, + [](raw::MjVisual& self, raw::MjVisualGlobal& value) { + self.global = value; + }); #include "specs.cc.inc" } // PYBIND11_MODULE // NOLINT diff --git a/python/mujoco/specs_test.py b/python/mujoco/specs_test.py index e0d3e836ef..329f08c96f 100644 --- a/python/mujoco/specs_test.py +++ b/python/mujoco/specs_test.py @@ -848,22 +848,31 @@ def test_access_option_stat_visual(self): + + """) self.assertEqual(spec.option.timestep, 0.001) self.assertEqual(spec.stat.meansize, 0.05) self.assertEqual(spec.visual.quality.shadowsize, 4096) + self.assertEqual(spec.visual.headlight.active, 0) + self.assertEqual(spec.visual.global_, getattr(spec.visual, 'global')) + np.testing.assert_array_equal(spec.visual.rgba.camera, [0, 0, 0, 0]) spec.option.timestep = 0.002 spec.stat.meansize = 0.06 spec.visual.quality.shadowsize = 8192 + spec.visual.headlight.active = 1 + spec.visual.rgba.camera = [1, 1, 1, 1] model = spec.compile() self.assertEqual(model.opt.timestep, 0.002) self.assertEqual(model.stat.meansize, 0.06) self.assertEqual(model.vis.quality.shadowsize, 8192) + self.assertEqual(model.vis.headlight.active, 1) + np.testing.assert_array_equal(model.vis.rgba.camera, [1, 1, 1, 1]) def test_assign_list_element(self): spec = mujoco.MjSpec()