You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #2787
# context
* previous diff triggered S495021
* the error message is like
```
ModelGenerationPlatformError("AttributeError: '_EmbeddingBagProxy' object has no attribute 'weight'")
```
* This is because in some flow the EBC module is fx traced so there is no actual EBC but a Proxy. Without full context it's risky to push this change.
* as a workaround, we'll just convert the unsharded EBC back to float32 so it's compatible with the input KJT.weight of float32
NOTE: this hacky change (unsharded EBC float16 ==> float32) is only needed in the tests, where we want to compare the results from sharded EBC.
WARNING: We make a strong assumption here that in any unsharded EBC (with dtype=float16) use case, the input KJT.weights should never be float32.
Reviewed By: basilwong
Differential Revision: D70712348
fbshipit-source-id: f2abaa601adf3052ea322cf326363da8bfef96c3
0 commit comments