diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
index ceec6e0013..65310a2757 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
@@ -63,6 +63,12 @@ protected internal override Vector3 GetLinearVelocityAt(int index)
return m_Bodies[index].velocity;
}
+ ///
+ protected internal override Vector3 GetAngularVelocityAt(int index)
+ {
+ return m_Bodies[index].angularVelocity;
+ }
+
///
protected internal override Pose GetPoseAt(int index)
{
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
index d9f9c0d441..8984685f83 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
@@ -34,11 +34,21 @@ public struct PhysicsSensorSettings
///
public bool UseModelSpaceLinearVelocity;
+ ///
+ /// Whether to use model space (relative to the root body) angular velocities as observations.
+ ///
+ public bool UseModelSpaceAngularVelocity;
+
///
/// Whether to use local space (relative to the parent body) linear velocities as observations.
///
public bool UseLocalSpaceLinearVelocity;
+ ///
+ /// Whether to use local space (relative to the parent body) angular velocities as observations.
+ ///
+ public bool UseLocalSpaceAngularVelocity;
+
///
/// Whether to use joint-specific positions and angles as observations.
///
@@ -67,7 +77,8 @@ public static PhysicsSensorSettings Default()
///
public bool UseModelSpace
{
- get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; }
+ get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity ||
+ UseModelSpaceAngularVelocity; }
}
///
@@ -75,7 +86,8 @@ public bool UseModelSpace
///
public bool UseLocalSpace
{
- get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
+ get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity ||
+ UseLocalSpaceAngularVelocity; }
}
}
@@ -109,9 +121,18 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
}
}
- foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities())
+ if (settings.UseModelSpaceLinearVelocity)
{
- if (settings.UseModelSpaceLinearVelocity)
+ foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities())
+ {
+ writer.Add(vel, offset);
+ offset += 3;
+ }
+ }
+
+ if (settings.UseModelSpaceAngularVelocity)
+ {
+ foreach (var vel in poseExtractor.GetEnabledModelSpaceAngularVelocities())
{
writer.Add(vel, offset);
offset += 3;
@@ -136,9 +157,18 @@ public static int WritePoses(this ObservationWriter writer, PhysicsSensorSetting
}
}
- foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities())
+ if (settings.UseLocalSpaceLinearVelocity)
+ {
+ foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities())
+ {
+ writer.Add(vel, offset);
+ offset += 3;
+ }
+ }
+
+ if (settings.UseLocalSpaceAngularVelocity)
{
- if (settings.UseLocalSpaceLinearVelocity)
+ foreach (var vel in poseExtractor.GetEnabledLocalSpaceAngularVelocities())
{
writer.Add(vel, offset);
offset += 3;
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
index 059804377b..b673589ec5 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
@@ -21,7 +21,9 @@ public abstract class PoseExtractor
Pose[] m_LocalSpacePoses;
Vector3[] m_ModelSpaceLinearVelocities;
+ Vector3[] m_ModelSpaceAngularVelocities;
Vector3[] m_LocalSpaceLinearVelocities;
+ Vector3[] m_LocalSpaceAngularVelocities;
bool[] m_PoseEnabled;
@@ -83,6 +85,25 @@ public IEnumerable GetEnabledModelSpaceVelocities()
}
}
+ ///
+ /// Read iterator for the enabled model space angular velocities.
+ ///
+ public IEnumerable GetEnabledModelSpaceAngularVelocities()
+ {
+ if (m_ModelSpaceAngularVelocities == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_ModelSpaceAngularVelocities.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_ModelSpaceAngularVelocities[i];
+ }
+ }
+ }
+
///
/// Read iterator for the enabled local space linear velocities.
///
@@ -102,6 +123,25 @@ public IEnumerable GetEnabledLocalSpaceVelocities()
}
}
+ ///
+ /// Read iterator for the enabled local space angular velocities.
+ ///
+ public IEnumerable GetEnabledLocalSpaceAngularVelocities()
+ {
+ if (m_LocalSpaceAngularVelocities == null)
+ {
+ yield break;
+ }
+
+ for (var i = 0; i < m_LocalSpaceAngularVelocities.Length; i++)
+ {
+ if (m_PoseEnabled[i])
+ {
+ yield return m_LocalSpaceAngularVelocities[i];
+ }
+ }
+ }
+
///
/// Number of enabled poses in the hierarchy (read-only).
///
@@ -181,7 +221,9 @@ protected void Setup(int[] parentIndices)
m_LocalSpacePoses = new Pose[numPoses];
m_ModelSpaceLinearVelocities = new Vector3[numPoses];
+ m_ModelSpaceAngularVelocities = new Vector3[numPoses];
m_LocalSpaceLinearVelocities = new Vector3[numPoses];
+ m_LocalSpaceAngularVelocities = new Vector3[numPoses];
m_PoseEnabled = new bool[numPoses];
// All poses are enabled by default. Generally we'll want to disable the root though.
@@ -205,6 +247,13 @@ protected void Setup(int[] parentIndices)
///
protected internal abstract Vector3 GetLinearVelocityAt(int index);
+ ///
+ /// Return the world space angular velocity of the i'th object.
+ ///
+ ///
+ ///
+ protected internal abstract Vector3 GetAngularVelocityAt(int index);
+
///
/// Return the underlying object at the given index. This is only
/// used for display in the inspector.
@@ -232,6 +281,7 @@ public void UpdateModelSpacePoses()
var rootWorldTransform = GetPoseAt(0);
var worldToModel = rootWorldTransform.Inverse();
var rootLinearVel = GetLinearVelocityAt(0);
+ var rootAngularVel = GetAngularVelocityAt(0);
for (var i = 0; i < m_ModelSpacePoses.Length; i++)
{
@@ -240,8 +290,11 @@ public void UpdateModelSpacePoses()
m_ModelSpacePoses[i] = currentModelSpacePose;
var currentBodyLinearVel = GetLinearVelocityAt(i);
- var relativeVelocity = currentBodyLinearVel - rootLinearVel;
- m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
+ var relativeLinearVel = currentBodyLinearVel - rootLinearVel;
+ m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeLinearVel;
+ var currentBodyAngularVel = GetAngularVelocityAt(i);
+ var relativeAngularVel = currentBodyAngularVel - rootAngularVel;
+ m_ModelSpaceAngularVelocities[i] = worldToModel.rotation * relativeAngularVel;
}
}
}
@@ -272,11 +325,15 @@ public void UpdateLocalSpacePoses()
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
var currentLinearVel = GetLinearVelocityAt(i);
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
+ var parentAngularVel = GetAngularVelocityAt(m_ParentIndices[i]);
+ var currentAngularVel = GetAngularVelocityAt(i);
+ m_LocalSpaceAngularVelocities[i] = invParent.rotation * (currentAngularVel - parentAngularVel);
}
else
{
m_LocalSpacePoses[i] = Pose.identity;
m_LocalSpaceLinearVelocities[i] = Vector3.zero;
+ m_LocalSpaceAngularVelocities[i] = Vector3.zero;
}
}
}
@@ -296,7 +353,9 @@ public int GetNumPoseObservations(PhysicsSensorSettings settings)
obsPerPose += settings.UseLocalSpaceRotations ? 4 : 0;
obsPerPose += settings.UseModelSpaceLinearVelocity ? 3 : 0;
+ obsPerPose += settings.UseModelSpaceAngularVelocity ? 3 : 0;
obsPerPose += settings.UseLocalSpaceLinearVelocity ? 3 : 0;
+ obsPerPose += settings.UseLocalSpaceAngularVelocity ? 3 : 0;
return NumEnabledPoses * obsPerPose;
}
@@ -363,6 +422,7 @@ internal IList GetDisplayNodes()
{
return Array.Empty();
}
+
var nodesOut = new List(NumPoses);
// List of children for each node
@@ -379,6 +439,7 @@ internal IList GetDisplayNodes()
{
tree[parent] = new List();
}
+
tree[parent].Add(i);
}
@@ -422,7 +483,6 @@ internal IList GetDisplayNodes()
return nodesOut;
}
-
}
///
diff --git a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
index b54b0b5713..d47301a5b2 100644
--- a/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
+++ b/com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
@@ -3,7 +3,6 @@
namespace Unity.MLAgents.Extensions.Sensors
{
-
///
/// Utility class to track a hierarchy of RigidBodies. These are assumed to have a root node,
/// and child nodes are connect to their parents via Joints.
@@ -129,9 +128,22 @@ protected internal override Vector3 GetLinearVelocityAt(int index)
// No velocity on the virtual root
return Vector3.zero;
}
+
return m_Bodies[index].velocity;
}
+ ///
+ protected internal override Vector3 GetAngularVelocityAt(int index)
+ {
+ if (index == 0 && m_VirtualRoot != null)
+ {
+ // No velocity on the virtual root
+ return Vector3.zero;
+ }
+
+ return m_Bodies[index].angularVelocity;
+ }
+
///
protected internal override Pose GetPoseAt(int index)
{
@@ -156,6 +168,7 @@ protected internal override Object GetObjectAt(int index)
{
return m_VirtualRoot;
}
+
return m_Bodies[index];
}
@@ -167,6 +180,11 @@ protected internal override Object GetObjectAt(int index)
///
internal Dictionary GetBodyPosesEnabled()
{
+ if (m_Bodies == null)
+ {
+ return new Dictionary();
+ }
+
var bodyPosesEnabled = new Dictionary(m_Bodies.Length);
for (var i = 0; i < m_Bodies.Length; i++)
{
@@ -205,5 +223,4 @@ internal IEnumerable GetEnabledRigidbodies()
}
}
}
-
}
diff --git a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/PoseExtractorTests.cs b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/PoseExtractorTests.cs
index 782b7da3a9..f7a7c32852 100644
--- a/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/PoseExtractorTests.cs
+++ b/com.unity.ml-agents.extensions/Tests/Runtime/Sensors/PoseExtractorTests.cs
@@ -19,6 +19,11 @@ protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
+
+ protected internal override Vector3 GetAngularVelocityAt(int index)
+ {
+ return Vector3.zero;
+ }
}
class UselessPoseExtractor : BasicPoseExtractor
@@ -114,6 +119,10 @@ protected internal override Vector3 GetLinearVelocityAt(int index)
return Vector3.zero;
}
+ protected internal override Vector3 GetAngularVelocityAt(int index)
+ {
+ return Vector3.zero;
+ }
}
[Test]