From c24a68fd6c9bbd60be770ed4b59940ed8c8b77c9 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 19 Jun 2022 13:16:26 +0530 Subject: [PATCH] Add more review feedback --- docs/dev-guide/contributing.md | 2 +- test/runtests.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/dev-guide/contributing.md b/docs/dev-guide/contributing.md index 8a83f1057..75574b033 100644 --- a/docs/dev-guide/contributing.md +++ b/docs/dev-guide/contributing.md @@ -35,6 +35,6 @@ All Metalhead.jl model artifacts are hosted using HuggingFace. You can find the 6. Open a PR to the [corresponding HuggingFace repo](https://huggingface.co/FluxML). Do this by going to the "Community" tab in the HuggingFace repository. PRs and discussions are shown as the same thing in the HuggingFace web app. You can use your local Git program to make clone the repo and make PRs if you wish. Check out the [guide on PRs to HuggingFace](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) for more information. 7. Copy the download URL for the model file that you added to HuggingFace. Make sure to grab the URL for a specific commit and not for the `main` branch. 8. Update your Metalhead.jl PR by adding the URL to the Artifacts.toml. -9. If the tests pass for your weights, we will merge your PR! +9. If the tests pass for your weights, we will merge your PR! Your model should pass the `acctest` function in the Metalhead.jl test suite. If your model already exists in the repo, then these tests are already in place, and you can add your model configuration to the `PRETRAINED_MODELS` list in the `runtests.jl` file. Please refer to the ResNet tests as an example. If you want to fix existing weights, then you can follow the same set of steps. diff --git a/test/runtests.jl b/test/runtests.jl index b8b1ef74b..f1a9787b9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,7 +23,7 @@ function gradtest(model, input) return true end -function normalize(data) +function normalize_imagenet(data) cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1)) cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1)) return (data .- cmean) ./ cstd @@ -33,7 +33,7 @@ end const TEST_PATH = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg") const TEST_IMG = imresize(Images.load(TEST_PATH), (224, 224)) # CHW -> WHC -const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3,2,1)) |> normalize +const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3,2,1)) |> normalize_imagenet # image net labels const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))