Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-trained models from HuggingFace and add accuracy testing #164

Merged
merged 9 commits into from
Jun 19, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 55 additions & 23 deletions Artifacts.toml
Original file line number Diff line number Diff line change
@@ -1,39 +1,71 @@
[densenet121]
git-tree-sha1 = "ffc7f7ed1e7f67baca4b76f6c100e0d5042ff063"
[vgg11]
git-tree-sha1 = "78ffe7d74c475cc28175f9e23a545ce2f17b1520"
lazy = true

[[densenet121.download]]
sha256 = "3fd10f0be70cf072fa7f1358f1fbbe01138440dbcaec1b7c8e007084382c1557"
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/densenet121-0.1.1.tar.gz"
[[vgg11.download]]
sha256 = "9703268c19ca2ae34036ca3588664a96dc0ca8d9d6458db78657299c6879880c"
url = "https://huggingface.co/FluxML/vgg11/resolve/275b202a8a4d10b59eef74285921d278b51fdbdb/vgg11.tar.gz"

[googlenet]
git-tree-sha1 = "56cc81845fcca30508fe81da18c7ba0d96d72cdd"
[vgg13]
git-tree-sha1 = "ed006dd09cc24342d4dcd9e2cfaa8c84f063c27a"
lazy = true

[[googlenet.download]]
sha256 = "8ab8d60cc26e81451473badc9dc749b5ffc170a11bc00fb4b203da34fbfdc996"
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/googlenet-0.1.1.tar.gz"
[[vgg13.download]]
sha256 = "ef27949024f5716f7656b3318b06964d76587851f15d9a9127c2b55e5faee288"
url = "https://huggingface.co/FluxML/vgg13/resolve/9593b269ee2c24ce5924d3667496a0d7458a6cb4/vgg13.tar.gz"

[vgg16]
git-tree-sha1 = "759df92ca502324d8624e1c5a940db227908fb9e"
lazy = true

[[vgg16.download]]
sha256 = "f9bad8d9d2c79bc4ebab840f2faded2a0c26c6b2a84f979525964eebcd1886ab"
url = "https://huggingface.co/FluxML/vgg16/resolve/57fdb74b1640815f17eae1a28ae67f0fc1c603db/vgg16.tar.gz"

[vgg19]
git-tree-sha1 = "67f5e867f297086cc911c2cb7985bec8ac1ab23d"
lazy = true

[[vgg19.download]]
sha256 = "5fe26391572b9f6ac84eaa0541d27e959f673f82e6515026cdcd3262cbd93ceb"
url = "https://huggingface.co/FluxML/vgg19/resolve/88e9056f60b054eccdc190a2eeb23731d5c693b6/vgg19.tar.gz"

[resnet18]
git-tree-sha1 = "7b555ed2708e551bfdbcb7e71b25001f4b3731c6"
lazy = true

[[resnet18.download]]
sha256 = "d5782fd873a3072df251c7a4b3cf16efca8ee1da1180ff815bc107833f84bb26"
url = "https://huggingface.co/FluxML/resnet18/resolve/ef9c74047fda4a4a503b1f72553ec05acc90929f/resnet18.tar.gz"

[resnet34]
git-tree-sha1 = "e6e79666cd0fc81cd828508314e6c7f66df8d43d"
lazy = true

[[resnet34.download]]
sha256 = "a8dec13609a86f7a2adac6a44b3af912a863bc2d7319120066c5fdaa04c3f395"
url = "https://huggingface.co/FluxML/resnet34/resolve/42061ddb463902885eea4fcc85275462a5445987/resnet34.tar.gz"

[resnet50]
git-tree-sha1 = "ea3effeaf1ea3969ed5c609f5db5cd0e456ce799"
git-tree-sha1 = "5c442ffd6c51a70c3bc36d849fca86beced446d4"
lazy = true

[[resnet50.download]]
sha256 = "17760ae50e3d59ed7d74c3dfcdb9f0eeaccec1e2ccd095663955c9fed4f318a8"
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/resnet50-0.1.1.tar.gz"
sha256 = "5325920ec91c2a4499ad7e659961f9eaac2b1a3a2905ca6410eaa593ecd35503"
url = "https://huggingface.co/FluxML/resnet50/resolve/10e601719e1cd5b0cab87ce7fd1e8f69a07ce042/resnet50.tar.gz"

[squeezenet]
git-tree-sha1 = "e0e53eb402efe4693417db8cbcc31519e74c8c74"
[resnet101]
git-tree-sha1 = "694a8563ec20fb826334dd663d532b10bb2b3c97"
lazy = true

[[squeezenet.download]]
sha256 = "a3e60f2731296cdf0f32b79badd227eb8dad88a9bee8c828dbe60382869c50f0"
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/squeezenet-0.1.1.tar.gz"
[[resnet101.download]]
sha256 = "f4d737ce640957c30f76bfa642fc9da23e6852d81474d58a2338c1148e55bff0"
url = "https://huggingface.co/FluxML/resnet101/resolve/ea37819163cc3f4a41989a6239ce505e483b112d/resnet101.tar.gz"

