Skip to content

Commit

Permalink
allow 0 object transformer blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Jun 4, 2024
1 parent b8930f0 commit e869fab
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 32 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"editor.formatOnSaveMode": "file",
"editor.formatOnSave": true,
"editor.defaultFormatter": "eeyore.yapf",
"editor.formatOnType": false,
},

"cSpell.words": [
Expand Down
51 changes: 28 additions & 23 deletions cutie/inference/memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def _get_sensory_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
return torch.stack([self.sensory[obj] for obj in obj_ids], dim=1)

def _get_object_mem_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
if obj_ids[0] not in self.obj_v:
# should only happen when the object transformer has been disabled
return None
return torch.stack([self.obj_v[obj] for obj in obj_ids], dim=1)

def _get_visual_values_by_ids(self, obj_ids: List[int]) -> torch.Tensor:
Expand Down Expand Up @@ -185,7 +188,8 @@ def read(self, pix_feat: torch.Tensor, query_key: torch.Tensor, selection: torch
this_msk_value).view(bs, len(objects), self.CV, h, w)
pixel_readout = network.pixel_fusion(pix_feat, visual_readout, this_sensory,
this_last_mask)
this_obj_mem = self._get_object_mem_by_ids(objects).unsqueeze(2)
this_obj_mem = self._get_object_mem_by_ids(objects)
this_obj_mem = this_obj_mem.unsqueeze(2) if this_obj_mem is not None else None
readout_memory, aux_features = network.readout_query(pixel_readout, this_obj_mem)
for i, obj in enumerate(objects):
all_readout_mem[obj] = readout_memory[:, i]
Expand Down Expand Up @@ -219,7 +223,7 @@ def add_memory(self,
bs = key.shape[0]
assert shrinkage.shape[0] == bs
assert msk_value.shape[0] == bs
assert obj_value.shape[0] == bs
assert obj_value is None or obj_value.shape[0] == bs

self.engaged = True
if self.H is None or self.config_stale:
Expand All @@ -245,25 +249,26 @@ def add_memory(self,
selection = selection.flatten(start_dim=2)

# insert object values into object memory
for obj_id, obj in enumerate(objects):
if obj in self.obj_v:
"""streaming average
each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
first embed_dim keeps track of the sum of embeddings
the last dim keeps the total count
averaging in done inside the object transformer
incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
"""
last_acc = self.obj_v[obj][:, :, -1]
new_acc = last_acc + obj_value[:, obj_id, :, -1]

self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
obj_value[:, obj_id, :, :-1])
self.obj_v[obj][:, :, -1] = new_acc
else:
self.obj_v[obj] = obj_value[:, obj_id]
if obj_value is not None:
for obj_id, obj in enumerate(objects):
if obj in self.obj_v:
"""streaming average
each self.obj_v[obj] is (1/2)*num_summaries*(embed_dim+1)
first embed_dim keeps track of the sum of embeddings
the last dim keeps the total count
averaging in done inside the object transformer
incoming obj_value is (1/2)*num_objects*num_summaries*(embed_dim+1)
self.obj_v[obj] = torch.cat([self.obj_v[obj], obj_value[:, obj_id]], dim=0)
"""
last_acc = self.obj_v[obj][:, :, -1]
new_acc = last_acc + obj_value[:, obj_id, :, -1]

self.obj_v[obj][:, :, :-1] = (self.obj_v[obj][:, :, :-1] +
obj_value[:, obj_id, :, :-1])
self.obj_v[obj][:, :, -1] = new_acc
else:
self.obj_v[obj] = obj_value[:, obj_id]

# convert mask value tensor into a dict for insertion
msk_values = {obj: msk_value[:, obj_id] for obj_id, obj in enumerate(objects)}
Expand All @@ -280,7 +285,7 @@ def add_memory(self,
if self.work_mem.non_perm_size(bucket_id) >= self.max_work_tokens:
# Remove obsolete features if needed
if self.long_mem.non_perm_size(bucket_id) >= (self.max_long_tokens -
self.num_prototypes):
self.num_prototypes):
self.long_mem.remove_obsolete_features(
bucket_id,
self.max_long_tokens - self.num_prototypes - self.buffer_tokens)
Expand Down Expand Up @@ -368,7 +373,7 @@ def update_sensory(self, sensory: torch.Tensor, ids: List[int]):
def get_sensory(self, ids: List[int]):
# returns (1/2)*num_objects*C*H*W
return self._get_sensory_by_ids(ids)

def clear_non_permanent_memory(self):
self.work_mem.clear_non_permanent_memory()
if self.use_long_term:
Expand Down
2 changes: 1 addition & 1 deletion cutie/model/aux_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def forward(self, pix_feat: torch.Tensor, aux_input: Dict[str, torch.Tensor],
# B*num_objects*H*W
logits = self.sensory_aux(pix_feat, sensory)
aux_output['sensory_logits'] = self._aggregate_with_selector(logits, selector)
if self.use_query_aux:
if self.use_query_aux and q_logits is not None:
# B*num_objects*num_levels*H*W
aux_output['q_logits'] = self._aggregate_with_selector(
torch.stack(q_logits, dim=2),
Expand Down
16 changes: 12 additions & 4 deletions cutie/model/cutie.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


class CUTIE(nn.Module):

def __init__(self, cfg: DictConfig, *, single_object=False):
super().__init__()
self.cfg = cfg
Expand All @@ -28,17 +27,20 @@ def __init__(self, cfg: DictConfig, *, single_object=False):
self.pixel_dim = model_cfg.pixel_dim
self.embed_dim = model_cfg.embed_dim
self.single_object = single_object
self.object_transformer_enabled = model_cfg.object_transformer.num_blocks > 0

log.info(f'Single object: {self.single_object}')
log.info(f'Object transformer enabled: {self.object_transformer_enabled}')

self.pixel_encoder = PixelEncoder(model_cfg)
self.pix_feat_proj = nn.Conv2d(self.ms_dims[0], self.pixel_dim, kernel_size=1)
self.key_proj = KeyProjection(model_cfg)
self.mask_encoder = MaskEncoder(model_cfg, single_object=single_object)
self.mask_decoder = MaskDecoder(model_cfg)
self.pixel_fuser = PixelFeatureFuser(model_cfg, single_object=single_object)
self.object_transformer = QueryTransformer(model_cfg)
self.object_summarizer = ObjectSummarizer(model_cfg)
if self.object_transformer_enabled:
self.object_transformer = QueryTransformer(model_cfg)
self.object_summarizer = ObjectSummarizer(model_cfg)
self.aux_computer = AuxComputer(cfg)

self.register_buffer("pixel_mean", torch.Tensor(model_cfg.pixel_mean).view(-1, 1, 1), False)
Expand Down Expand Up @@ -80,7 +82,11 @@ def encode_mask(
others,
deep_update=deep_update,
chunk_size=chunk_size)
object_summaries, object_logits = self.object_summarizer(masks, mask_value, need_weights)
if self.object_transformer_enabled:
object_summaries, object_logits = self.object_summarizer(masks, mask_value,
need_weights)
else:
object_summaries, object_logits = None, None
return mask_value, new_sensory, object_summaries, object_logits

def transform_key(self,
Expand Down Expand Up @@ -156,6 +162,8 @@ def readout_query(self,
*,
selector=None,
need_weights=False) -> (torch.Tensor, Dict[str, torch.Tensor]):
if not self.object_transformer_enabled:
return pixel_readout, None
return self.object_transformer(pixel_readout,
obj_memory,
selector=selector,
Expand Down
8 changes: 5 additions & 3 deletions cutie/model/train_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def get_ms_feat_ti(ti):

# add the time dimension
msk_values = msk_val.unsqueeze(3) # B*num_objects*C*T*H*W
obj_values = obj_val.unsqueeze(2) # B*num_objects*T*Q*C
obj_values = obj_val.unsqueeze(
2) if obj_val is not None else None # B*num_objects*T*Q*C

for ti in range(1, seq_length):
if ti <= self.num_ref_frames:
Expand Down Expand Up @@ -101,10 +102,11 @@ def get_ms_feat_ti(ti):
masks,
deep_update=deep_update)
msk_values = torch.cat([msk_values, msk_val.unsqueeze(3)], 3)
obj_values = torch.cat([obj_values, obj_val.unsqueeze(2)], 2)
obj_values = torch.cat([obj_values, obj_val.unsqueeze(2)],
2) if obj_val is not None else None

out[f'masks_{ti}'] = masks
out[f'logits_{ti}'] = logits
out[f'aux_{ti}'] = aux_output

return out
return out
2 changes: 1 addition & 1 deletion cutie/model/transformer/object_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,4 @@ def _get_aux_mask(self, logits: torch.Tensor, selector: torch.Tensor) -> torch.T

aux_mask[torch.where(aux_mask.sum(-1) == aux_mask.shape[-1])] = False

return aux_mask
return aux_mask

0 comments on commit e869fab

Please sign in to comment.