-
Notifications
You must be signed in to change notification settings - Fork 243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BASNet to keras hub #1984
base: master
Are you sure you want to change the base?
Add BASNet to keras hub #1984
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR Laxma! it looks great! can you add a demo colab to verify outputs? Thanks!
input_data=self.images, | ||
) | ||
|
||
def test_end_to_end_model_predict(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add presets test and also self.run_task_test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Started a review. But I think this might need some bigger considerations.
Some questions:
- What are the weight sources we are using here?
- What are the outputs? Looks like there is a list of outputs instead of a normal classification style output.
- How does training this model look?
In general, I think there there will be a lot of core infrastructure that will break if the arch does not have a backbone class. DeepLabV3 is probably much closer to what we will need to go for.
But I think I need answers to the above questions to better suggest a design.
) | ||
|
||
|
||
@keras_hub_export( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can remove the list for single symbol exports.
@keras_hub_export("keras_hub.models.BASNet")
Also I think in keeping with our other naming this should be "keras_hub.models.BASNetImageSegmenter"
right?
"keras_hub.models.BASNet", | ||
] | ||
) | ||
class BASNet(ImageSegmenter): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BASNetImageSegmenter
A Keras model implementing the BASNet architecture for semantic | ||
segmentation. | ||
|
||
References: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally we put these after args I think
backbone: `keras.Model`. The backbone network for the model that is | ||
used as a feature extractor for BASNet prediction encoder. Currently | ||
supported backbones are ResNet18 and ResNet34. Default backbone is | ||
`keras_cv.models.ResNet34Backbone()`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is false, update the docstring!
|
||
Example: | ||
```python | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove newline
) | ||
model.fit(images, labels, epochs=3) | ||
``` | ||
""" # noqa: E501 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove noqa
outputs.extend(predict_model.outputs) | ||
|
||
outputs = [ | ||
keras.layers.Activation("sigmoid", dtype="float32")(_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally we use _
as a dummy variable that is unused. this is used, so use x
maybe?
loss=keras.losses.BinaryCrossentropy(from_logits=False), | ||
metrics=["accuracy"], | ||
) | ||
model.fit(images, labels, epochs=3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this work when the model has multiple outputs it looks like?
Refer this issue for more details