Skip to content

Commit

Permalink
Merge pull request #248 from Unity-Technologies/ability_to_disable_la…
Browse files Browse the repository at this point in the history
…beling

Added ability to disable labeling on an object by disabling its Labeling component
  • Loading branch information
sleal-unity authored Mar 12, 2021
2 parents 3502511 + 3d110c4 commit 07ea8c1
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 76 deletions.
2 changes: 2 additions & 0 deletions com.unity.perception/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ The newly added `LabelManager` class now enables custom Labelers to access the l

Improved UI for `KeypointTemplate` and added useful default colors for keypoint and skeleton definitions.

Added the ability to switch ground-truth labeling on or off for an object at runtime by enabling or disabling its `Labeling` component.

### Changed

Renamed all appearances of the term `KeyPoint` within types and names to `Keypoint`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static class InstanceIdToColorMapping
const uint k_HslCount = 64;
const uint k_ColorsPerAlpha = 256 * 256 * 256;
const uint k_InvalidPackedColor = 255; // packed uint for color (0, 0, 0, 255);
static readonly Color32 k_InvalidColor = new Color(0, 0, 0, 255);
public static readonly Color32 invalidColor = new Color(0, 0, 0, 255);
static readonly float k_GoldenRatio = (1 + Mathf.Sqrt(5)) / 2;
const int k_HuesInEachValue = 30;

Expand Down Expand Up @@ -143,7 +143,7 @@ public static Color32 GetColorFromPackedColor(uint color)
/// <returns>Returns true if the ID was mapped to a non-black color, otherwise returns false</returns>
public static bool TryGetColorFromInstanceId(uint id, out Color32 color)
{
color = k_InvalidColor;
color = invalidColor;
if (id > maxId) return false;

var packed = GetColorForId(id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,20 +332,21 @@ bool DoesTemplateContainJoint(JointLabel jointLabel)

void ProcessLabel(Labeling labeledEntity)
{
// Cache out the data of a labeled game object the first time we see it, this will
// save performance each frame. Also checks to see if a labeled game object can be annotated.
if (!m_KnownStatus.ContainsKey(labeledEntity.instanceId))
if (idLabelConfig.TryGetLabelEntryFromInstanceId(labeledEntity.instanceId, out var labelEntry))
{
var cached = new CachedData()
// Cache out the data of a labeled game object the first time we see it, this will
// save performance each frame. Also checks to see if a labeled game object can be annotated.
if (!m_KnownStatus.ContainsKey(labeledEntity.instanceId))
{
status = false,
animator = null,
keypoints = new KeypointEntry(),
overrides = new List<(JointLabel, int)>()
};
var cached = new CachedData()
{
status = false,
animator = null,
keypoints = new KeypointEntry(),
overrides = new List<(JointLabel, int)>()
};


if (idLabelConfig.TryGetLabelEntryFromInstanceId(labeledEntity.instanceId, out var labelEntry))
{
var entityGameObject = labeledEntity.gameObject;

cached.keypoints.instance_id = labeledEntity.instanceId;
Expand Down Expand Up @@ -373,55 +374,55 @@ void ProcessLabel(Labeling labeledEntity)
cached.status = true;
}
}
}

m_KnownStatus[labeledEntity.instanceId] = cached;
}
m_KnownStatus[labeledEntity.instanceId] = cached;
}

var cachedData = m_KnownStatus[labeledEntity.instanceId];
var cachedData = m_KnownStatus[labeledEntity.instanceId];

