diff --git a/.gitignore b/.gitignore
index 9f11b75..a047a94 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,2 @@
.idea/
+__pycache__/
\ No newline at end of file
diff --git a/DummyVisualisation.py b/DummyVisualisation.py
new file mode 100644
index 0000000..6cf0ab8
--- /dev/null
+++ b/DummyVisualisation.py
@@ -0,0 +1,20 @@
+from mirror.visualisations.Visualisation import Visualisation
+
+class DummyVisualisation(Visualisation):
+
+ def __call__(self, inputs, layer):
+ return inputs.repeat(self.params['repeat']['value'],1, 1, 1)
+
+ @property
+ def name(self):
+ return 'dummy'
+
+ def init_params(self):
+ return {'repeat' : {
+ 'type' : 'slider',
+ 'min' : 1,
+ 'max' : 100,
+ 'value' : 3,
+ 'step': 1,
+ 'params': {}
+ }}
diff --git a/README.md b/README.md
index 5aec8f7..f205233 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
This is a raw beta so expect lots of things to change and improve over time.
-![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/mirror/master/mirror/resources/mirror.gif)
+![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/develop/resources/mirror.gif?raw=true)
### Getting started
@@ -17,37 +17,77 @@ Basic example:
```python
from mirror import mirror
+from mirror.visualisations import DeepDream
from PIL import Image
-from torchvision.models import resnet101
+from torchvision.models import resnet101, resnet18, vgg16
from torchvision.transforms import ToTensor, Resize, Compose
-# create a model
-model = resnet101(True)
-cat = Image.open("cat.jpg")
+# create a model
+model = vgg16(pretrained=True)
+
+cat = Image.open("./cat.jpg")
# resize the image and make it a tensor
input = Compose([Resize((224,224)), ToTensor()])(cat)
# add 1 dim for batch
-input = input.view(1,3,224,224)
+input = input.unsqueeze(0)
# call mirror with the input and the model
-mirror(input, model)
+mirror(input, model, visualisations=[DeepDream])
```
It will automatic open a new tab in your browser
+![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/develop/resources/mirror.jpg?raw=true)
+
+### Create a Visualisation
+
+You can find an example below
+
+```python
+from mirror.visualisations.Visualisation import Visualisation
+
+class DummyVisualisation(Visualisation):
+
+ def __call__(self, inputs, layer):
+ return inputs.repeat(self.params['repeat']['value'],1, 1, 1)
+
+ @property
+ def name(self):
+ return 'dummy'
+
+ def init_params(self):
+ return {'repeat' : {
+ 'type' : 'slider',
+ 'min' : 1,
+ 'max' : 100,
+ 'value' : 3,
+ 'step': 1,
+ 'params': {}
+ }}
+
+```
+
+![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/develop/resources/dummy.jpg?raw=true)
+
+The `__call__` function is called each time you click a layer or change a value in the options on the right.
+
+The `init_params` function returns a dictionary of options that will be showed on the right drawer of the application. For now only `slider` and `radio` are supported
### TODO
-- Support multiple inputs and cache them
-- Make a generic abstraction of a visualisation in order to add more features
+- [x] Cache reused layer
+- [x] Make a generic abstraction of a visualisation in order to add more features
+- [ ] Add more options for the parameters (dropdown, text)
+- [ ] Support multiple inputs
+- [ ] Support multiple models
- Add all visualisation present here https://github.com/utkuozbulak/pytorch-cnn-visualizations
- * [Gradient visualization with vanilla backpropagation](#gradient-visualization)
- * [Gradient visualization with guided backpropagation](#gradient-visualization) [1]
- * [Gradient visualization with saliency maps](#gradient-visualization) [4]
- * [Gradient-weighted [3] class activation mapping](#gradient-visualization) [2]
- * [Guided, gradient-weighted class activation mapping](#gradient-visualization) [3]
- * [Smooth grad](#smooth-grad) [8]
- * [CNN filter visualization](#convolutional-neural-network-filter-visualization) [9]
- * [Inverted image representations](#inverted-image-representations) [5]
- * [Deep dream](#deep-dream) [10]
- * [Class specific image generation](#class-specific-image-generation) [4]
+ * [ ] [Gradient visualization with vanilla backpropagation](#gradient-visualization)
+ * [ ] [Gradient visualization with guided backpropagation](#gradient-visualization) [1]
+ * [ ] [Gradient visualization with saliency maps](#gradient-visualization) [4]
+ * [ ] [Gradient-weighted [3] class activation mapping](#gradient-visualization) [2]
+ * [ ] [Guided, gradient-weighted class activation mapping](#gradient-visualization) [3]
+ * [ ] [Smooth grad](#smooth-grad) [8]
+ * [x] [CNN filter visualization](#convolutional-neural-network-filter-visualization) [9]
+ * [ ] [Inverted image representations](#inverted-image-representations) [5]
+ * [x] [Deep dream](#deep-dream) [10]
+ * [ ] [Class specific image generation](#class-specific-image-generation) [4]
diff --git a/example.py b/example.py
index 6217e8b..1a395cf 100644
--- a/example.py
+++ b/example.py
@@ -1,16 +1,18 @@
from mirror import mirror
+from mirror.visualisations import DeepDream
from PIL import Image
-from torchvision.models import resnet101
+from torchvision.models import resnet101, resnet18, vgg16
from torchvision.transforms import ToTensor, Resize, Compose
-model = resnet101(True)
-
-cat = Image.open("cat.jpg")
+# create a model
+model = vgg16(pretrained=True)
+cat = Image.open("./cat.jpg")
+# resize the image and make it a tensor
input = Compose([Resize((224,224)), ToTensor()])(cat)
-
-input = input.view(1,3,224,224)
-
-mirror(input, model)
\ No newline at end of file
+# add 1 dim for batch
+input = input.unsqueeze(0)
+# call mirror with the input and the model
+mirror(input, model, visualisations=[DeepDream])
\ No newline at end of file
diff --git a/mirror/__pycache__/app.cpython-36.pyc b/mirror/__pycache__/app.cpython-36.pyc
index bf49284..ddc41b1 100644
Binary files a/mirror/__pycache__/app.cpython-36.pyc and b/mirror/__pycache__/app.cpython-36.pyc differ
diff --git a/mirror/__pycache__/server.cpython-36.pyc b/mirror/__pycache__/server.cpython-36.pyc
index 464d4a4..5120b3f 100644
Binary files a/mirror/__pycache__/server.cpython-36.pyc and b/mirror/__pycache__/server.cpython-36.pyc differ
diff --git a/mirror/__pycache__/tree.cpython-36.pyc b/mirror/__pycache__/tree.cpython-36.pyc
index 2a71c0d..81753e6 100644
Binary files a/mirror/__pycache__/tree.cpython-36.pyc and b/mirror/__pycache__/tree.cpython-36.pyc differ
diff --git a/mirror/app.py b/mirror/app.py
index a804ca8..c8be4bb 100644
--- a/mirror/app.py
+++ b/mirror/app.py
@@ -1,13 +1,15 @@
import webbrowser
from .tree import Tracer
-from .server import build
+from .server import Builder
-def mirror(input, model):
+def mirror(input, model, visualisations=[]):
tracer = Tracer(module=model)
tracer(input)
- app = build(input, model, tracer)
+ builder = Builder()
+
+ app = builder.build(input, model, tracer, visualisations)
webbrowser.open_new('http://localhost:5000') # opens in default browser
diff --git a/mirror/client/package-lock.json b/mirror/client/package-lock.json
index 90aeae9..429ad77 100644
--- a/mirror/client/package-lock.json
+++ b/mirror/client/package-lock.json
@@ -14134,6 +14134,11 @@
"resolved": "https://registry.npmjs.org/throat/-/throat-4.1.0.tgz",
"integrity": "sha1-iQN8vJLFarGJJua6TLsgDhVnKmo="
},
+ "throttle-debounce": {
+ "version": "2.0.1",
+ "resolved": "https://registry.npmjs.org/throttle-debounce/-/throttle-debounce-2.0.1.tgz",
+ "integrity": "sha512-Sr6jZBlWShsAaSXKyNXyNicOrJW/KtkDqIEwHt4wYwWA2wa/q67Luhqoujg48V8hTk60wB56tYrJJn6jc2R7VA=="
+ },
"through": {
"version": "2.3.8",
"resolved": "http://registry.npmjs.org/through/-/through-2.3.8.tgz",
diff --git a/mirror/client/package.json b/mirror/client/package.json
index 9691b59..2e90d92 100644
--- a/mirror/client/package.json
+++ b/mirror/client/package.json
@@ -13,6 +13,7 @@
"react-script": "^2.0.5",
"react-scripts": "^2.0.4",
"reactstrap": "^6.5.0",
+ "throttle-debounce": "^2.0.1",
"unstated": "^2.1.1"
},
"scripts": {
diff --git a/mirror/client/src/App.js b/mirror/client/src/App.js
index 3787ad3..1344d17 100644
--- a/mirror/client/src/App.js
+++ b/mirror/client/src/App.js
@@ -18,6 +18,7 @@ import LinearProgress from '@material-ui/core/LinearProgress';
import LayerOutputs from './Module/LayerOutputs/LayerOutputs'
import MoreIcon from '@material-ui/icons/MoreVert';
+import Hidden from '@material-ui/core/Hidden';
const drawerWidth = 300;
@@ -25,24 +26,38 @@ const styles = theme => ({
root: {
flexGrow: 1,
// height: 440,
- zIndex: 1,
+ // zIndex: 1,
// overflow: 'hidden',
- position: 'relative',
+ // position: 'relative',
display: 'flex',
+ flexDirection : 'row',
minHeight: '100vh'
},
+ typography: {
+ useNextVariants: true,
+ },
appBar: {
zIndex: theme.zIndex.drawer + 1,
},
+ drawer: {
+ width: drawerWidth,
+ flexShrink: 0,
+ },
drawerPaper: {
- position: 'relative',
+ flexShrink: 0,
+
+ // position: 'relative',
width: drawerWidth,
},
content: {
flexGrow: 1,
+ // marginLeft: '300px',
+ // position: 'fixed',
+ // width: '100%',
+ // height: '100%',
backgroundColor: theme.palette.background.default,
padding: theme.spacing.unit * 3,
- minWidth: 0, // So the Typography noWrap works
+ // minWidth: 0, // So the Typography noWrap works
},
toolbar: theme.mixins.toolbar,
@@ -53,9 +68,16 @@ const styles = theme => ({
zIndex: 9999
},
- settn: {
+ settings: {
width: '300px !important'
-}
+ },
+
+ sliders : {
+ width: '200px !important'
+ },
+
+ layersOuput : {
+ }
})
function MyAppBar({ module, classes }) {
@@ -66,9 +88,12 @@ function MyAppBar({ module, classes }) {
Mirror
+
+
+
)
@@ -80,6 +105,7 @@ class App extends Component {
}
+
toggleSettings = () => {
const openSettings = !this.state.openSettings
this.setState({ openSettings })
@@ -94,23 +120,40 @@ class App extends Component {
-
+
{module.state.isLoading ? (
) : ''}
-
+
+
+
+
+
+
+
+
+
)}
diff --git a/mirror/client/src/Module/LayerOutputs/LayerOutputs.js b/mirror/client/src/Module/LayerOutputs/LayerOutputs.js
index 844ab45..9862841 100644
--- a/mirror/client/src/Module/LayerOutputs/LayerOutputs.js
+++ b/mirror/client/src/Module/LayerOutputs/LayerOutputs.js
@@ -6,25 +6,54 @@ import Slider from '@material-ui/lab/Slider';
import Button from '@material-ui/core/Button';
import Typography from '@material-ui/core/Typography';
+class AtomicImage extends Component {
+ constructor(props) {
+ super(props);
+ this.state = {dimensions: {}};
+ this.onImgLoad = this.onImgLoad.bind(this);
+ }
+ onImgLoad({target:img}) {
+ this.setState({dimensions:{height:img.offsetHeight,
+ width:img.offsetWidth}});
+ }
+ render(){
+ const {src, size} = this.props;
+ const {width, height} = this.state.dimensions;
+
+ const style = {
+ // backgroundImage : `url(${src})`,
+ // backgroundSize: 100% 100%,
+ height: `${height * size}px`,
+ width: `${width * size}px`
+ }
+
+
+ return (
+
+ );
+ }
+ }
+
+
const Image = ({ src, size }) => {
const style = {
// backgroundImage : `url(${src})`,
// backgroundSize: 100% 100%,
- height: `${100 * (size / 20)}px`,
- width: `${100 * (size / 20)}px`
+ height: `${100 * (size / 10)}px`,
+ width: `${100 * (size / 10)}px`
}
return (
-
+
)
}
class LayerOutputs extends React.Component {
render() {
- const { module } = this.props
+ const { module, classes } = this.props
return (
-
+
+ {this.props.module.state.outputs.length == 0 ?
+ (Nothing to show
) :
+ this.props.module.state.next ? (
+ ) : ''}
+
diff --git a/mirror/client/src/ModuleContainer/ModuleContainer.js b/mirror/client/src/ModuleContainer/ModuleContainer.js
index ee79aea..183e553 100644
--- a/mirror/client/src/ModuleContainer/ModuleContainer.js
+++ b/mirror/client/src/ModuleContainer/ModuleContainer.js
@@ -7,11 +7,14 @@ import querystring from 'querystring'
class ModuleContainer extends Container {
state = {
tree: null,
+ visualisations: [],
+ currentVisualisation: {},
isLoading: false,
open: false,
outputs: [],
layer: { name : ''},
last: 0,
+ next: false,
settings: { size: 50 }
}
@@ -29,17 +32,18 @@ class ModuleContainer extends Container {
}
- getLayerOutputs = async (layer=this.state.layer) => {
- const isSameLayer = layer.id == this.state.layer.id
- const last = isSameLayer ? this.state.last + 64 : 0
+ getLayerOutputs = async (layer=this.state.layer, start=false) => {
+ var isSameLayer = layer.id == this.state.layer.id
+ if(start) isSameLayer = false
+ var last = isSameLayer ? this.state.last + 64 : 0
try {
await this.setState({ isLoading: true })
const res = await axios.get(api.getModuleLayerOutput(layer.id, last), { params: { last } })
- var outputs = isSameLayer ? this.state.outputs.concat(res.data) : res.data
-
- await this.setState({ outputs, layer, last, isLoading: false })
+ var outputs = isSameLayer ? this.state.outputs.concat(res.data.links) : res.data.links
+
+ await this.setState({ outputs, layer, last, isLoading: false, next: res.data.next })
} catch {
await this.setState({ isLoading: false })
}
@@ -48,6 +52,34 @@ class ModuleContainer extends Container {
changeSettings = async(settings) => {
await this.setState( { settings })
}
+
+ async getVisualisations() {
+ await this.setState({ isLoading: true })
+
+ const res = await axios.get(api.GET_VISUALISATIONS)
+ const visualisations = res.data.visualisations
+ const currentVisualisation = res.data.current
+ await this.setState({ visualisations, isLoading: false, currentVisualisation })
+ }
+
+ async setVisualisationsSettings(data){
+ await this.setState({ isLoading: true })
+
+ const res = await axios.put(api.PUT_VISUALISATIONS, data)
+
+ const currentVisualisation = res.data
+ var visualisations = [...this.state.visualisations]
+ // TODO could be smarter!
+ for(let key in visualisations){
+ if(visualisations[key].name == currentVisualisation.name){
+ visualisations[key] = currentVisualisation
+ }
+ }
+ await this.setState({isLoading: false, currentVisualisation, visualisations })
+
+ }
+
+
}
export default ModuleContainer
\ No newline at end of file
diff --git a/mirror/client/src/Settings/Settings.js b/mirror/client/src/Settings/Settings.js
index 126744e..ac80840 100644
--- a/mirror/client/src/Settings/Settings.js
+++ b/mirror/client/src/Settings/Settings.js
@@ -13,28 +13,154 @@ import ModuleContainer from '../ModuleContainer/ModuleContainer'
import { Provider, Subscribe } from 'unstated';
import Slider from '@material-ui/lab/Slider';
+import Radio from '@material-ui/core/Radio';
import { withStyles } from '@material-ui/core/styles';
+import {debounce} from 'throttle-debounce';
+
+
+class VisualisationSettings extends Component{
+
+ update = (value) => {
+ var param = {...this.props.param, ...value}
+ // console.log('value', value)
+ // console.log('this.props.visualisation,', this.props.param)
+ // console.log('visualisation,', param)
+ const key = this.props.name
+ var fromDown = {}
+ fromDown[key] = param
+ this.props.update( fromDown)
+ }
+
+ makeParam = (param) => {
+ if(param.type == 'slider') {
+ var { max, min, step, value } = param
+ return (
this.update({...param, ...{value : v}})) }
+ />)
+ }
+
+ if(param.type == 'radio') {
+ return ( this.update({...param, ...{value : !param.value}}) }
+ value="a"
+ name="radio-button-demo"
+ aria-label="A"
+ />)
+ }
+ }
+
+ render(){
+ const { classes, module, param, name} = this.props
+ const {params} = param
+
+ return (
+
+
+
+ {name}
+
+ {this.makeParam(param)}
+
+
+
+
+ {Object.keys(params).map((k, i) => ())}
+
+
+ )
+ }
+}
+
+class VisualisationSettingsRoot extends Component {
+ update = (value, down=true) => {
+ console.log('from down', value)
+
+ if(down) {
+ var visualisation = {...this.props.visualisation.params, ...value}
+ visualisation = {...this.props.visualisation, params: visualisation}
+ } else{
+ visualisation = {...this.props.visualisation, ...value}
+ }
+ console.log(visualisation)
+ this.props.module.setVisualisationsSettings(visualisation)
+ this.props.module.getLayerOutputs(this.props.module.state.layer, true)
+
+ }
+
+ render(){
+ const { module, visualisation} = this.props
+ const { name, params} = visualisation
+
+ return (
+
+
+
+ {name}
+
+ this.update({...visualisation, value:v }, false) }
+ value="a"
+ name="radio-button-demo"
+ aria-label="A"
+ />
+
+ {Object.keys(params).map((k, i) => ())}
+
+ )
+ }
+}
+
class Settings extends Component {
+ componentDidMount(){
+ this.props.module.getVisualisations()
+ }
+
+ onVisualisationSettingsChange = (data) => {
+ this.props.module.setVisualisationsSettings(data)
+ }
+
handleSlider = (e, size) => {
const { module } = this.props
var settings = module.state.settings
settings = Object.assign({}, settings, { size })
- console.log(settings)
+
module.changeSettings(settings)
}
+
render() {
- const { toogle, classes, open, module } = this.props
+ const { toogle, classes, open, module, small=false } = this.props
return (
+ classes={{
+ paper: classes.drawerPaper,
+ }}
+ >
-
@@ -43,19 +169,31 @@ class Settings extends Component {
-
-
+
+
+
+
+ Visualisations
+ {this.props.module.state.visualisations.map((v,i)=> )}
+
+
+
+
)
}
diff --git a/mirror/client/src/api.js b/mirror/client/src/api.js
index f1cfac4..6cb2fa6 100644
--- a/mirror/client/src/api.js
+++ b/mirror/client/src/api.js
@@ -1,6 +1,8 @@
const api = {
GET_MODULE : '/api/model',
+ GET_VISUALISATIONS: '/api/visualisation',
+ PUT_VISUALISATIONS: '/api/visualisation',
getModuleLayerOutput : (id) => `/api/model/layer/output/${id}`
}
diff --git a/mirror/deepdream.py b/mirror/deepdream.py
index a0184f5..1e2fd9e 100644
--- a/mirror/deepdream.py
+++ b/mirror/deepdream.py
@@ -15,17 +15,18 @@
from torch.autograd import Variable
import scipy.ndimage as nd
import numpy as np
+from skimage.util import view_as_blocks, view_as_windows, montage
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image, ImageFilter, ImageChops
-
+import matplotlib.pyplot as plt
# In[26]:
-# IMG_PATH = './pytorch-cnn-visualizations/input_images/dd_tree.jpg'
-IMG_PATH = './the_starry_night-wallpaper-1920x1200.jpg'
-# IMG_PATH = './the_starry_night-wallpaper-2560x1600.jpg'
+# IMG_PATH = '../pytorch-cnn-visualizations/input_images/dd_tree.jpg'
+# IMG_PATH = './resources/the_starry_night-wallpaper-1920x1200.jpg'
+IMG_PATH = './resources/the_starry_night-wallpaper-2560x1600.jpg'
# IMG_PATH = './sky-dd.jpeg'
# In[27]:
@@ -34,12 +35,12 @@
pil_img = Image.open(IMG_PATH)
# In[28]:
-
+print(pil_img.size)
img_transform = transforms.Compose([
# transforms.Resize((224, 224)),
transforms.ToTensor(),
- # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
+ # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# In[5]:
@@ -50,7 +51,7 @@
# In[6]:
-model = models.vgg16(pretrained=True).cuda()
+model = models.resnet50(pretrained=True).cuda()
import time
# In[37]:
@@ -65,15 +66,7 @@ def __init__(self, module, layer):
self.transformMean = [0.485, 0.456, 0.406]
self.transformStd = [0.229, 0.224, 0.225]
- # self.transform_preprocess = transforms.Normalize(
- # mean=self.transformMean,
- # std=self.transformStd
- # )
-
self.transform_preprocess = transforms.Compose([
- transforms.ToPILImage(),
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
transforms.Normalize(
mean=self.transformMean,
std=self.transformStd
@@ -82,7 +75,7 @@ def __init__(self, module, layer):
self.mean = torch.Tensor(self.transformStd).cuda()
self.std = torch.Tensor(self.transformMean).cuda()
- self.lr = 0.2
+ self.lr = 0.1
self.out = None
self.register_hooks()
@@ -116,17 +109,6 @@ def step(self, image, steps=5, save=False):
except:
pass
- # if save:
- # dreamed = self.image_var.data.squeeze()
- # c, w, h = dreamed.shape
- #
- # dreamed = dreamed.view((w, h, c))
- # dreamed = torch.clamp(dreamed, 0.0, 1.0)
- # dreamed = dreamed * self.std + self.mean
- # dreamed = dreamed.view((c, w, h))
- # dreamed_pil = TF.to_pil_image(dreamed.squeeze().cpu())
- #
- # dreamed_pil.save('./dream-' + str(i) + ".jpg", "JPEG")
dreamed = self.image_var.data.squeeze()
c, w, h = dreamed.shape
@@ -155,16 +137,12 @@ def deep_dream(self, image, n, top, scale_factor):
from_down = TF.to_pil_image(from_down.squeeze().cpu())
from_down = TF.resize(from_down, (w, h), Image.ANTIALIAS)
- # from_down = torch.nn.functional.interpolate(from_down, size=(w, h))
-
- # down, image = TF.to_pil_image(from_down.cpu().squeeze()), TF.to_pil_image(image.cpu().squeeze())
-
image = ImageChops.blend(from_down, image, 0.6)
image = TF.to_tensor(image).cuda()
n = n - 1
- return self.step(image, steps=3, save=top == n + 1)
+ return self.step(image, steps=8, save=top == n + 1)
def __call__(self, image, n_repeat=6, scale_factor=0.7):
@@ -172,28 +150,51 @@ def __call__(self, image, n_repeat=6, scale_factor=0.7):
-# In[9]:
+print(model)
+original_image = np.array(pil_img)
+N = 8
+h, w, c = original_image.shape
+h_N, w_N = h // N, w // N
+images = view_as_windows(original_image, (h_N, w_N, 3), (h_N,w_N, 3)).squeeze()
-print(model)
-start = time.time()
-dd = DeepDream(model, model.features[28])
-dreamed = dd(img_transform(pil_img).unsqueeze(0))
-end = time.time()
-print('{:.4f}'.format(end - start))
+dd = DeepDream(model, model.layer4[0].conv2)
+
+rec = None
+# print(images.shape)
+for rows in images:
+ col = None
+ for cols in rows:
+ image = Image.fromarray(cols.astype('uint8'), 'RGB')
+
+ dreamed = dd(img_transform(image).unsqueeze(0), n_repeat=6, scale_factor=0.7)
+ # dreamed = torch.nn.functional.interpolate(dreamed, scale_factor=0.7)
+ dreamed = transforms.ToPILImage()(dreamed.squeeze().cpu())
+
+ dreamed_np = np.array(dreamed)
+
+ if col is None: col = dreamed_np
+ else: col = np.hstack((col, dreamed_np))
+ # plt.imshow(dreamed_np)
+ # plt.show()
+ # print(image.shape)
+ # plt.imshow(image)
+ # plt.show()
+ if rec is None: rec = col
+ else: rec = np.vstack((rec, col))
-# In[35]:
+print(rec.shape)
-def plot_tensor(tensor):
- tensor = torch.nn.functional.interpolate(tensor, scale_factor=0.7)
- img = transforms.ToPILImage()(tensor.squeeze().cpu())
+img = Image.fromarray(rec.astype('uint8'), 'RGB')
+# print(model)
- return img
-img = plot_tensor(dreamed)
-# img.show()
+#
+# # In[35]:
+# #
+# # img.show()
img.save('dream' + ".jpg", "JPEG")
# In[36]:
diff --git a/mirror/dream.jpg b/mirror/dream.jpg
new file mode 100644
index 0000000..909e850
Binary files /dev/null and b/mirror/dream.jpg differ
diff --git a/mirror/server.py b/mirror/server.py
index f9a0e78..b18219b 100644
--- a/mirror/server.py
+++ b/mirror/server.py
@@ -1,77 +1,143 @@
-from flask import Flask, request, Response, send_file, jsonify
+import json
+import io
+import numpy as np
+import torch
+import time
+from flask import Flask, request, Response, send_file, jsonify
+from .visualisations import WeightsVisualisation
from PIL import Image
+import pprint
+from torchvision.transforms import ToPILImage
-import io
+class Builder:
+ def __init__(self):
+ self.outputs = None
+ self.cache = {}
+ self.visualisations = {}
+ self.current_vis = None
+ self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
+
+ def build(self, input, model, tracer, visualisations=[]):
+ input = input.to(self.device)
+ model = model.to(self.device)
+
+ visualisations = [v(model, tracer) for v in visualisations]
+ self.visualisations = [WeightsVisualisation(model, tracer), *visualisations]
+
+ self.name2visualisations = { v.name : v for v in self.visualisations}
+ self.current_vis = self.visualisations[0]
+
+ app = Flask(__name__)
+ MAX_LINKS_EVERY_REQUEST = 64
+
+
+ @app.route('/')
+ def root():
+ return app.send_static_file('index.html')
+
+ @app.route('/api/model', methods=['GET'])
+ def api_model():
+ model = tracer.serialized
+
+ response = jsonify(model)
+
+ return response
+
+ @app.route('/api/model/layer/')
+ def api_model_layer(id):
+ id = int(id)
+ name = str(tracer.idx_to_value[id])
+
+ return Response(response=name)
+
+ @app.route('/api/visualisation', methods=['GET'])
+ def api_visualisations():
+ serialised = [v.properties for v in self.visualisations]
+
+ response = jsonify({ 'visualisations': serialised,
+ 'current': self.current_vis.properties})
+
+ return response
+
+ @app.route('/api/visualisation', methods=['PUT'])
+ def api_visualisation():
+ data = json.loads(request.data.decode())
+
+ vis_key = data['name']
+
+ if vis_key not in self.name2visualisations:
+ response = Response(status=500, response='Visualisation {} not supported or does not exist'.format(vis_key))
+ else:
+ # TODO I should think on a cleaver way to update properties and params
+ self.name2visualisations[vis_key].properties = data
+ self.name2visualisations[vis_key].params = self.name2visualisations[vis_key].properties['params']
+ self.current_vis = self.name2visualisations[vis_key]
+ self.name2visualisations[vis_key].cache = {}
+
+ response = jsonify(self.name2visualisations[vis_key].properties)
+ return response
-def build(input, model, tracer):
- app = Flask(__name__)
- MAX_LINKS_EVERY_REQUEST = 64
+ @app.route('/api/model/layer/output/')
+ def api_model_layer_output(id):
+ try:
+ layer = tracer.idx_to_value[id].v
- @app.route('/')
- def root():
- return app.send_static_file('index.html')
+ if input not in self.current_vis.cache: self.current_vis.cache[input] = {}
+ # TODO need to cache for vis
+ layer_cache = self.current_vis.cache[input]
- @app.route('/api/model', methods=['GET'])
- def api_model():
- model = tracer.serialized
+ # layer_cache[layer] = self.current_vis(input, layer)
+ if layer not in layer_cache: layer_cache[layer] = self.current_vis(input, layer)
+ else: print('cached')
+ self.outputs = layer_cache[layer]
- response = jsonify(model)
+ outputs = self.outputs
- return response
+ if len(outputs.shape) < 3: raise ValueError
- @app.route('/api/model/layer/')
- def api_model_layer(id):
- id = int(id)
- name = str(tracer.idx_to_value[id])
+ last = int(request.args['last'])
+ max = min((last + MAX_LINKS_EVERY_REQUEST), outputs.shape[0])
- return Response(response=name)
+ response = ['/api/model/image/{}/{}/{}/{}/{}'.format(hash(input),
+ hash(self.current_vis),
+ hash(time.time()),
+ id,
+ i) for i in range(last, max)]
- @app.route('/api/model/layer/output/')
- def api_model_layer_output(id):
- try:
- model, inputs, outputs = tracer.idx_to_value[id].traced[0]
- if len(outputs.shape) < 3: raise ValueError
- # mode = request.args.get('mode')
+ response = jsonify({ 'links' : response, 'next': last + 1< max})
- last = int(request.args['last'])
- max = min((last + MAX_LINKS_EVERY_REQUEST), outputs.shape[1])
- if last >= max: raise StopIteration
+ except KeyError:
+ response = Response(status=500, response='Index not found.')
+ except ValueError:
+ response = Response(status=404, response='Outputs must be an array of images')
+ except StopIteration:
+ response = jsonify({ 'links' : [], 'next': False})
- response = ['/api/model/image/{}/{}'.format(id, i) for i in range(last, max)]
- response = jsonify(response)
+ return response
- except KeyError:
- response = Response(status=500, response='Index not found.')
- except ValueError:
- response = Response(status=404, response='Outputs are not images.')
- except StopIteration:
- response = Response(status=404, response='No more.')
+ @app.route('/api/model/image////