diff --git a/cnn_nst.ipynb b/cnn_nst.ipynb new file mode 100755 index 0000000..7e1f9aa --- /dev/null +++ b/cnn_nst.ipynb @@ -0,0 +1,337 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from keras.layers import Conv2D, UpSampling2D, InputLayer, Conv2DTranspose\n", + "from keras.layers import Activation, Dense, Dropout, Flatten\n", + "from keras.layers.normalization import BatchNormalization\n", + "from keras.models import Sequential\n", + "from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img\n", + "from keras.utils import plot_model\n", + "from skimage.color import rgb2lab, lab2rgb, rgb2gray, xyz2lab\n", + "from skimage.io import imsave\n", + "import numpy as np\n", + "import os\n", + "import random\n", + "import tensorflow as tf\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.image as mpimg\n", + "from skimage.transform import resize\n", + "from skimage import io, color\n", + "from PIL import Image\n", + "import warnings\n", + "import cv2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Pre_Process_For Direct SEM Images\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "or_input = 'ref/ref22_d.png'\n", + "or_output = 'tmp/input.png'\n", + "size = (400,400)\n", + "rgb = io.imread(or_input)\n", + "resized_image = resize(rgb, size)\n", + "rescaled_image = 255 * resized_image\n", + "final_image = rescaled_image.astype(np.uint8)\n", + "io.imsave(or_output,final_image)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dir_style = 'sem/' \n", + "style_input = 'AI-22.jpg'\n", + "style_output = 'tmp/style.png'\n", + "size = (400,400)\n", + "rgb = io.imread(dir_style+style_input)\n", + "resized_image = resize(rgb, size)\n", + "rescaled_image = 255 * resized_image\n", + "final_image = rescaled_image.astype(np.uint8)\n", + "io.imsave(style_output,final_image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plt.figure()\n", + "\n", + "a=fig.add_subplot(1,2,1)\n", + "img_rslt=mpimg.imread(or_output)\n", + "imgplot= plt.imshow(img_rslt)\n", + "a.set_title('Input')\n", + "a.axis('off')\n", + "a=fig.add_subplot(1,2,2)\n", + "img_gray=mpimg.imread(style_output)\n", + "imgplot= plt.imshow(img_gray)\n", + "a.set_title('Style')\n", + "\n", + "a.axis('off')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get images\n", + "\n", + "input_image = img_to_array(load_img(or_output))\n", + "input_image = np.array(input_image, dtype=float)\n", + "style_image = img_to_array(load_img(style_output))\n", + "style_image = np.array(style_image, dtype=float)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X = rgb2lab(1.0/255*input_image)[:,:,0]\n", + "Y = rgb2lab(1.0/255*input_image)[:,:,1:]\n", + "Y /= 128\n", + "X = X.reshape(1, 400, 400, 1)\n", + "Y = Y.reshape(1, 400, 400, 2)\n", + "## TO be Draw\n", + "X_style = rgb2lab(1.0/255*style_image)[:,:,0]\n", + "#X_style = style_image.reshape((image.shape[0] * image.shape[1], 3))\n", + "X_style = X_style.reshape(1, 400, 400, 1)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Building the neural network\n", + "model = Sequential()\n", + "model.add(InputLayer(input_shape=(None, None, 1)))\n", + "model.add(Conv2D(8, (3, 3), activation='relu', padding='same', strides=2))\n", + "model.add(Conv2D(8, (3, 3), activation='relu', padding='same'))\n", + "model.add(Conv2D(16, (3, 3), activation='relu', padding='same'))\n", + "model.add(Conv2D(16, (3, 3), activation='relu', padding='same', strides=2))\n", + "model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))\n", + "model.add(Conv2D(32, (3, 3), activation='relu', padding='same', strides=2))\n", + "model.add(UpSampling2D((2, 2)))\n", + "model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))\n", + "model.add(UpSampling2D((2, 2)))\n", + "model.add(Conv2D(16, (3, 3), activation='relu', padding='same'))\n", + "model.add(UpSampling2D((2, 2)))\n", + "model.add(Conv2D(2, (3, 3), activation='tanh', padding='same'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Finish model\n", + "model.compile(optimizer='rmsprop',loss='mse')\n", + "#plot_model(model, to_file='model.png')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "history = model.fit(x=X, \n", + " y=Y,\n", + " batch_size=1,\n", + " epochs=500)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(history.history['loss'])\n", + "#plt.plot(history.history['val_loss'])\n", + "plt.title('Model loss')\n", + "plt.ylabel('Loss')\n", + "plt.xlabel('Epoch')\n", + "plt.legend(['Train', 'Test'], loc='upper left')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "color_dict = {}\n", + "color_dict_max = {}\n", + "color_dict_max_ab = {}\n", + "def same_grey_same_colorize(AB, L):\n", + " \n", + " for i in range(L.shape[0]):\n", + " for j in range(L.shape[1]):\n", + " l = L[i][j]\n", + " ab = list(map(lambda x:int(x), AB[i][j] + 128))\n", + " if l not in color_dict:\n", + " color_dict[l] = [[0]*256]*256\n", + " color_dict_max[l] = 0\n", + " color_dict[l][ab[0]][ab[1]] += 1\n", + " if color_dict[l][ab[0]][ab[1]] > color_dict_max[l]:\n", + " color_dict_max[l] = color_dict[l][ab[0]][ab[1]]\n", + " color_dict_max_ab[l] = AB[i][j]\n", + " \n", + " for i in range(L.shape[0]):\n", + " for j in range(L.shape[1]):\n", + " l = L[i][j]\n", + " AB[i][j] = color_dict_max_ab[l]\n", + " return AB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#debug code\n", + "#l = np.array([[1,3,4,1],[3,4,1,3]])\n", + "#ab = np.array([[5,7,8,5],[8,9,6,7]])\n", + "#same_grey_same_colorize(ab, l)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(model.evaluate(X_style, Y, batch_size=1))\n", + "\n", + "output = model.predict(X_style)\n", + "output *= 128\n", + "# Output colorizations\n", + "cur = np.zeros((400, 400, 3))\n", + "L = X_style[0][:,:,0]\n", + "AB = output[0]\n", + "cur[:,:,0] = L\n", + "cur[:,:,1:] = same_grey_same_colorize(AB, L)\n", + "imsave(\"or_result/\"+style_input, lab2rgb(cur))\n", + "imsave(\"or_result/img_gray_version.png\", rgb2gray(lab2rgb(cur)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plt.figure()\n", + "\n", + "a=fig.add_subplot(1,2,1)\n", + "img_rslt=mpimg.imread('or_result/'+style_input)\n", + "imgplot= plt.imshow(img_rslt)\n", + "a.set_title('Result')\n", + "a.axis('off')\n", + "a=fig.add_subplot(1,2,2)\n", + "img_gray=mpimg.imread('or_result/img_gray_version.png')\n", + "imgplot= plt.imshow(img_gray,cmap='gray')\n", + "a.set_title('Gray')\n", + "\n", + "a.axis('off')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#After Process\n", + "filename = 'or_result/'+style_input\n", + "image = cv2.imread(filename)\n", + "frame = cv2.imread(filename)\n", + "gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n", + "ret, imgf = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)\n", + "#imgf = cv2.adaptiveThreshold(gray_image,11,cv2.ADAPTIVE_THRESH_GAUSSIAN_C,cv2.THRESH_BINARY,255,125)\n", + "mask = imgf\n", + "res = cv2.bitwise_and(frame, frame, mask=mask)\n", + "cv2.imwrite('cv/'+style_input,res)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "imgplt = plt.imshow(cv2.cvtColor(res,cv2.COLOR_BGR2RGB))\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}