diff --git a/README.md b/README.md index 239707b..daf0ee3 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,80 @@ if __name__ == "__main__": ``` +## Model Architecture + +```mermaid +flowchart TB + subgraph Inputs["Input Observations"] + SL[Stereo Left Image] + SR[Stereo Right Image] + WL[Wrist Left Image] + WR[Wrist Right Image] + end + + subgraph ImageEncoder["Image Encoder"] + direction TB + CNN["CNN Backbone + Conv2d layers + ReLU + MaxPool"] + Proj["Projection Layer + Linear(256, hidden_dim)"] + CNN --> Proj + end + + subgraph TransformerEncoder["Transformer Encoder (x4 layers)"] + direction TB + SA["Self Attention"] + FF["Feed Forward"] + N1["LayerNorm"] + N2["LayerNorm"] + SA --> N1 + N1 --> FF + FF --> N2 + end + + subgraph TransformerDecoder["Transformer Decoder (x7 layers)"] + direction TB + CA["Cross Attention"] + FFD["Feed Forward"] + N3["LayerNorm"] + N4["LayerNorm"] + CA --> N3 + N3 --> FFD + FFD --> N4 + end + + subgraph ActionPredictor["Action Predictor"] + direction TB + MLP["MLP Layers"] + Out["Output Layer + 20-dim vector"] + MLP --> Out + end + + subgraph Outputs["Action Outputs"] + LP["Left Position (3)"] + LR["Left Rotation (6)"] + LG["Left Gripper (1)"] + RP["Right Position (3)"] + RR["Right Rotation (6)"] + RG["Right Gripper (1)"] + end + + SL & SR & WL & WR --> ImageEncoder + ImageEncoder --> |"[B, 4, D]"| TransformerEncoder + TransformerEncoder --> |"Memory"| TransformerDecoder + TransformerDecoder --> |"[B, D]"| ActionPredictor + ActionPredictor --> LP & LR & LG & RP & RR & RG + + style Inputs fill:#e1f5fe,stroke:#01579b + style ImageEncoder fill:#fff3e0,stroke:#e65100 + style TransformerEncoder fill:#f3e5f5,stroke:#4a148c + style TransformerDecoder fill:#e8f5e9,stroke:#1b5e20 + style ActionPredictor fill:#fbe9e7,stroke:#bf360c + style Outputs fill:#f3e5f5,stroke:#4a148c +``` + ## Training Example ** on progress **