Skip to content

Commit

Permalink
Ready to push
Browse files Browse the repository at this point in the history
  • Loading branch information
Fra committed Dec 11, 2018
1 parent 047abb3 commit 4c2ca88
Show file tree
Hide file tree
Showing 27 changed files with 130 additions and 71 deletions.
2 changes: 1 addition & 1 deletion DummyVisualisation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .Visualisation import Visualisation
from mirror.visualisations.Visualisation import Visualisation

class DummyVisualisation(Visualisation):

Expand Down
72 changes: 54 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,73 @@ Basic example:

```python
from mirror import mirror
from mirror.visualisations import DeepDream

from PIL import Image

from torchvision.models import resnet101
from torchvision.models import resnet101, resnet18, vgg16
from torchvision.transforms import ToTensor, Resize, Compose
# create a model
model = resnet101(True)

cat = Image.open("cat.jpg")
# create a model
model = vgg16(pretrained=True)

cat = Image.open("./cat.jpg")
# resize the image and make it a tensor
input = Compose([Resize((224,224)), ToTensor()])(cat)
# add 1 dim for batch
input = input.view(1,3,224,224)
input = input.unsqueeze(0)
# call mirror with the input and the model
mirror(input, model)
mirror(input, model, visualisations=[DeepDream])
```

It will automatic open a new tab in your browser

### Create a Visualisation

You can find an example below

```python
from mirror.visualisations.Visualisation import Visualisation

class DummyVisualisation(Visualisation):

def __call__(self, inputs, layer):
return inputs.repeat(self.params['repeat']['value'],1, 1, 1)

@property
def name(self):
return 'dummy'

def init_params(self):
return {'repeat' : {
'type' : 'slider',
'min' : 1,
'max' : 100,
'value' : 3,
'step': 1,
'params': {}
}}

```

The `__call__` function is called each time you click a layer or change a value in the options on the right.

The `init_params` parameters function returns a dictionary of options that will be showed on the rigth drawer of the application. For know only `slider` and `radio` is supported

