Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchModuleWrapper in documentation #18849

Closed
emi-dm opened this issue Nov 29, 2023 · 2 comments
Closed

TorchModuleWrapper in documentation #18849

emi-dm opened this issue Nov 29, 2023 · 2 comments
Assignees
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:docs Need to modify the documentation type:feature The user is asking for a new feature.

Comments

@emi-dm
Copy link

emi-dm commented Nov 29, 2023

Where is TorchModuleWrapper in documentation?

it would be nice to add a few examples of use.

@sachinprasadhs sachinprasadhs added the type:docs Need to modify the documentation label Nov 29, 2023
@sachinprasadhs
Copy link
Collaborator

Hi,

You can find the example usage of TorchModuleWrapper in the below section.

Here's an example of how the `TorchModuleWrapper` can be used with vanilla
PyTorch modules.
```python
import torch.nn as nn
import torch.nn.functional as F
import keras
from keras.layers import TorchModuleWrapper
class Classifier(keras.Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# Wrap `torch.nn.Module`s with `TorchModuleWrapper`
# if they contain parameters
self.conv1 = TorchModuleWrapper(
nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))
)
self.conv2 = TorchModuleWrapper(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))
)
self.pool = nn.MaxPool2d(kernel_size=(2, 2))
self.flatten = nn.Flatten()
self.dropout = nn.Dropout(p=0.5)
self.fc = TorchModuleWrapper(nn.Linear(1600, 10))
def call(self, inputs):
x = F.relu(self.conv1(inputs))
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = self.flatten(x)
x = self.dropout(x)
x = self.fc(x)
return F.softmax(x, dim=1)
model = Classifier()
model.build((1, 28, 28))
print("Output shape:", model(torch.ones(1, 1, 28, 28).to("cuda")).shape)
model.compile(
loss="sparse_categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"]
)
model.fit(train_loader, epochs=5)
```
"""

@fchollet
Copy link
Collaborator

Sure, we'll add it to the docs as part of a new layer category.

@sachinprasadhs sachinprasadhs added type:feature The user is asking for a new feature. stat:awaiting keras-eng Awaiting response from Keras engineer labels Nov 29, 2023
@emi-dm emi-dm closed this as completed Dec 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting keras-eng Awaiting response from Keras engineer type:docs Need to modify the documentation type:feature The user is asking for a new feature.
Projects
None yet
Development

No branches or pull requests

3 participants