diff --git a/.gitignore b/.gitignore index 9f11b75..a047a94 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .idea/ +__pycache__/ \ No newline at end of file diff --git a/DummyVisualisation.py b/DummyVisualisation.py new file mode 100644 index 0000000..6cf0ab8 --- /dev/null +++ b/DummyVisualisation.py @@ -0,0 +1,20 @@ +from mirror.visualisations.Visualisation import Visualisation + +class DummyVisualisation(Visualisation): + + def __call__(self, inputs, layer): + return inputs.repeat(self.params['repeat']['value'],1, 1, 1) + + @property + def name(self): + return 'dummy' + + def init_params(self): + return {'repeat' : { + 'type' : 'slider', + 'min' : 1, + 'max' : 100, + 'value' : 3, + 'step': 1, + 'params': {} + }} diff --git a/README.md b/README.md index 5aec8f7..f205233 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ This is a raw beta so expect lots of things to change and improve over time. -![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/mirror/master/mirror/resources/mirror.gif) +![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/develop/resources/mirror.gif?raw=true) ### Getting started @@ -17,37 +17,77 @@ Basic example: ```python from mirror import mirror +from mirror.visualisations import DeepDream from PIL import Image -from torchvision.models import resnet101 +from torchvision.models import resnet101, resnet18, vgg16 from torchvision.transforms import ToTensor, Resize, Compose -# create a model -model = resnet101(True) -cat = Image.open("cat.jpg") +# create a model +model = vgg16(pretrained=True) + +cat = Image.open("./cat.jpg") # resize the image and make it a tensor input = Compose([Resize((224,224)), ToTensor()])(cat) # add 1 dim for batch -input = input.view(1,3,224,224) +input = input.unsqueeze(0) # call mirror with the input and the model -mirror(input, model) +mirror(input, model, visualisations=[DeepDream]) ``` It will automatic open a new tab in your browser +![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/develop/resources/mirror.jpg?raw=true) + +### Create a Visualisation + +You can find an example below + +```python +from mirror.visualisations.Visualisation import Visualisation + +class DummyVisualisation(Visualisation): + + def __call__(self, inputs, layer): + return inputs.repeat(self.params['repeat']['value'],1, 1, 1) + + @property + def name(self): + return 'dummy' + + def init_params(self): + return {'repeat' : { + 'type' : 'slider', + 'min' : 1, + 'max' : 100, + 'value' : 3, + 'step': 1, + 'params': {} + }} + +``` + +![alt](https://github.com/FrancescoSaverioZuppichini/mirror/blob/develop/resources/dummy.jpg?raw=true) + +The `__call__` function is called each time you click a layer or change a value in the options on the right. + +The `init_params` function returns a dictionary of options that will be showed on the right drawer of the application. For now only `slider` and `radio` are supported ### TODO -- Support multiple inputs and cache them -- Make a generic abstraction of a visualisation in order to add more features +- [x] Cache reused layer +- [x] Make a generic abstraction of a visualisation in order to add more features +- [ ] Add more options for the parameters (dropdown, text) +- [ ] Support multiple inputs +- [ ] Support multiple models - Add all visualisation present here https://github.com/utkuozbulak/pytorch-cnn-visualizations - * [Gradient visualization with vanilla backpropagation](#gradient-visualization) - * [Gradient visualization with guided backpropagation](#gradient-visualization) [1] - * [Gradient visualization with saliency maps](#gradient-visualization) [4] - * [Gradient-weighted [3] class activation mapping](#gradient-visualization) [2] - * [Guided, gradient-weighted class activation mapping](#gradient-visualization) [3] - * [Smooth grad](#smooth-grad) [8] - * [CNN filter visualization](#convolutional-neural-network-filter-visualization) [9] - * [Inverted image representations](#inverted-image-representations) [5] - * [Deep dream](#deep-dream) [10] - * [Class specific image generation](#class-specific-image-generation) [4] + * [ ] [Gradient visualization with vanilla backpropagation](#gradient-visualization) + * [ ] [Gradient visualization with guided backpropagation](#gradient-visualization) [1] + * [ ] [Gradient visualization with saliency maps](#gradient-visualization) [4] + * [ ] [Gradient-weighted [3] class activation mapping](#gradient-visualization) [2] + * [ ] [Guided, gradient-weighted class activation mapping](#gradient-visualization) [3] + * [ ] [Smooth grad](#smooth-grad) [8] + * [x] [CNN filter visualization](#convolutional-neural-network-filter-visualization) [9] + * [ ] [Inverted image representations](#inverted-image-representations) [5] + * [x] [Deep dream](#deep-dream) [10] + * [ ] [Class specific image generation](#class-specific-image-generation) [4] diff --git a/example.py b/example.py index 6217e8b..1a395cf 100644 --- a/example.py +++ b/example.py @@ -1,16 +1,18 @@ from mirror import mirror +from mirror.visualisations import DeepDream from PIL import Image -from torchvision.models import resnet101 +from torchvision.models import resnet101, resnet18, vgg16 from torchvision.transforms import ToTensor, Resize, Compose -model = resnet101(True) - -cat = Image.open("cat.jpg") +# create a model +model = vgg16(pretrained=True) +cat = Image.open("./cat.jpg") +# resize the image and make it a tensor input = Compose([Resize((224,224)), ToTensor()])(cat) - -input = input.view(1,3,224,224) - -mirror(input, model) \ No newline at end of file +# add 1 dim for batch +input = input.unsqueeze(0) +# call mirror with the input and the model +mirror(input, model, visualisations=[DeepDream]) \ No newline at end of file diff --git a/mirror/__pycache__/app.cpython-36.pyc b/mirror/__pycache__/app.cpython-36.pyc index bf49284..ddc41b1 100644 Binary files a/mirror/__pycache__/app.cpython-36.pyc and b/mirror/__pycache__/app.cpython-36.pyc differ diff --git a/mirror/__pycache__/server.cpython-36.pyc b/mirror/__pycache__/server.cpython-36.pyc index 464d4a4..5120b3f 100644 Binary files a/mirror/__pycache__/server.cpython-36.pyc and b/mirror/__pycache__/server.cpython-36.pyc differ diff --git a/mirror/__pycache__/tree.cpython-36.pyc b/mirror/__pycache__/tree.cpython-36.pyc index 2a71c0d..81753e6 100644 Binary files a/mirror/__pycache__/tree.cpython-36.pyc and b/mirror/__pycache__/tree.cpython-36.pyc differ diff --git a/mirror/app.py b/mirror/app.py index a804ca8..c8be4bb 100644 --- a/mirror/app.py +++ b/mirror/app.py @@ -1,13 +1,15 @@ import webbrowser from .tree import Tracer -from .server import build +from .server import Builder -def mirror(input, model): +def mirror(input, model, visualisations=[]): tracer = Tracer(module=model) tracer(input) - app = build(input, model, tracer) + builder = Builder() + + app = builder.build(input, model, tracer, visualisations) webbrowser.open_new('http://localhost:5000') # opens in default browser diff --git a/mirror/client/package-lock.json b/mirror/client/package-lock.json index 90aeae9..429ad77 100644 --- a/mirror/client/package-lock.json +++ b/mirror/client/package-lock.json @@ -14134,6 +14134,11 @@ "resolved": "https://registry.npmjs.org/throat/-/throat-4.1.0.tgz", "integrity": "sha1-iQN8vJLFarGJJua6TLsgDhVnKmo=" }, + "throttle-debounce": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/throttle-debounce/-/throttle-debounce-2.0.1.tgz", + "integrity": "sha512-Sr6jZBlWShsAaSXKyNXyNicOrJW/KtkDqIEwHt4wYwWA2wa/q67Luhqoujg48V8hTk60wB56tYrJJn6jc2R7VA==" + }, "through": { "version": "2.3.8", "resolved": "http://registry.npmjs.org/through/-/through-2.3.8.tgz", diff --git a/mirror/client/package.json b/mirror/client/package.json index 9691b59..2e90d92 100644 --- a/mirror/client/package.json +++ b/mirror/client/package.json @@ -13,6 +13,7 @@ "react-script": "^2.0.5", "react-scripts": "^2.0.4", "reactstrap": "^6.5.0", + "throttle-debounce": "^2.0.1", "unstated": "^2.1.1" }, "scripts": { diff --git a/mirror/client/src/App.js b/mirror/client/src/App.js index 3787ad3..1344d17 100644 --- a/mirror/client/src/App.js +++ b/mirror/client/src/App.js @@ -18,6 +18,7 @@ import LinearProgress from '@material-ui/core/LinearProgress'; import LayerOutputs from './Module/LayerOutputs/LayerOutputs' import MoreIcon from '@material-ui/icons/MoreVert'; +import Hidden from '@material-ui/core/Hidden'; const drawerWidth = 300; @@ -25,24 +26,38 @@ const styles = theme => ({ root: { flexGrow: 1, // height: 440, - zIndex: 1, + // zIndex: 1, // overflow: 'hidden', - position: 'relative', + // position: 'relative', display: 'flex', + flexDirection : 'row', minHeight: '100vh' }, + typography: { + useNextVariants: true, + }, appBar: { zIndex: theme.zIndex.drawer + 1, }, + drawer: { + width: drawerWidth, + flexShrink: 0, + }, drawerPaper: { - position: 'relative', + flexShrink: 0, + + // position: 'relative', width: drawerWidth, }, content: { flexGrow: 1, + // marginLeft: '300px', + // position: 'fixed', + // width: '100%', + // height: '100%', backgroundColor: theme.palette.background.default, padding: theme.spacing.unit * 3, - minWidth: 0, // So the Typography noWrap works + // minWidth: 0, // So the Typography noWrap works }, toolbar: theme.mixins.toolbar, @@ -53,9 +68,16 @@ const styles = theme => ({ zIndex: 9999 }, - settn: { + settings: { width: '300px !important' -} + }, + + sliders : { + width: '200px !important' + }, + + layersOuput : { + } }) function MyAppBar({ module, classes }) { @@ -66,9 +88,12 @@ function MyAppBar({ module, classes }) { Mirror
+ + + ) @@ -80,6 +105,7 @@ class App extends Component { } + toggleSettings = () => { const openSettings = !this.state.openSettings this.setState({ openSettings }) @@ -94,23 +120,40 @@ class App extends Component {
-
+ {module.state.isLoading ? () : ''}
- +
+ + + + + + + +
)} diff --git a/mirror/client/src/Module/LayerOutputs/LayerOutputs.js b/mirror/client/src/Module/LayerOutputs/LayerOutputs.js index 844ab45..9862841 100644 --- a/mirror/client/src/Module/LayerOutputs/LayerOutputs.js +++ b/mirror/client/src/Module/LayerOutputs/LayerOutputs.js @@ -6,25 +6,54 @@ import Slider from '@material-ui/lab/Slider'; import Button from '@material-ui/core/Button'; import Typography from '@material-ui/core/Typography'; +class AtomicImage extends Component { + constructor(props) { + super(props); + this.state = {dimensions: {}}; + this.onImgLoad = this.onImgLoad.bind(this); + } + onImgLoad({target:img}) { + this.setState({dimensions:{height:img.offsetHeight, + width:img.offsetWidth}}); + } + render(){ + const {src, size} = this.props; + const {width, height} = this.state.dimensions; + + const style = { + // backgroundImage : `url(${src})`, + // backgroundSize: 100% 100%, + height: `${height * size}px`, + width: `${width * size}px` + } + + + return ( + + ); + } + } + + const Image = ({ src, size }) => { const style = { // backgroundImage : `url(${src})`, // backgroundSize: 100% 100%, - height: `${100 * (size / 20)}px`, - width: `${100 * (size / 20)}px` + height: `${100 * (size / 10)}px`, + width: `${100 * (size / 10)}px` } return ( - + ) } class LayerOutputs extends React.Component { render() { - const { module } = this.props + const { module, classes } = this.props return ( -
+
+ {this.props.module.state.outputs.length == 0 ? + (

Nothing to show

) : + this.props.module.state.next ? ( + ) : ''} +
diff --git a/mirror/client/src/ModuleContainer/ModuleContainer.js b/mirror/client/src/ModuleContainer/ModuleContainer.js index ee79aea..183e553 100644 --- a/mirror/client/src/ModuleContainer/ModuleContainer.js +++ b/mirror/client/src/ModuleContainer/ModuleContainer.js @@ -7,11 +7,14 @@ import querystring from 'querystring' class ModuleContainer extends Container { state = { tree: null, + visualisations: [], + currentVisualisation: {}, isLoading: false, open: false, outputs: [], layer: { name : ''}, last: 0, + next: false, settings: { size: 50 } } @@ -29,17 +32,18 @@ class ModuleContainer extends Container { } - getLayerOutputs = async (layer=this.state.layer) => { - const isSameLayer = layer.id == this.state.layer.id - const last = isSameLayer ? this.state.last + 64 : 0 + getLayerOutputs = async (layer=this.state.layer, start=false) => { + var isSameLayer = layer.id == this.state.layer.id + if(start) isSameLayer = false + var last = isSameLayer ? this.state.last + 64 : 0 try { await this.setState({ isLoading: true }) const res = await axios.get(api.getModuleLayerOutput(layer.id, last), { params: { last } }) - var outputs = isSameLayer ? this.state.outputs.concat(res.data) : res.data - - await this.setState({ outputs, layer, last, isLoading: false }) + var outputs = isSameLayer ? this.state.outputs.concat(res.data.links) : res.data.links + + await this.setState({ outputs, layer, last, isLoading: false, next: res.data.next }) } catch { await this.setState({ isLoading: false }) } @@ -48,6 +52,34 @@ class ModuleContainer extends Container { changeSettings = async(settings) => { await this.setState( { settings }) } + + async getVisualisations() { + await this.setState({ isLoading: true }) + + const res = await axios.get(api.GET_VISUALISATIONS) + const visualisations = res.data.visualisations + const currentVisualisation = res.data.current + await this.setState({ visualisations, isLoading: false, currentVisualisation }) + } + + async setVisualisationsSettings(data){ + await this.setState({ isLoading: true }) + + const res = await axios.put(api.PUT_VISUALISATIONS, data) + + const currentVisualisation = res.data + var visualisations = [...this.state.visualisations] + // TODO could be smarter! + for(let key in visualisations){ + if(visualisations[key].name == currentVisualisation.name){ + visualisations[key] = currentVisualisation + } + } + await this.setState({isLoading: false, currentVisualisation, visualisations }) + + } + + } export default ModuleContainer \ No newline at end of file diff --git a/mirror/client/src/Settings/Settings.js b/mirror/client/src/Settings/Settings.js index 126744e..ac80840 100644 --- a/mirror/client/src/Settings/Settings.js +++ b/mirror/client/src/Settings/Settings.js @@ -13,28 +13,154 @@ import ModuleContainer from '../ModuleContainer/ModuleContainer' import { Provider, Subscribe } from 'unstated'; import Slider from '@material-ui/lab/Slider'; +import Radio from '@material-ui/core/Radio'; import { withStyles } from '@material-ui/core/styles'; +import {debounce} from 'throttle-debounce'; + + +class VisualisationSettings extends Component{ + + update = (value) => { + var param = {...this.props.param, ...value} + // console.log('value', value) + // console.log('this.props.visualisation,', this.props.param) + // console.log('visualisation,', param) + const key = this.props.name + var fromDown = {} + fromDown[key] = param + this.props.update( fromDown) + } + + makeParam = (param) => { + if(param.type == 'slider') { + var { max, min, step, value } = param + return ( this.update({...param, ...{value : v}})) } + />) + } + + if(param.type == 'radio') { + return ( this.update({...param, ...{value : !param.value}}) } + value="a" + name="radio-button-demo" + aria-label="A" + />) + } + } + + render(){ + const { classes, module, param, name} = this.props + const {params} = param + + return ( + + + +
{name}
+
+ {this.makeParam(param)} +
+
+
+ + {Object.keys(params).map((k, i) => ())} + +
+ ) + } +} + +class VisualisationSettingsRoot extends Component { + update = (value, down=true) => { + console.log('from down', value) + + if(down) { + var visualisation = {...this.props.visualisation.params, ...value} + visualisation = {...this.props.visualisation, params: visualisation} + } else{ + visualisation = {...this.props.visualisation, ...value} + } + console.log(visualisation) + this.props.module.setVisualisationsSettings(visualisation) + this.props.module.getLayerOutputs(this.props.module.state.layer, true) + + } + + render(){ + const { module, visualisation} = this.props + const { name, params} = visualisation + + return ( + + + +
{name}
+
+ this.update({...visualisation, value:v }, false) } + value="a" + name="radio-button-demo" + aria-label="A" + /> +
+ {Object.keys(params).map((k, i) => ())} +
+ ) + } +} + class Settings extends Component { + componentDidMount(){ + this.props.module.getVisualisations() + } + + onVisualisationSettingsChange = (data) => { + this.props.module.setVisualisationsSettings(data) + } + handleSlider = (e, size) => { const { module } = this.props var settings = module.state.settings settings = Object.assign({}, settings, { size }) - console.log(settings) + module.changeSettings(settings) } + render() { - const { toogle, classes, open, module } = this.props + const { toogle, classes, open, module, small=false } = this.props return ( + classes={{ + paper: classes.drawerPaper, + }} + >
- @@ -43,19 +169,31 @@ class Settings extends Component { - - + + + + +
Visualisations
+ {this.props.module.state.visualisations.map((v,i)=> )} +
+
+ +
) } diff --git a/mirror/client/src/api.js b/mirror/client/src/api.js index f1cfac4..6cb2fa6 100644 --- a/mirror/client/src/api.js +++ b/mirror/client/src/api.js @@ -1,6 +1,8 @@ const api = { GET_MODULE : '/api/model', + GET_VISUALISATIONS: '/api/visualisation', + PUT_VISUALISATIONS: '/api/visualisation', getModuleLayerOutput : (id) => `/api/model/layer/output/${id}` } diff --git a/mirror/deepdream.py b/mirror/deepdream.py index a0184f5..1e2fd9e 100644 --- a/mirror/deepdream.py +++ b/mirror/deepdream.py @@ -15,17 +15,18 @@ from torch.autograd import Variable import scipy.ndimage as nd import numpy as np +from skimage.util import view_as_blocks, view_as_windows, montage from torchvision import transforms import torchvision.transforms.functional as TF from PIL import Image, ImageFilter, ImageChops - +import matplotlib.pyplot as plt # In[26]: -# IMG_PATH = './pytorch-cnn-visualizations/input_images/dd_tree.jpg' -IMG_PATH = './the_starry_night-wallpaper-1920x1200.jpg' -# IMG_PATH = './the_starry_night-wallpaper-2560x1600.jpg' +# IMG_PATH = '../pytorch-cnn-visualizations/input_images/dd_tree.jpg' +# IMG_PATH = './resources/the_starry_night-wallpaper-1920x1200.jpg' +IMG_PATH = './resources/the_starry_night-wallpaper-2560x1600.jpg' # IMG_PATH = './sky-dd.jpeg' # In[27]: @@ -34,12 +35,12 @@ pil_img = Image.open(IMG_PATH) # In[28]: - +print(pil_img.size) img_transform = transforms.Compose([ # transforms.Resize((224, 224)), transforms.ToTensor(), - # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # In[5]: @@ -50,7 +51,7 @@ # In[6]: -model = models.vgg16(pretrained=True).cuda() +model = models.resnet50(pretrained=True).cuda() import time # In[37]: @@ -65,15 +66,7 @@ def __init__(self, module, layer): self.transformMean = [0.485, 0.456, 0.406] self.transformStd = [0.229, 0.224, 0.225] - # self.transform_preprocess = transforms.Normalize( - # mean=self.transformMean, - # std=self.transformStd - # ) - self.transform_preprocess = transforms.Compose([ - transforms.ToPILImage(), - transforms.Resize((224, 224)), - transforms.ToTensor(), transforms.Normalize( mean=self.transformMean, std=self.transformStd @@ -82,7 +75,7 @@ def __init__(self, module, layer): self.mean = torch.Tensor(self.transformStd).cuda() self.std = torch.Tensor(self.transformMean).cuda() - self.lr = 0.2 + self.lr = 0.1 self.out = None self.register_hooks() @@ -116,17 +109,6 @@ def step(self, image, steps=5, save=False): except: pass - # if save: - # dreamed = self.image_var.data.squeeze() - # c, w, h = dreamed.shape - # - # dreamed = dreamed.view((w, h, c)) - # dreamed = torch.clamp(dreamed, 0.0, 1.0) - # dreamed = dreamed * self.std + self.mean - # dreamed = dreamed.view((c, w, h)) - # dreamed_pil = TF.to_pil_image(dreamed.squeeze().cpu()) - # - # dreamed_pil.save('./dream-' + str(i) + ".jpg", "JPEG") dreamed = self.image_var.data.squeeze() c, w, h = dreamed.shape @@ -155,16 +137,12 @@ def deep_dream(self, image, n, top, scale_factor): from_down = TF.to_pil_image(from_down.squeeze().cpu()) from_down = TF.resize(from_down, (w, h), Image.ANTIALIAS) - # from_down = torch.nn.functional.interpolate(from_down, size=(w, h)) - - # down, image = TF.to_pil_image(from_down.cpu().squeeze()), TF.to_pil_image(image.cpu().squeeze()) - image = ImageChops.blend(from_down, image, 0.6) image = TF.to_tensor(image).cuda() n = n - 1 - return self.step(image, steps=3, save=top == n + 1) + return self.step(image, steps=8, save=top == n + 1) def __call__(self, image, n_repeat=6, scale_factor=0.7): @@ -172,28 +150,51 @@ def __call__(self, image, n_repeat=6, scale_factor=0.7): -# In[9]: +print(model) +original_image = np.array(pil_img) +N = 8 +h, w, c = original_image.shape +h_N, w_N = h // N, w // N +images = view_as_windows(original_image, (h_N, w_N, 3), (h_N,w_N, 3)).squeeze() -print(model) -start = time.time() -dd = DeepDream(model, model.features[28]) -dreamed = dd(img_transform(pil_img).unsqueeze(0)) -end = time.time() -print('{:.4f}'.format(end - start)) +dd = DeepDream(model, model.layer4[0].conv2) + +rec = None +# print(images.shape) +for rows in images: + col = None + for cols in rows: + image = Image.fromarray(cols.astype('uint8'), 'RGB') + + dreamed = dd(img_transform(image).unsqueeze(0), n_repeat=6, scale_factor=0.7) + # dreamed = torch.nn.functional.interpolate(dreamed, scale_factor=0.7) + dreamed = transforms.ToPILImage()(dreamed.squeeze().cpu()) + + dreamed_np = np.array(dreamed) + + if col is None: col = dreamed_np + else: col = np.hstack((col, dreamed_np)) + # plt.imshow(dreamed_np) + # plt.show() + # print(image.shape) + # plt.imshow(image) + # plt.show() + if rec is None: rec = col + else: rec = np.vstack((rec, col)) -# In[35]: +print(rec.shape) -def plot_tensor(tensor): - tensor = torch.nn.functional.interpolate(tensor, scale_factor=0.7) - img = transforms.ToPILImage()(tensor.squeeze().cpu()) +img = Image.fromarray(rec.astype('uint8'), 'RGB') +# print(model) - return img -img = plot_tensor(dreamed) -# img.show() +# +# # In[35]: +# # +# # img.show() img.save('dream' + ".jpg", "JPEG") # In[36]: diff --git a/mirror/dream.jpg b/mirror/dream.jpg new file mode 100644 index 0000000..909e850 Binary files /dev/null and b/mirror/dream.jpg differ diff --git a/mirror/server.py b/mirror/server.py index f9a0e78..b18219b 100644 --- a/mirror/server.py +++ b/mirror/server.py @@ -1,77 +1,143 @@ -from flask import Flask, request, Response, send_file, jsonify +import json +import io +import numpy as np +import torch +import time +from flask import Flask, request, Response, send_file, jsonify +from .visualisations import WeightsVisualisation from PIL import Image +import pprint +from torchvision.transforms import ToPILImage -import io +class Builder: + def __init__(self): + self.outputs = None + self.cache = {} + self.visualisations = {} + self.current_vis = None + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + def build(self, input, model, tracer, visualisations=[]): + input = input.to(self.device) + model = model.to(self.device) + + visualisations = [v(model, tracer) for v in visualisations] + self.visualisations = [WeightsVisualisation(model, tracer), *visualisations] + + self.name2visualisations = { v.name : v for v in self.visualisations} + self.current_vis = self.visualisations[0] + + app = Flask(__name__) + MAX_LINKS_EVERY_REQUEST = 64 + + + @app.route('/') + def root(): + return app.send_static_file('index.html') + + @app.route('/api/model', methods=['GET']) + def api_model(): + model = tracer.serialized + + response = jsonify(model) + + return response + + @app.route('/api/model/layer/') + def api_model_layer(id): + id = int(id) + name = str(tracer.idx_to_value[id]) + + return Response(response=name) + + @app.route('/api/visualisation', methods=['GET']) + def api_visualisations(): + serialised = [v.properties for v in self.visualisations] + + response = jsonify({ 'visualisations': serialised, + 'current': self.current_vis.properties}) + + return response + + @app.route('/api/visualisation', methods=['PUT']) + def api_visualisation(): + data = json.loads(request.data.decode()) + + vis_key = data['name'] + + if vis_key not in self.name2visualisations: + response = Response(status=500, response='Visualisation {} not supported or does not exist'.format(vis_key)) + else: + # TODO I should think on a cleaver way to update properties and params + self.name2visualisations[vis_key].properties = data + self.name2visualisations[vis_key].params = self.name2visualisations[vis_key].properties['params'] + self.current_vis = self.name2visualisations[vis_key] + self.name2visualisations[vis_key].cache = {} + + response = jsonify(self.name2visualisations[vis_key].properties) + return response -def build(input, model, tracer): - app = Flask(__name__) - MAX_LINKS_EVERY_REQUEST = 64 + @app.route('/api/model/layer/output/') + def api_model_layer_output(id): + try: + layer = tracer.idx_to_value[id].v - @app.route('/') - def root(): - return app.send_static_file('index.html') + if input not in self.current_vis.cache: self.current_vis.cache[input] = {} + # TODO need to cache for vis + layer_cache = self.current_vis.cache[input] - @app.route('/api/model', methods=['GET']) - def api_model(): - model = tracer.serialized + # layer_cache[layer] = self.current_vis(input, layer) + if layer not in layer_cache: layer_cache[layer] = self.current_vis(input, layer) + else: print('cached') + self.outputs = layer_cache[layer] - response = jsonify(model) + outputs = self.outputs - return response + if len(outputs.shape) < 3: raise ValueError - @app.route('/api/model/layer/') - def api_model_layer(id): - id = int(id) - name = str(tracer.idx_to_value[id]) + last = int(request.args['last']) + max = min((last + MAX_LINKS_EVERY_REQUEST), outputs.shape[0]) - return Response(response=name) + response = ['/api/model/image/{}/{}/{}/{}/{}'.format(hash(input), + hash(self.current_vis), + hash(time.time()), + id, + i) for i in range(last, max)] - @app.route('/api/model/layer/output/') - def api_model_layer_output(id): - try: - model, inputs, outputs = tracer.idx_to_value[id].traced[0] - if len(outputs.shape) < 3: raise ValueError - # mode = request.args.get('mode') + response = jsonify({ 'links' : response, 'next': last + 1< max}) - last = int(request.args['last']) - max = min((last + MAX_LINKS_EVERY_REQUEST), outputs.shape[1]) - if last >= max: raise StopIteration + except KeyError: + response = Response(status=500, response='Index not found.') + except ValueError: + response = Response(status=404, response='Outputs must be an array of images') + except StopIteration: + response = jsonify({ 'links' : [], 'next': False}) - response = ['/api/model/image/{}/{}'.format(id, i) for i in range(last, max)] - response = jsonify(response) + return response - except KeyError: - response = Response(status=500, response='Index not found.') - except ValueError: - response = Response(status=404, response='Outputs are not images.') - except StopIteration: - response = Response(status=404, response='No more.') + @app.route('/api/model/image////