### TODO
- Support multiple inputs and cache them
- Make a generic abstraction of a visualisation in order to add more features
- [x] Cache reused layer
- [x] Make a generic abstraction of a visualisation in order to add more features
- [ ] Add more options for the parameters (dropdown, text)
- [ ] Support multiple inputs
- [ ] Support multiple models
- Add all visualisation present here https://github.com/utkuozbulak/pytorch-cnn-visualizations
* [Gradient visualization with vanilla backpropagation](#gradient-visualization)
* [Gradient visualization with guided backpropagation](#gradient-visualization) [1]
* [Gradient visualization with saliency maps](#gradient-visualization) [4]
* [Gradient-weighted [3] class activation mapping](#gradient-visualization) [2]
* [Guided, gradient-weighted class activation mapping](#gradient-visualization) [3]
* [Smooth grad](#smooth-grad) [8]
* [CNN filter visualization](#convolutional-neural-network-filter-visualization) [9]
* [Inverted image representations](#inverted-image-representations) [5]
* [Deep dream](#deep-dream) [10]
* [Class specific image generation](#class-specific-image-generation) [4]
* [ ] [Gradient visualization with vanilla backpropagation](#gradient-visualization)
* [ ] [Gradient visualization with guided backpropagation](#gradient-visualization) [1]
* [ ] [Gradient visualization with saliency maps](#gradient-visualization) [4]
* [ ] [Gradient-weighted [3] class activation mapping](#gradient-visualization) [2]
* [ ] [Guided, gradient-weighted class activation mapping](#gradient-visualization) [3]
* [ ] [Smooth grad](#smooth-grad) [8]
* [x] [CNN filter visualization](#convolutional-neural-network-filter-visualization) [9]
* [ ] [Inverted image representations](#inverted-image-representations) [5]
* [x] [Deep dream](#deep-dream) [10]
* [ ] [Class specific image generation](#class-specific-image-generation) [4]
20 changes: 7 additions & 13 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
from mirror import mirror
from mirror.visualisations import DeepDream

from PIL import Image

from torchvision.models import resnet101, resnet18, vgg16
from torchvision.transforms import ToTensor, Resize, Compose

# model = resnet101(True)
model = vgg16(True)

# cat = Image.open("/home/francesco/Documents/mirror/mirror/resources/sky-dd.jpeg")
# create a model
model = vgg16(pretrained=True)

cat = Image.open("./cat.jpg")

#
# cat = Image.open("/home/francesco/Documents/mirror/mirror/resources/the_starry_night-wallpaper-1920x1200.jpg")

# resize the image and make it a tensor
input = Compose([Resize((224,224)), ToTensor()])(cat)
#
# input = Compose([ToTensor()])(cat)

# add 1 dim for batch
input = input.unsqueeze(0)

mirror(input, model)
# call mirror with the input and the model
mirror(input, model, visualisations=[DeepDream])
Binary file modified mirror/__pycache__/app.cpython-36.pyc
Binary file not shown.
Binary file modified mirror/__pycache__/server.cpython-36.pyc
Binary file not shown.
6 changes: 3 additions & 3 deletions mirror/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from .tree import Tracer
from .server import Builder

def mirror(input, model):
def mirror(input, model, visualisations=[]):
tracer = Tracer(module=model)
tracer(input)

builder = Builder()

app = builder.build(input, model, tracer)
app = builder.build(input, model, tracer, visualisations)

# webbrowser.open_new('http://localhost:5000') # opens in default browser
webbrowser.open_new('http://localhost:5000') # opens in default browser

app.run(host="0.0.0.0", port=5000)

2 changes: 1 addition & 1 deletion mirror/client/src/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class App extends Component {
<LayerOutputs module={module} classes={classes}/>
</main>

<Hidden mdDown >
<Hidden smDown >
<Settings
toogle={module.toogleDrawer}
open={module.state.open}
Expand Down
5 changes: 3 additions & 2 deletions mirror/client/src/Module/LayerOutputs/LayerOutputs.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,10 @@ class LayerOutputs extends React.Component {
<Grid item>
{this.props.module.state.outputs.length == 0 ?
(<h1>Nothing to show</h1>) :
(<Button variant="contained" color="primary" onClick={() => module.getLayerOutputs()}>
this.props.module.state.next ? (
<Button variant="contained" color="primary" onClick={() => module.getLayerOutputs()}>
More
</Button>)}
</Button>) : ''}

</Grid>
</Grid>
Expand Down
11 changes: 6 additions & 5 deletions mirror/client/src/ModuleContainer/ModuleContainer.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class ModuleContainer extends Container {
outputs: [],
layer: { name : ''},
last: 0,
next: false,
settings: { size: 50 }
}

Expand All @@ -40,9 +41,9 @@ class ModuleContainer extends Container {

const res = await axios.get(api.getModuleLayerOutput(layer.id, last), { params: { last } })

var outputs = isSameLayer ? this.state.outputs.concat(res.data) : res.data
var outputs = isSameLayer ? this.state.outputs.concat(res.data.links) : res.data.links

await this.setState({ outputs, layer, last, isLoading: false })
await this.setState({ outputs, layer, last, isLoading: false, next: res.data.next })
} catch {
await this.setState({ isLoading: false })
}
Expand All @@ -56,9 +57,9 @@ class ModuleContainer extends Container {
await this.setState({ isLoading: true })

const res = await axios.get(api.GET_VISUALISATIONS)
const visualisations = res.data

await this.setState({ visualisations, isLoading: false })
const visualisations = res.data.visualisations
const currentVisualisation = res.data.current
await this.setState({ visualisations, isLoading: false, currentVisualisation })
}

async setVisualisationsSettings(data){
Expand Down
3 changes: 1 addition & 2 deletions mirror/client/src/Settings/Settings.js
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ class Settings extends Component {
}

onVisualisationSettingsChange = (data) => {
console.log('onVisualisationSettingsChange', data)
this.props.module.setVisualisationsSettings(data)
}

Expand All @@ -160,7 +159,7 @@ class Settings extends Component {
classes={{
paper: classes.drawerPaper,
}}
>
>
<div className={classes.toolbar} />
<List className={classes.settings}>
<ListItem>
Expand Down
22 changes: 12 additions & 10 deletions mirror/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import time

from flask import Flask, request, Response, send_file, jsonify
from .visualisations import WeightsVisualisation, DummyVisualisation, DeepDream
from .visualisations import WeightsVisualisation
from PIL import Image
import pprint
from torchvision.transforms import ToPILImage
Expand All @@ -22,7 +22,8 @@ def build(self, input, model, tracer, visualisations=[]):
input = input.to(self.device)
model = model.to(self.device)

self.visualisations = [WeightsVisualisation(model, tracer), DeepDream(model, tracer), *visualisations]
visualisations = [v(model, tracer) for v in visualisations]
self.visualisations = [WeightsVisualisation(model, tracer), *visualisations]

self.name2visualisations = { v.name : v for v in self.visualisations}
self.current_vis = self.visualisations[0]
Expand Down Expand Up @@ -54,7 +55,8 @@ def api_model_layer(id):
def api_visualisations():
serialised = [v.properties for v in self.visualisations]

response = jsonify(serialised)
response = jsonify({ 'visualisations': serialised,
'current': self.current_vis.properties})

return response

Expand All @@ -67,11 +69,12 @@ def api_visualisation():
if vis_key not in self.name2visualisations:
response = Response(status=500, response='Visualisation {} not supported or does not exist'.format(vis_key))
else:
# TODO I should think on a cleaver way to update properties and params
self.name2visualisations[vis_key].properties = data
self.name2visualisations[vis_key].params = self.name2visualisations[vis_key].properties['params']
pprint.pprint(self.name2visualisations[vis_key].properties)
self.current_vis = self.name2visualisations[vis_key]
self.name2visualisations[vis_key].cache = {}

response = jsonify(self.name2visualisations[vis_key].properties)

return response
Expand All @@ -80,7 +83,7 @@ def api_visualisation():
def api_model_layer_output(id):
try:
layer = tracer.idx_to_value[id].v
print(self.current_vis)

if input not in self.current_vis.cache: self.current_vis.cache[input] = {}
# TODO need to cache for vis
layer_cache = self.current_vis.cache[input]
Expand All @@ -91,28 +94,27 @@ def api_model_layer_output(id):
self.outputs = layer_cache[layer]

outputs = self.outputs
print(self.outputs.shape)

if len(outputs.shape) < 3: raise ValueError

last = int(request.args['last'])
max = min((last + MAX_LINKS_EVERY_REQUEST), outputs.shape[0])

if last >= max: raise StopIteration

response = ['/api/model/image/{}/{}/{}/{}/{}'.format(hash(input),
hash(self.current_vis),
hash(time.time()),
id,
i) for i in range(last, max)]
response = jsonify(response)

response = jsonify({ 'links' : response, 'next': last + 1< max})


except KeyError:
response = Response(status=500, response='Index not found.')
except ValueError:
response = Response(status=404, response='Outputs must be an array of images')
except StopIteration:
response = Response(status=404, response='No more.')
response = jsonify({ 'links' : [], 'next': False})

return response

Expand Down
14 changes: 7 additions & 7 deletions mirror/static/asset-manifest.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
{
"main.css": "/static/css/main.8a0d906b.chunk.css",
"main.js": "/static/js/main.33f09332.chunk.js",
"main.js.map": "/static/js/main.33f09332.chunk.js.map",
"static/js/1.bd250341.chunk.js": "/static/js/1.bd250341.chunk.js",
"static/js/1.bd250341.chunk.js.map": "/static/js/1.bd250341.chunk.js.map",
"main.css": "/static/css/main.a67564f8.chunk.css",
"main.js": "/static/js/main.fef9010a.chunk.js",
"main.js.map": "/static/js/main.fef9010a.chunk.js.map",
"static/js/1.6a64618b.chunk.js": "/static/js/1.6a64618b.chunk.js",
"static/js/1.6a64618b.chunk.js.map": "/static/js/1.6a64618b.chunk.js.map",
"runtime~main.js": "/static/js/runtime~main.229c360f.js",
"runtime~main.js.map": "/static/js/runtime~main.229c360f.js.map",
"static/css/main.8a0d906b.chunk.css.map": "/static/css/main.8a0d906b.chunk.css.map",
"static/css/main.a67564f8.chunk.css.map": "/static/css/main.a67564f8.chunk.css.map",
"index.html": "/index.html",
"precache-manifest.c308b7b331ea8e3c4d3b692bf72f5a99.js": "/precache-manifest.c308b7b331ea8e3c4d3b692bf72f5a99.js",
"precache-manifest.128ea1df9a72d8aaff259ac878e1fb66.js": "/precache-manifest.128ea1df9a72d8aaff259ac878e1fb66.js",
"service-worker.js": "/service-worker.js"
}
2 changes: 2 additions & 0 deletions mirror/static/css/main.a67564f8.chunk.css

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mirror/static/css/main.a67564f8.chunk.css.map

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 4c2ca88

Please sign in to comment.