Skip to content

Commit

Permalink
[readme]
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Nov 17, 2024
1 parent 6a0e1f9 commit 44ad40e
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 21 deletions.
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,31 @@ However, they did not use this larger dataset, instead collecting their own cont
The paper does not mention if they plan to release their datasets publicly.


## Implementation Details from the paper:

```txt
4 Implementation Details
To train our policies, we use action chunking with transformers (ACT) [23] and diffusion policy
[64]. The policies were trained using the endoscope and wrist cameras images as input, which are all
downsized to image size of 224 × 224 × 3. The original input size of the surgical endoscope images
were 1024 × 1280 × 3 and the wrist images were 480 × 640 × 3. Kinematics data is not provided as
input as commonly done in other imitation learning approaches because it is generally inconsistent
due to the design limitations of the dVRK. The policy outputs include the end-effector (delta) position,
(delta) orientation, and jaw angle for both arms. We leave further specific implementation details in
Appendix A.
```

### Appendix A

```txt
main modifications include changing the input layers to accept four images, which include left/right surgical endoscope views and left/right wrist camera views. The output dimensions are also
revised to generate end-effector poses, which amounts to a 10-dim vector for each arm (position [3]+ orientation [6] + jaw angle [1] = 10), thus amounting to a 20-dim vector total for both arms. The
orientation was modeled using a 6D rotation representation following [21], where the 6 elements corrrespond to the first two columns of the rotation matrix. Since the network predictions may not
generate orthonormal vectors, Gram-Schmidt process is performed to convert them to orthonormal vectors, and a cross product of the two vectors are performed to generate the remaining third column
of the rotation matrix. For diffusion policy, similar modifications are made such as changing the input and the output dimensions of the network appropriately. The specific hyperparameters for training
are shown in Table 3 and 4.
```

# Todo

- [ ] Add training logic (in progress)
Expand Down
108 changes: 87 additions & 21 deletions srt_torch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional, Tuple
from dataclasses import dataclass
from loguru import logger
import timm

# Configure logging
logger.add("srt.log", rotation="500 MB")
Expand Down Expand Up @@ -98,47 +99,112 @@ def from_tensor(tensor: torch.Tensor) -> "RobotAction":
right_gripper=tensor[..., 19:20],
)


class ImageEncoder(nn.Module):
"""Encodes multiple camera views into latent space."""
"""Encodes multiple camera views into latent space using EVA Large."""

def __init__(self, config: ModelConfig):
super().__init__()
self.config = config

self.backbone = nn.Sequential(
nn.Conv2d(
config.num_channels, 64, 7, stride=2, padding=3
),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, stride=2, padding=1),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, 3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),

# Initialize EVA Large backbone
self.backbone = timm.create_model(
'eva_large_patch14_336',
pretrained=True,
num_classes=0, # Remove classification head
global_pool='avg', # Use average pooling
)

self.projection = nn.Linear(256, config.hidden_dim)

# Freeze backbone parameters (optional, can be controlled via config)
if getattr(config, 'freeze_backbone', True):
for param in self.backbone.parameters():
param.requires_grad = False

# Projection head to match transformer dimensions
self.projection = nn.Sequential(
nn.Linear(1024, config.hidden_dim * 2), # EVA Large outputs 1024 features
nn.LayerNorm(config.hidden_dim * 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.hidden_dim * 2, config.hidden_dim),
nn.LayerNorm(config.hidden_dim),
nn.Dropout(config.dropout)
)

# Initialize projection weights
self._init_weights()

# Image normalization parameters for EVA
self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

def _init_weights(self):
"""Initialize the weights of the projection layers."""
for m in self.projection.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def preprocess_images(self, images: torch.Tensor) -> torch.Tensor:
"""Normalize and preprocess images for EVA model."""
# Ensure images are float and in range [0, 1]
if images.dtype == torch.uint8:
images = images.float() / 255.0

# Normalize with ImageNet statistics
images = (images - self.mean) / self.std

# Resize if needed
if images.shape[-2:] != (336, 336): # EVA Large expected input size
images = F.interpolate(
images,
size=(336, 336),
mode='bilinear',
align_corners=False
)

return images

def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Args:
images: [B, 4, C, H, W] tensor of camera views
B: batch size
4: number of views (stereo_left, stereo_right, wrist_left, wrist_right)
C: channels (3 for RGB)
H, W: height and width
Returns:
[B, 4, D] encoded features
[B, 4, D] encoded features where D is config.hidden_dim
"""
B = images.shape[0]
features = []

# Process each camera view separately
for i in range(4):
x = self.backbone(images[:, i])
x = self.projection(x)
# Extract and preprocess single view
view = images[:, i] # [B, C, H, W]
view = self.preprocess_images(view)

# Extract features through EVA backbone
with torch.cuda.amp.autocast(enabled=True): # Enable AMP for efficiency
x = self.backbone(view) # [B, 1024]

# Project to transformer dimension
x = self.projection(x) # [B, hidden_dim]
features.append(x)

return torch.stack(features, dim=1)
# Stack all views
features = torch.stack(features, dim=1) # [B, 4, hidden_dim]

return features

def get_output_dim(self) -> int:
"""Return the output dimension of the encoder."""
return self.config.hidden_dim

class TransformerBlock(nn.Module):
"""Standard transformer encoder/decoder block."""
Expand Down

0 comments on commit 44ad40e

Please sign in to comment.