Skip to content

Commit

Permalink
Merge pull request #5 from Kawaeee/streamlit-cpu
Browse files Browse the repository at this point in the history
Streamlit cpu
  • Loading branch information
Kawaeee authored May 4, 2021
2 parents 73e30df + d4ff43c commit 4819334
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 10 deletions.
28 changes: 23 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@
|Batch Size | 32 |
|Optimizer | ADAM |

## Reproduction
## Model Reproduction
* In order to reproduce the model, it requires our datasets. You can send me an e-mail at [email protected] if you are interested.

- Install dependencies
```Bash
pip install -r requirements.txt
```
- Install dependencies
- ```Remove "+cpu" and "--find-links flag" in requirements.txt to get CUDA support```

```Bash
pip install -r requirements.txt
```

- Run the train.py python script

Expand All @@ -65,3 +67,19 @@
```

- Open and run the notebook for prediction: `predictor.ipynb`

## Streamlit Reproduction
- Install dependencies

```Bash
pip install -r requirements.txt
```

- Run the streamlit

```Bash
streamlit run streamlit_app.py
```

- Streamlit web application will be host on http://localhost:8501

6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ numpy==1.20.2
tqdm==4.60.0
pillow==8.2.0
streamlit==0.80.0
torch==1.8.1
torchvision==0.9.1
# [STREAMLIT] Remove "+cpu" and "--find-links flag" in requirements.txt to get CUDA support
--find-links https://download.pytorch.org/whl/torch_stable.html
torch==1.8.1+cpu
torchvision==0.9.1+cpu
psutil==5.8.0
14 changes: 11 additions & 3 deletions streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
}

# Model configuration
# Streamlit server does not provide GPU, So we go will CPU!
processing_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

img_normalizer = transforms.Normalize(
Expand All @@ -95,7 +96,7 @@
)


@st.cache(allow_output_mutation=True, max_entries=3, ttl=1800)
@st.cache(allow_output_mutation=True, suppress_st_warning=True, max_entries=2, ttl=600)
def initialize_model(device=processing_device):
"""Retrieves the butt_bread trained model and maps it to the CPU by default, can also specify GPU here."""
model = models.resnet152(pretrained=False).to(device)
Expand All @@ -110,6 +111,7 @@ def initialize_model(device=processing_device):

return model

@st.cache(max_entries=5, ttl=300)
def predict(img, model):
"""Make a prediction on a single image"""
input_img = img_transformer(img).float()
Expand All @@ -135,6 +137,10 @@ def predict(img, model):
},
}

input_img = None
pred_logit_tensor = None
pred_probs = None

return json_output

def download_model():
Expand All @@ -143,7 +149,7 @@ def download_model():
print("Downloading butt_bread model !!")
req = requests.get(model_url_path, allow_redirects=True)
open("buttbread_resnet152_3.h5", "wb").write(req.content)
return True
req = None

return True

Expand Down Expand Up @@ -222,6 +228,8 @@ def health_check():
st.image(resized_image)
st.write("Prediction:")
st.json(prediction)
img = None
resized_image = None
prediction = None

# Reset model after used
model = None

0 comments on commit 4819334

Please sign in to comment.