Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
budzianowski committed May 17, 2024
2 parents d2504a6 + 963f094 commit 8765660
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 51 deletions.
16 changes: 5 additions & 11 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,9 @@ clean:
# Static Checks #
# ------------------------ #

py-files := $(shell find . -name '*.py')

format:
@black $(py-files)
@ruff format $(py-files)
@black sim
@ruff format sim
.PHONY: format

format-cpp:
Expand All @@ -88,15 +86,11 @@ format-cpp:
.PHONY: format-cpp

static-checks:
@black --diff --check $(py-files)
@ruff check $(py-files)
@mypy --install-types --non-interactive $(py-files)
@black --diff --check sim
@ruff check sim
@mypy --install-types --non-interactive sim
.PHONY: lint

mypy-daemon:
@dmypy run -- $(py-files)
.PHONY: mypy-daemon

# ------------------------ #
# Unit tests #
# ------------------------ #
Expand Down
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
[![Discord](https://img.shields.io/discord/1224056091017478166)](https://discord.gg/k5mSvCkYQh)
[![Wiki](https://img.shields.io/badge/wiki-humanoids-black)](https://humanoids.wiki)
<br />
[![python](https://img.shields.io/badge/-Python_3.8-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit)
[![black](https://img.shields.io/badge/Code%20Style-Black-black.svg?labelColor=gray)](https://black.readthedocs.io/en/stable/)
[![ruff](https://img.shields.io/badge/Linter-Ruff-red.svg?labelColor=gray)](https://github.com/charliermarsh/ruff)
<br />
[![Python Checks](https://github.com/kscalelabs/sim/actions/workflows/test.yml/badge.svg)](https://github.com/kscalelabs/sim/actions/workflows/test.yml)
[![Update Stompy S3 Model](https://github.com/kscalelabs/sim/actions/workflows/update_stompy_s3.yml/badge.svg)](https://github.com/kscalelabs/sim/actions/workflows/update_stompy_s3.yml)

</div>
Expand All @@ -29,6 +34,8 @@ The getting up task is still an open challenge!

## Getting Started

This repository requires Python 3.8 due to compatibility issues with underlying libraries. We hope to support more recent Python versions in the future.

1. Clone this repository:
```bash
git clone https://github.com/kscalelabs/sim.git
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ warn_redundant_casts = true
incremental = true
namespace_packages = false

exclude = "sim/humanoid_gym/"

[[tool.mypy.overrides]]

module = [
"isaacgym.*",
"humanoid.*",
"isaacgym.*",
"IsaacGymEnvs.*",
"mujoco.*",
]

ignore_missing_imports = true
Expand All @@ -49,6 +52,8 @@ profile = "black"
line-length = 120
target-version = "py310"

exclude = ["sim/humanoid_gym"]

[tool.ruff.lint]

select = ["ANN", "D", "E", "F", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"]
Expand Down
1 change: 1 addition & 0 deletions sim/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# requirements.txt

mediapy
torch
33 changes: 18 additions & 15 deletions sim/scripts/create_mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@
2. Condim 3 and 4 and difference in results
"""

import logging
import xml.dom.minidom
import xml.etree.ElementTree as ET
from pathlib import Path
import xml.dom.minidom
from typing import List, Union

from kol.formats import mjcf

from sim.stompy.joints import StompyFixed
from sim.env import stompy_mjcf_path

logger = logging.getLogger(__name__)

STOMPY_HEIGHT = 1.0


def _pretty_print_xml(xml_string):
def _pretty_print_xml(xml_string: str) -> str:
"""Formats the provided XML string into a pretty-printed version."""
parsed_xml = xml.dom.minidom.parseString(xml_string)
return parsed_xml.toprettyxml(indent=" ")
Expand All @@ -32,7 +36,8 @@ class Sim2SimRobot(mjcf.Robot):
"""A class to adapt the world in a Mujoco XML file."""

def adapt_world(self) -> None:
root = self.tree.getroot()
root: ET.Element = self.tree.getroot()

asset = root.find("asset")
asset.append(
ET.Element(
Expand Down Expand Up @@ -116,26 +121,24 @@ def adapt_world(self) -> None:
).to_xml(),
)

motors = []
sensors = []
motors: List[mjcf.Motor] = []
sensor_pos: List[mjcf.Actuatorpos] = []
sensor_vel: List[mjcf.Actuatorvel] = []
sensor_frc: List[mjcf.Actuatorfrc] = []
for joint, limits in StompyFixed.default_limits().items():
if joint in StompyFixed.default_standing().keys():
motors.append(
mjcf.Motor(
name=joint, joint=joint, gear=1, ctrlrange=(limits["lower"], limits["upper"]), ctrllimited=True
)
)
sensors.extend(
[
mjcf.Actuatorpos(name=joint + "_p", actuator=joint, user="13"),
mjcf.Actuatorvel(name=joint + "_v", actuator=joint, user="13"),
mjcf.Actuatorfrc(name=joint + "_f", actuator=joint, user="13", noise=0.001),
]
)
sensor_pos.append(mjcf.Actuatorpos(name=joint + "_p", actuator=joint, user="13"))
sensor_vel.append(mjcf.Actuatorvel(name=joint + "_v", actuator=joint, user="13"))
sensor_frc.append(mjcf.Actuatorfrc(name=joint + "_f", actuator=joint, user="13", noise=0.001))

# Add motors and sensors
root.append(mjcf.Actuator(motors).to_xml())
root.append(mjcf.Sensor(sensors).to_xml())
root.append(mjcf.Sensor(sensor_pos, sensor_vel, sensor_frc).to_xml())

# Add imus
sensors = root.find("sensor")
Expand Down Expand Up @@ -182,11 +185,11 @@ def adapt_world(self) -> None:
)
self.tree = ET.ElementTree(root)

def save(self, path: str | Path) -> None:
def save(self, path: Union[str, Path]) -> None:
rough_string = ET.tostring(self.tree.getroot(), "utf-8")
# Pretty print the XML
formatted_xml = _pretty_print_xml(rough_string)

logger.info("XML:\n%s", formatted_xml)
with open(path, "w") as f:
f.write(formatted_xml)

Expand Down
26 changes: 17 additions & 9 deletions sim/scripts/simulate_mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,25 @@
mjpython sim/scripts/simulate_mjcf.py --record
"""

import argparse
import logging
import time
from pathlib import Path
from typing import List, Union

import mediapy as media
import mujoco
import mujoco.viewer
import mediapy as media
import argparse
import numpy as np

from sim.env import stompy_mjcf_path
from sim.logging import configure_logging

logger = logging.getLogger(__name__)


def simulate(model_path, duration, framerate, record_video):
frames = []
def simulate(model_path: Union[str, Path], duration: float, framerate: float, record_video: bool) -> None:
frames: List[np.ndarray] = []
model = mujoco.MjModel.from_xml_path(model_path)
data = mujoco.MjData(model)
renderer = mujoco.Renderer(model)
Expand All @@ -37,18 +45,18 @@ def simulate(model_path, duration, framerate, record_video):
if record_video:
video_path = "mjcf_simulation.mp4"
media.write_video(video_path, frames, fps=framerate)
print(f"Video saved to {video_path}")
# print(f"Video saved to {video_path}")
logger.info("Video saved to %s", video_path)


if __name__ == "__main__":
configure_logging()

parser = argparse.ArgumentParser(description="MuJoCo Simulation")
parser.add_argument(
"--model_path", type=str, default=str(stompy_mjcf_path()), help="Path to the MuJoCo XML model file"
)
parser.add_argument("--model_path", type=str, default=str(stompy_mjcf_path()), help="Path to the MuJoCo XML file")
parser.add_argument("--duration", type=int, default=3, help="Duration of the simulation in seconds")
parser.add_argument("--framerate", type=int, default=30, help="Frame rate for video recording")
parser.add_argument("--record", action="store_true", help="Flag to record video")

args = parser.parse_args()

simulate(args.model_path, args.duration, args.framerate, args.record)
15 changes: 7 additions & 8 deletions sim/scripts/simulate_urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
"""

import logging
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, NewType
from typing import Any, Dict, Literal, NewType

from isaacgym import gymapi, gymtorch, gymutil

Expand Down Expand Up @@ -174,14 +173,14 @@ def load_gym() -> GymParams:


def run_gym(gym: GymParams, mode: Literal["one_at_a_time", "all_at_once"] = "all_at_once") -> None:
joints = Stompy.all_joints()
last_time = time.time()
# joints = Stompy.all_joints()
# last_time = time.time()

dof_ids: Dict[str, int] = gym.gym.get_actor_dof_dict(gym.env, gym.robot)
body_ids: List[str] = gym.gym.get_actor_rigid_body_names(gym.env, gym.robot)
# dof_ids: Dict[str, int] = gym.gym.get_actor_dof_dict(gym.env, gym.robot)
# body_ids: List[str] = gym.gym.get_actor_rigid_body_names(gym.env, gym.robot)

joint_id = 0
effort = 5.0
# joint_id = 0
# effort = 5.0

while not gym.gym.query_viewer_has_closed(gym.viewer):
gym.gym.simulate(gym.sim)
Expand Down
10 changes: 7 additions & 3 deletions sim/scripts/update_stompy_s3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# mypy: disable-error-code="import-not-found"
"""Updates the Stompy model."""

import tarfile
from pathlib import Path

from kol.logging import configure_logging
from kol.onshape.converter import Converter

STOMPY_MODEL = (
"https://cad.onshape.com/documents/71f793a23ab7562fb9dec82d/"
"w/6160a4f44eb6113d3fa116cd/e/1a95e260677a2d2d5a3b1eb3"
Expand All @@ -23,6 +21,12 @@


def main() -> None:
try:
from kol.logging import configure_logging
from kol.onshape.converter import Converter
except ImportError:
raise ImportError("Please install the `kscale-onshape-library` package to run this script.")

configure_logging()

output_dir = Path("stompy")
Expand Down
6 changes: 3 additions & 3 deletions sim/stompy/joints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import textwrap
from abc import ABC
from typing import Dict, List, Union
from typing import Dict, List, Tuple, Union

import numpy as np

Expand All @@ -30,8 +30,8 @@ def joints(cls) -> List[str]:
]

@classmethod
def joints_motors(cls) -> List[str]:
joint_names = []
def joints_motors(cls) -> List[Tuple[str, str]]:
joint_names: List[Tuple[str, str]] = []
for attr in dir(cls):
if not attr.startswith("__"):
attr2 = getattr(cls, attr)
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Defines PyTest configuration for the project."""

import random
from typing import List

import pytest
from _pytest.python import Function
Expand All @@ -11,5 +12,5 @@ def set_random_seed() -> None:
random.seed(1337)


def pytest_collection_modifyitems(items: list[Function]) -> None:
def pytest_collection_modifyitems(items: List[Function]) -> None:
items.sort(key=lambda x: x.get_closest_marker("slow") is not None)

0 comments on commit 8765660

Please sign in to comment.