if (cachedData.status)
{
var animator = cachedData.animator;
var keypoints = cachedData.keypoints.keypoints;

// Go through all of the rig keypoints and get their location
for (var i = 0; i < activeTemplate.keypoints.Length; i++)
if (cachedData.status)
{
var pt = activeTemplate.keypoints[i];
if (pt.associateToRig)
var animator = cachedData.animator;
var keypoints = cachedData.keypoints.keypoints;

// Go through all of the rig keypoints and get their location
for (var i = 0; i < activeTemplate.keypoints.Length; i++)
{
var bone = animator.GetBoneTransform(pt.rigLabel);
if (bone != null)
var pt = activeTemplate.keypoints[i];
if (pt.associateToRig)
{
var loc = ConvertToScreenSpace(bone.position);
keypoints[i].index = i;
keypoints[i].x = loc.x;
keypoints[i].y = loc.y;
keypoints[i].state = 2;
var bone = animator.GetBoneTransform(pt.rigLabel);
if (bone != null)
{
var loc = ConvertToScreenSpace(bone.position);
keypoints[i].index = i;
keypoints[i].x = loc.x;
keypoints[i].y = loc.y;
keypoints[i].state = 2;
}
}
}
}

// Go through all of the additional or override points defined by joint labels and get
// their locations
foreach (var (joint, idx) in cachedData.overrides)
{
var loc = ConvertToScreenSpace(joint.transform.position);
keypoints[idx].index = idx;
keypoints[idx].x = loc.x;
keypoints[idx].y = loc.y;
keypoints[idx].state = 2;
}
// Go through all of the additional or override points defined by joint labels and get
// their locations
foreach (var (joint, idx) in cachedData.overrides)
{
var loc = ConvertToScreenSpace(joint.transform.position);
keypoints[idx].index = idx;
keypoints[idx].x = loc.x;
keypoints[idx].y = loc.y;
keypoints[idx].state = 2;
}

cachedData.keypoints.pose = "unset";
cachedData.keypoints.pose = "unset";

if (cachedData.animator != null)
{
cachedData.keypoints.pose = GetPose(cachedData.animator);
}
if (cachedData.animator != null)
{
cachedData.keypoints.pose = GetPose(cachedData.animator);
}

m_AsyncAnnotations[m_CurrentFrame].keypoints[labeledEntity.instanceId] = cachedData.keypoints;
m_AsyncAnnotations[m_CurrentFrame].keypoints[labeledEntity.instanceId] = cachedData.keypoints;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,24 @@ void IGroundTruthGenerator.SetupMaterialProperties(MaterialPropertyBlock mpb, Re
if (m_IdLabelConfig.TryGetMatchingConfigurationEntry(labeling, out _, out var index))
{
Debug.Assert(index < k_DefaultValue, "Too many entries in the label config");
if (m_InstanceIdToLabelEntryIndexLookup.Length <= instanceId)
if (labeling.enabled)
{
var oldLength = m_InstanceIdToLabelEntryIndexLookup.Length;
m_InstanceIdToLabelEntryIndexLookup.Resize((int)instanceId + 1, NativeArrayOptions.ClearMemory);
if (m_InstanceIdToLabelEntryIndexLookup.Length <= instanceId)
{
var oldLength = m_InstanceIdToLabelEntryIndexLookup.Length;
m_InstanceIdToLabelEntryIndexLookup.Resize((int)instanceId + 1, NativeArrayOptions.ClearMemory);

for (var i = oldLength; i < instanceId; i++)
m_InstanceIdToLabelEntryIndexLookup[i] = k_DefaultValue;
for (var i = oldLength; i < instanceId; i++)
m_InstanceIdToLabelEntryIndexLookup[i] = k_DefaultValue;
}
m_InstanceIdToLabelEntryIndexLookup[(int)instanceId] = (ushort)index;
}
else if (m_InstanceIdToLabelEntryIndexLookup.Length > instanceId)
{
m_InstanceIdToLabelEntryIndexLookup[(int)instanceId] = k_DefaultValue;
}
m_InstanceIdToLabelEntryIndexLookup[(int)instanceId] = (ushort)index;
}
else if (m_InstanceIdToLabelEntryIndexLookup.Length > (int)instanceId)
else if (m_InstanceIdToLabelEntryIndexLookup.Length > instanceId)
{
m_InstanceIdToLabelEntryIndexLookup[(int)instanceId] = k_DefaultValue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,6 @@ public bool Deactivate(IGroundTruthGenerator generator)
return m_ActiveGenerators.Remove(generator);
}

/// <summary>
/// Registers a labeling component
/// </summary>
/// <param name="labeling">the component to register</param>
internal void Register(Labeling labeling)
{
m_LabelsPendingRegistration.Add(labeling);
}

/// <summary>
/// Unregisters a labeling component
/// </summary>
Expand Down
14 changes: 10 additions & 4 deletions com.unity.perception/Runtime/GroundTruth/Labeling/Labeling.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using UnityEditor;
using UnityEngine.Serialization;
Expand Down Expand Up @@ -34,14 +35,19 @@ public class Labeling : MonoBehaviour
/// </summary>
public uint instanceId { get; private set; }

void Awake()
void OnDestroy()
{
labelManager.Register(this);
labelManager.Unregister(this);
}

void OnDestroy()
void OnEnable()
{
labelManager.Unregister(this);
RefreshLabeling();
}

void OnDisable()
{
RefreshLabeling();
}

void Reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ public override void SetupMaterialProperties(
if (!found)
Debug.LogError($"Could not get a unique color for {instanceId}");

mpb.SetVector(k_SegmentationIdProperty, (Color)color);
if (labeling.enabled)
mpb.SetVector(k_SegmentationIdProperty, (Color)color);
else
mpb.SetVector(k_SegmentationIdProperty, (Color) InstanceIdToColorMapping.invalidColor);
#if PERCEPTION_DEBUG
Debug.Log($"Assigning id. Frame {Time.frameCount} id {id}");
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,25 @@ public override void SetupMaterialProperties(
{
var entry = new SemanticSegmentationLabelEntry();
var found = false;
foreach (var l in m_LabelConfig.labelEntries)
if (labeling.enabled)
{
if (labeling.labels.Contains(l.label))
foreach (var l in m_LabelConfig.labelEntries)
{
entry = l;
found = true;
break;
if (labeling.labels.Contains(l.label))
{
entry = l;
found = true;
break;
}
}
}

// Set the labeling ID so that it can be accessed in ClassSemanticSegmentationPass.shader
if (found)
mpb.SetVector(k_LabelingId, entry.color);
else
mpb.SetVector(k_LabelingId, Color.black);

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,43 @@ public IEnumerator TryGet_ReturnsFalse_ForNonMatchingLabel_WhenAllObjectsAreDest
}
}

[UnityTest]
public IEnumerator TryGet_ReturnsFalse_ForMatchingLabelWithDisabledLabelingComponent()
{
var label = "label";
var labeledPlane = TestHelper.CreateLabeledPlane(label: label);
AddTestObjectForCleanup(labeledPlane);
var config = ScriptableObject.CreateInstance<IdLabelConfig>();
var labeling = labeledPlane.GetComponent<Labeling>();

config.Init(new[]
{
new IdLabelEntry()
{
id = 1,
label = label
},
});
using (var cache = new LabelEntryMatchCache(config, Allocator.Persistent))
{
labeling.enabled = false;
//allow label to be registered
yield return null;
Assert.IsFalse(cache.TryGetLabelEntryFromInstanceId(labeledPlane.GetComponent<Labeling>().instanceId, out var labelEntry, out var index));
Assert.AreEqual(-1, index);

labeling.enabled = true;
yield return null;
Assert.IsTrue(cache.TryGetLabelEntryFromInstanceId(labeledPlane.GetComponent<Labeling>().instanceId, out labelEntry, out index));
Assert.AreEqual(0, index);
Assert.AreEqual(config.labelEntries[0], labelEntry);

labeling.enabled = false;
yield return null;
Assert.IsFalse(cache.TryGetLabelEntryFromInstanceId(labeledPlane.GetComponent<Labeling>().instanceId, out labelEntry, out index));
Assert.AreEqual(-1, index);
}
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,50 @@ void OnSegmentationImageReceived(NativeArray<Color32> data)
Assert.AreEqual(1, timesSegmentationImageReceived);
}

[UnityTest]
public IEnumerator SemanticSegmentationPass_WithMatchingButDisabledLabel_ProducesBlack()
{
int timesSegmentationImageReceived = 0;
var expectedPixelValue = new Color32(0, 0, 0, 255);
void OnSegmentationImageReceived(NativeArray<Color32> data)
{
timesSegmentationImageReceived++;
CollectionAssert.AreEqual(Enumerable.Repeat(expectedPixelValue, data.Length), data.ToArray());
}

var cameraObject = SetupCameraSemanticSegmentation(a => OnSegmentationImageReceived(a.data), false);

var gameObject = TestHelper.CreateLabeledPlane();
gameObject.GetComponent<Labeling>().enabled = false;
AddTestObjectForCleanup(gameObject);
yield return null;
//destroy the object to force all pending segmented image readbacks to finish and events to be fired.
DestroyTestObject(cameraObject);
Assert.AreEqual(1, timesSegmentationImageReceived);
}

[UnityTest]
public IEnumerator InstanceSegmentationPass_WithMatchingButDisabledLabel_ProducesBlack()
{
int timesSegmentationImageReceived = 0;
var expectedPixelValue = new Color32(0, 0, 0, 255);
void OnSegmentationImageReceived(NativeArray<Color32> data)
{
CollectionAssert.AreEqual(Enumerable.Repeat(expectedPixelValue, data.Length), data);
timesSegmentationImageReceived++;
}

var cameraObject = SetupCameraInstanceSegmentation((frame, data, renderTexture) => OnSegmentationImageReceived(data));

var gameObject = TestHelper.CreateLabeledPlane();
gameObject.GetComponent<Labeling>().enabled = false;
AddTestObjectForCleanup(gameObject);
yield return null;
//destroy the object to force all pending segmented image readbacks to finish and events to be fired.
DestroyTestObject(cameraObject);
Assert.AreEqual(1, timesSegmentationImageReceived);
}

[UnityTest]
public IEnumerator SemanticSegmentationPass_WithEmptyFrame_ProducesBlack([Values(false, true)] bool showVisualizations)
{
Expand Down

0 comments on commit 07ea8c1

Please sign in to comment.