[vgg19]
git-tree-sha1 = "072056ec63bf7308cf89885e91852666e191e80a"
[resnet152]
git-tree-sha1 = "55eb883248a276d710d75ecaecfbd2427e50cc0a"
lazy = true

[[vgg19.download]]
sha256 = "0fa000609965604b9d249e84190c30d067d443d73e6c8e340ef09bd013d0bc90"
url = "https://github.com/FluxML/MetalheadWeights/releases/download/v0.1.1/vgg19-0.1.1.tar.gz"
[[resnet152.download]]
sha256 = "57be335e6828d1965c9d11f933d2d41f51e5e534f9bfdbde01c6144fa8862a4d"
url = "https://huggingface.co/FluxML/resnet152/resolve/ba28814d5746643387b5c0e1d2269104e5e9bc8d/resnet152.tar.gz"
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
Expand All @@ -21,6 +22,7 @@ NNlib = "0.7.34, 0.8"
julia = "1.6"

[extras]
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[publish]
Expand All @@ -29,4 +31,4 @@ theme = "_flux-theme"
title = "Metalhead.jl"

[targets]
test = ["Test"]
test = ["Images", "Test"]
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

| Model Name | Function | Pre-trained? |
|:-------------------------------------------------|:------------------------------------------------------------------------------------------|:------------:|
| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.VGG.html) | N |
| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNet.html) | N |
| [VGG](https://arxiv.org/abs/1409.1556) | [`VGG`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.VGG.html) | Y (w/o BN) |
| [ResNet](https://arxiv.org/abs/1512.03385) | [`ResNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ResNet.html) | Y |
| [GoogLeNet](https://arxiv.org/abs/1409.4842) | [`GoogLeNet`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.GoogLeNet.html) | N |
| [Inception-v3](https://arxiv.org/abs/1512.00567) | [`Inceptionv3`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inceptionv3.html) | N |
| [Inception-v4](https://arxiv.org/abs/1602.07261) | [`Inceptionv4`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.Inceptionv4.html) | N |
Expand All @@ -35,6 +35,8 @@
| [ConvNeXt](https://arxiv.org/abs/2201.03545) | [`ConvNeXt`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ConvNeXt.html) | N |
| [ConvMixer](https://arxiv.org/abs/2201.09792) | [`ConvMixer`](https://fluxml.ai/Metalhead.jl/dev/docstrings/Metalhead.ConvMixer.html) | N |

To contribute new models, see our [contributing docs](https://fluxml.ai/Metalhead.jl/dev/docs/developer-guide/contributing.html).

## Getting Started

You can find the Metalhead.jl getting started guide [here](https://fluxml.ai/Metalhead.jl/dev/docs/tutorials/quickstart.html).
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
[deps]
DataAugmentation = "88a5189c-e7ff-4f85-ac6b-e6158070f02e"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Publish = "f065f642-d108-4f50-8aa5-6749150a895a"
40 changes: 40 additions & 0 deletions docs/dev-guide/contributing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Contributing to Metalhead.jl

We welcome contributions from anyone to Metalhead.jl! Thank you for taking the time to make our ecosystem better.

You can contribute by fixing bugs, adding new models, or adding pre-trained weights. If you aren't ready to write some code, but you think you found a bug or have a feature request, please [post an issue](https://github.com/FluxML/Metalhead.jl/issues/new/choose).

Before continuing, make sure you read the [FluxML contributing guide](https://github.com/FluxML/Flux.jl/blob/master/CONTRIBUTING.md) for general guidelines and tips.

## Fixing bugs

To fix a bug in Metalhead.jl, you can [open a PR](https://github.com/FluxML/Metalhead.jl/pulls). It would be helpful to file an issue first so that we can confirm the bug.

## Adding models

To add a new model architecture to Metalhead.jl, you can [open a PR](https://github.com/FluxML/Metalhead.jl/pulls). Keep in mind a few guiding principles for how this package is designed:

- reuse layers from Flux as much as possible (e.g. use `Parallel` before defining a `Bottleneck` struct)
- adhere as closely as possible to a reference such as a published paper (i.e. the structure of your model should be intuitive based on the paper)
darsnack marked this conversation as resolved.
Show resolved Hide resolved
- use generic functional builders (e.g. [`resnet`](#) is the core function that build "ResNet-like" models based on the principles in the paper)
darsnack marked this conversation as resolved.
Show resolved Hide resolved
- use multiple dispatch to add convenience constructors that wrap your functional builder

When in doubt, just open a PR! We are more than happy to help review your code to help it align with the rest of the library. After adding a model, you might consider adding some pre-trained weights (see below).

## Adding pre-trained weights

To add pre-trained weights for an existing model or new model, you can [open a PR](https://github.com/FluxML/Metalhead.jl/pulls). Below, we describe the steps you should follow to get there.

All Metalhead.jl model artifacts are hosted using HuggingFace. You can find the FluxML account [here](https://huggingface.co/FluxML). This [documentation from HuggingFace](https://huggingface.co/docs/hub/models) will provide you with an introduction to their ModelHub. In short, the Model Hub is a collection of Git repositories, similar to Julia packages on GitHub. This means you can [make a pull request to our HuggingFace repositories](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to upload updated weight artifacts just like you would make a PR on GitHub to upload code.

1. Train your model or port the weights from another framework.
2. Save the model using [BSON.jl](https://github.com/JuliaIO/BSON.jl) with `BSON.@save "modelname.bson" model`. It is important that your model is saved under the key `model`.
3. Compress the saved model as a tarball using `tar -cvzf modelname.tar.gz modelname.bson`.
4. Obtain the SHAs (see the [Pkg docs](https://pkgdocs.julialang.org/v1/artifacts/#Basic-Usage)). Edit the `Artifacts.toml` file in the Metalhead.jl repository and add entry for your model. You can leave the URL empty for now.
5. Open a PR on Metalhead.jl. Be sure to ping a maintainer (e.g. `@darsnack`) to let us know that you are adding a pre-trained weight. We will create a model repository on HuggingFace if it does not already exist.
6. Open a PR to the [corresponding HuggingFace repo](https://huggingface.co/FluxML). Do this by going to the "Community" tab in the HuggingFace repository. PRs and discussions are shown as the same thing in the HuggingFace web app. You can use your local Git program to make clone the repo and make PRs if you wish. Check out the [guide on PRs to HuggingFace](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) for more information.
7. Copy the download URL for the model file that you added to HuggingFace. Make sure to grab the URL for a specific commit and not for the `main` branch.
8. Update your Metalhead.jl PR by adding the URL to the Artifacts.toml.
9. If the tests pass for your weights, we will merge your PR!

darsnack marked this conversation as resolved.
Show resolved Hide resolved
If you want to fix existing weights, then you can follow the same set of steps.
35 changes: 32 additions & 3 deletions docs/tutorials/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,44 @@
using Flux, Metalhead
```

Using a model from Metalhead is as simple as selecting a model from the table of [available models](#). For example, below we use the ResNet-18 model.
Using a model from Metalhead is as simple as selecting a model from the table of [available models](#). For example, below we use the pre-trained ResNet-18 model.
{cell=quickstart}
```julia
using Flux, Metalhead

model = ResNet(18)
model = ResNet(18; pretrain = true)
```

Now, we can use this model with Flux like any other model. Below, we train it on some randomly generated data.
Now, we can use this model with Flux like any other model.

First, let's check the accuracy on a test image from ImageNet.
{cell=quickstart}
```julia
using Images

# test image
img = Images.load(download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg"))
```
We'll use the popular [DataAugmentation.jl](https://github.com/lorenzoh/DataAugmentation.jl) library to crop our input image, convert it to a plain array, and normalize the pixels.
{cell=quickstart}
```julia
using DataAugmentation

DATA_MEAN = (0.485, 0.456, 0.406)
DATA_STD = (0.229, 0.224, 0.225)

augmentations = CenterCrop((224, 224)) |>
ImageToTensor() |>
Normalize(DATA_MEAN, DATA_STD)
data = apply(augmentations, Image(img)) |> itemdata

# image net labels
labels = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))

Flux.onecold(model(data), labels)
```

Below, we train it on some randomly generated data.

```julia
using Flux: onehotbatch
Expand Down
1 change: 1 addition & 0 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using BSON
using Artifacts, LazyArtifacts
using Statistics
using MLUtils
using Random

import Functors

Expand Down
3 changes: 2 additions & 1 deletion src/convnets/inception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -579,8 +579,9 @@ Creates an Xception model.

`Xception` does not currently support pretrained weights.
"""
function Xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
function Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
layers = xception(; inchannels, dropout, nclasses)
pretrain && loadpretrain!(layers, "xception")
return Xception(layers)
end

Expand Down
2 changes: 1 addition & 1 deletion src/convnets/resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,6 @@ function ResNet(depth::Integer = 50; pretrain = false, nclasses = 1000)
@assert depth in keys(resnet_config) "`depth` must be one of $(sort(collect(keys(resnet_config))))"
config, block = resnet_config[depth]
model = ResNet(config...; block = block, nclasses = nclasses)
pretrain && loadpretrain!(model, string("ResNet", depth))
pretrain && loadpretrain!(model, string("resnet", depth))
return model
end
4 changes: 2 additions & 2 deletions src/convnets/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses
fcsize = 4096,
dropout = 0.5)
if pretrain && !batchnorm
loadpretrain!(model, string("VGG", depth))
loadpretrain!(model, string("vgg", depth))
elseif pretrain
loadpretrain!(model, "VGG$(depth)-BN)")
loadpretrain!(model, "vgg$(depth)-bn)")
end
return model
end
Loading