From 4688ffc172d1d2e41a63f51e5171ecca968278ee Mon Sep 17 00:00:00 2001 From: dbuscombe-usgs Date: Tue, 27 Jul 2021 19:45:39 -0700 Subject: [PATCH] 7/27/21 --- README.md | 58 ++ doodler.py | 4 +- install/Dockerfile.miniconda | 23 + my_defaults.py | 4 +- src/__pycache__/defaults.cpython-36.pyc | Bin 0 -> 317 bytes .../image_segmentation.cpython-36.pyc | Bin 0 -> 13429 bytes utils/gen_npz_4_zoo.py | 6 +- utils/plot_label_generation.py | 879 ++++++++---------- utils/viz_npz.py | 12 +- 9 files changed, 485 insertions(+), 501 deletions(-) create mode 100644 install/Dockerfile.miniconda create mode 100644 src/__pycache__/defaults.cpython-36.pyc create mode 100644 src/__pycache__/image_segmentation.cpython-36.pyc diff --git a/README.md b/README.md index 05259bf..4e4324c 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,64 @@ More demonstration videos (older version of the program): ![Coast Train example 2](https://raw.githubusercontent.com/dbuscombe-usgs/dash_doodler/main/assets/logos/doodler-demo-2-9-21-short-coast2.gif) + + ## Acknowledgements diff --git a/doodler.py b/doodler.py index ae12933..b94015f 100644 --- a/doodler.py +++ b/doodler.py @@ -377,7 +377,7 @@ def shapes_seg_pair_as_dict(d, key, seg, remove_old=True): # Slider for specifying pen width dcc.Slider( id="crf-downsample-slider", - min=2, + min=1, max=6, step=1, value=DEFAULT_CRF_DOWNSAMPLE, @@ -410,7 +410,7 @@ def shapes_seg_pair_as_dict(d, key, seg, remove_old=True): # Slider for specifying pen width dcc.Slider( id="rf-downsample-slider", - min=2, + min=1, max=20, step=1, value=DEFAULT_RF_DOWNSAMPLE, diff --git a/install/Dockerfile.miniconda b/install/Dockerfile.miniconda new file mode 100644 index 0000000..81898a5 --- /dev/null +++ b/install/Dockerfile.miniconda @@ -0,0 +1,23 @@ +# +FROM continuumio/miniconda3 +LABEL maintainer "Doodler, by Dr Daniel Buscombe, Marda Science/USGS " +WORKDIR / +# The code to run when container is started: +COPY ./ ./ + +COPY install/dashdoodler.yml . +RUN conda env create -f dashdoodler.yml + +# Make RUN commands use the new environment: +SHELL ["conda", "run", "-n", "dashdoodler", "/bin/bash", "-c"] + +EXPOSE 8050/tcp +EXPOSE 8050/udp +EXPOSE 80 +EXPOSE 8080 + +# set environment variables +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 + +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "dashdoodler", "python", "doodler.py"] diff --git a/my_defaults.py b/my_defaults.py index db945f6..107a454 100644 --- a/my_defaults.py +++ b/my_defaults.py @@ -1,6 +1,6 @@ DEFAULT_PEN_WIDTH = 3 -DEFAULT_CRF_DOWNSAMPLE = 2 -DEFAULT_RF_DOWNSAMPLE = 2 +DEFAULT_CRF_DOWNSAMPLE = 1 +DEFAULT_RF_DOWNSAMPLE = 1 DEFAULT_CRF_THETA = 1 DEFAULT_CRF_MU = 1 DEFAULT_CRF_GTPROB = 0.9 diff --git a/src/__pycache__/defaults.cpython-36.pyc b/src/__pycache__/defaults.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9efe3fc0563fe07552f0e7f6a9cfd45b870282d GIT binary patch literal 317 zcmXr!<>k`uo{?~nm4V?g0}@~avK@f9SOG|+Fhnt=Fh((^Fhwz?Fh?<`utc!{`K&3d zQEVw}!3>)0FPVVKUNVCSE)bysB2s`v`q?ue@W$RxlRZk%#nsI*)F&i9z|}84+|wn* zBT5V=NPN_<&QoOk;Ir)&}_jS)M9s+1r zur)nB{qFwW-`BW2I$HRfi~sUq?^8wjcjdre7Wwbv2|ZO&n8H*~vE*B|G<<8G?i-fj zXRM5$wX!PJ>0ZvyTX~T;yb-@(6+}Mc75$P`68Wq*>X)su$gAF1YuuVZ&zx8BC#^|Q zn)j-H)vBf~rdre13|fqMv;Lel=g(X7qNd<2_>0z}$QQjA{8QE`|3&LX-?U8sCF`ZM z_i5`i`l{X;>x_7xwa(&Q^3M6^t@HlN*2|*ZsP~FrvuYw=_Fna0vtCmbM{8YRWB85l zX{ecCl`jAy`FC+IBdxh1|>TUKadkyb*n0imCz5a~Auw2uVT+{U;hlj{GcGTrg zQg;2dx*gV=UfYf?U#h9e=J8f?BtsAaKO|q^Z+K(ugXLjU7uJ6>eWUj$6Eyl84{@4wj`l{>Uo7-M8(GAeM z-sEk+?%Au37t$!)3EiGkGZKS!{Z3+Fku1r7Y@>(gcu9^qp1tKT3|--l-+tuOLmyT3 z^|t4PN%c@!d$r{>A{sUqFxR)&Xmy1~J9yOg9yym7*1Q(9Io4@qnfu6Pj+JwRN4Dp( z1~69q(c8fD0iLjgBvN8v1Jm|2oU`7WHP@X@n>h`P@0n5Cbc4uQbGZ3z9E~xbv1*4j zuZ{}5)ZnibzJRLVFD@;;9`eTPVqx{rS@WGBvLm-0EOoX3E}RR7{x;34AyG8`0@8*~ zdHiMYP+lzTTgcGDTI!CnqxSVk8|IBzi`7}w#rmvLRbqqQ3reicDA<77d$VBT564cB z4NW`14m#b)1Q5+tCwlBSf$0Y*XQsa?+5)R|-wm)tfHHCdeA%0|OroPpqWJ-y%|!JR z^|sikOct@iB$rNe1s@@C2PIolODaE)-0;Ndv}toY(_d2J4x>8q&I z@aUpmN6v`V9d$w3(fU-Lq4GXzY16T`p!Ag}E8(}J_mypZTid2Fa-zSnq4RBg<f6LmRJLw*)x9X+lh+6k%yMO;UHSV5v_i|V3Uz_FFotX5XX zHU1__8v;j=p`)Xt7034v$Y4V_zLrLqfvsoez<6hs4Xvd|1~#Z|?naqdYh_t>4lS-J zpFaZrrH(Q$INgf%-Enq*7`=2|D#$dt&tcrVLqi*z#6_W@>ONqc=*PB zT4@orxgCvZGXT{#uo#D*UtZ!bV+Qe+7m8LeSPCdptE%VmI}1RKtD4Hchq@u4a`+em>bLkf0x07spirOU^JWYz z)>=6NrT|&14>CUlFoOWr&qR3~eg*B`!?7NAz>ICLf!5(X5=Aq$f?8Fl1@WB)JZAw_0oV6XJLF%- z5+Ec7erq~Oj`S_Q*NDdmE+#mYw)qucqsJP^OxrN{-!q(Lj5Cc0zWN4OIWt&>W%mpa zse!uvZ-ZDn^_nASq&YD9v{XPnlYU`D8ka^y>vU|C8Q_m#6>=>Pm+qq#2r#!S=uZIh zhGm4d?;zdseMn)$Z!;&!uM#0aoY%6F5;HdE_Lh}#!@!1s2)jO?!XYI(K5H3EbA7Ae zc7i&995|k3pg6P&kCq*GZGE-PLkq0&(oqPNAen5ifI79_#&F>*5=A`?h>WXwz-C-E z)iT}&Qcf(hA*2j4#07@b_B$|3;49=fXaw>AIh;d37pb)OmM(;79+1*e0%q3&m9;WE z3NscUb-3d*!WUE`2JE8G)_$YNNtmLBi3{53QB_&49}o2JcA;3Ob>92_=7l!623D>%0}Zn zogPy+w69hElXx^P%_{wJa7*R{11|%|P%4@EG2@OMGwygi9*?Dc&_9LI#BhC;sD%6| z8wJ-Y#~D`M)BBTAxiuC~#Fe-Vsa}n%t?_u0V1V2NwCOR_SMW?i4>tgrDpVq^sl_Ao zYw9>%Xy%}(DYT!8r(%8Bep0kQ3D(tgj%hXpDxSt}X5wj3`1GFEnx*wLLC398l+O`8 z18-(d#G4s`GlYLL>8Z~~^YIKgaIOiC!DiVURyY$apsW;+?&yPYpEKsXSl>)sZj!FL zX#XSB7oowOZt!NEJ54C}uLxmx!GtDKJ1r_#*4@x-c7q0~qEHN}74$+=G(@l59C=n9J zX&w)oHkX!`fD}!KJ5VxA=o&l%3O*~KGZtw;w-0+28f3qb58Lksa}ILUSxUBu=a4h(Q0{9rXTO`L{2l0~ToNiDYLJNzQAptzm$~!0OtZ77cORew^gPh&Ra|=clMK8L>j&f{6`2U=^-8LFiy5 zm#$gm(1|`Fv;2DC@U<;#(q`=A)O7#Q?qY#?QokH$W#z7%M2 zlArFOTi8OPR10bWR@@?#mNLwRvi#|%v|0Ls!dbw(iWUW7Aad2JeP3TSZ_EikfHr5|#RPxuuC!x~u&Fcqc#2U8EyeP{m* z8U?f*?Ppp)3q8qzB?4OlMqpOtb0VLMGp#K8=VDkJlS-VUbcXst7H?;_g~k*WT19A0 zB?gmF$UJCKEEB7#-a{iptKHF_LOr(Ib4e79hFb1c!O=1q<+XDNE*^-1E(4@KQ zbvY;QF=XK8g9o7ti}t|-A&vGmc>W`R{5XN}+O2j&u+{;O3Gq~b-ULU_lPgpl9eLz9`KW$hxaz^m zBqsbWyu{M?mQ{PF2aB(ZQ(e01;jp~$U1SE$d#|oQsKCYgM{KnBiV%-~M9W$=-G}dl z>ucO)(rSK{7EB@#8%!B9To|2>1AoSARQE9@Kc<8v0l!PhJtW|FZ8%$#5f@TEYPY>G zc>%j!bprwWu>Kgz7uZzCvl|Y%B&8L8f%+3vtm$W=3?is#b7tb)vlIS4iYs!b~pZp89n011o# zLPbbMy^=SyS9CbWv_;JT(<_5vf;ozu0Y@2?WWgGzbudi4VY`bmD(4&M|DQC7qWs%? zdAB*;GJZxr+h54i_)j6(D$?OD8fcWJ|AcC zChJk=(i+%1I`PTWw$FgW7FdbpSbk4?qC8QbXixN}!A4l&i+r3TjwE;zaqHt&Q5>$K z7?oP1tuk@09Tmhe+WWzcW4Q_`PcX!#p1uU$0m^vrz-&OjUv%*wuEGT^-yY^QodL4`DcA(;76vH8keyl(n|2pkO-6GGkYwO=i~s>SXqEkoLJ& zJ{p0u8V+grIe}czKTrnOH03~3LKn!#`Bu3iAnqOl#NPxLk(7vmB5bqZ_*7JCVP zk^(DyNq)AH%)!rAf_4D}-Og{9hCPaD4~!}(^#HOQkpCBn5`CPudvQQn!bFv%d5BVo z?+Z#H-VI7QOOyf{Q5jNl5eaIM+gA5-DaDAe)>18V1Pyo>gYZX4mcfe&AOpIID4#t_ zH`G%0IZQ#kjS5<(2x=K4y)#E~;uyl&S6-g8uH)F!LP~H0)r3461N*9{J$?;3L57Ja z!S`vsBvcWN;nGQ6IpBYONYz9?oKq4~LWgiNS=OjlSZ7<5BerE_eaH80M2nhUH(alc zN{Lw#;7p-;;<59`G)#*U8qg{duL}{M zMF`N5idF$RYQtY1RqALZbL7pO%hguPQ`(_)a_DA8`RyOkdXOThCsGqAKD-o$e zIU-Yp*R4#yEL`CD##a`EITn|xRb2Rz!Xo8(?7-JpT;#ut5KIxo4TtQG3O#eYKY`VZ zz#R$q8JI^|d`-k-;%kgd5n>Hs2b9ajWin~zQt{97eK;d!NnF_=6XTI+HxgG;1wnxW zPid8zF=)pM2vJscu&34-Lr|ta2?yt-+?#OdV%6~^Rxjr+oisOMcbJ)K<09t$4We|U zE>vm$Xq*^Nt3vM6VgK^L@DM8;gA;U8t}$4Ts?pSr5>2ylVqN4WrOOLz8k^cf_~$ue zPK)`gaEM~nJGlIC;_SlliIr;nm(grog)_FoX7}{x^q3p2ca}znBUk%_edY&l_U-ca z*ddI8BUoSzIe+J3>Bv3lq!z@k1pZ;Y{~b@a=3yuj{w)%N$0^UieY*gBn+?tax(hLd zci{TtW!Dt8#pH9H*bUb#VQ1dFR63)~3TN&HytlL&e`laje14fROM{u}K`jS<4Z0>ELxZUYs zhgN0XWy}e1iDe7kDT)|}z3H0)yxZigq%*+%n)F9_i6!V!VQ=C# zK>%m@2q%sP*v0rYG&)prPB!YMZg?7m+BzZpd{}MjiIt}Q8^?P+$r&ktO)DWk=TI`% z7KJ{^h0RBdonJ(f6h0-#+jY*{JW*FHgw!9wj?BOl4Ap%MEv(`JY&jqqNvw(>ZIZ`* zBB7zmfG^32eakc{1@&|ThyZdj|2dgmocvUlj#ct9<#ISD<~Ab!3BD4e(e7-)V@N7H z9E#1lkb@(1A~I_bmOr0(tnmXrL+M`3-FyGj>-D=gq{RIN%{w~4={gN=jSWAg+uN$p z%oLgyh%Xye50|Aai=OM(Q({4+GhmhuVue4J#?8t`{7>YDglhebP-up=>%6%~;`*=9 zPJx4PUrbZ3plfgt?w93t0}jGGqSxaYI2nbOT|zH75#e4m5XmmXWk^v}axE4h-t#*2 zakWw*2P1TGVLumOUCUaEvZ9t!^ro1&R1IWWLLaCGLS5irVSSBBS|$EY;UNuPXzzb3 zRD)f30-qX)k_qh`;wdwp>Tg;%bEXdv}6;#7z|T*i%s zylhDxy+K=wee?e~Zd}nnbF@b)wzE(Piir6tEGy?c+77ppC<0kTIfQV(`n4APHX7YU z6dvkKkz#iww4hc@K=rBYl%7HX$9Opi_Nl16i`afVRmJ_KrxbBL{zPESL%2_|5usSz zipHqdb9)FMBoNbNoX3#4~A}N~KrMnQuv;DdL{Cy>W%bI>Co{>;M;M1PPGlyq>^V8?e^7&i+h21PHz)ZZbqhp@M z{!IUcc(yeiFN)KiLr$tH^T-MD{I&M*7x$IVltBB*$?JgdA7Tzge`oQ`#k1JqVmwER z&fw(x)$Q5sxfAz;2=a4wI}cM1CY_uqo+sS?TIT`g$`Neis(`ft%M+L>l*=mW#-VOi z2RINTnk{hvYTZegCvYIHLM?e=`8X+DkzP%@r9-zxAqttW-Fh>)BuPOws3W{hH!a~& zYC;zGKE4xph#dxw)LalE3)=D)^T+ny)9$*}K*OH9Zlb%XCwLvq?A z>SH*3enOx3DZv$Rg?~uN(4(FdBZy*(eL+vSg_93SiYt&0=OIDzFeS@c@1;ZcUx&{1 z13mdIbb`;pr@;dxhkPBRA_!jx$rt!Krpec_i=E@F;1}G{Q}2ej^GY5Ja{moH7++If z-`^qsGZgG^K)5VEqlM#9;O=|z_{9!ggZu{C_1+i3-SpzJ89}qTNI_WB?(X08MKByX z1EaVzcc6L-H4J)O=x`{Hy*FuyWAN#md%}1jwsaelTe-({lPU^(Z{HUgU?Q|NPIos3 zn2Bk=u_z+TZiMT2|Q1GD+B@Rg`{`A=yW5n(1wr8FkQybf zE>f4$L9{FqzLFD8nUP+bsyeGyw0Wpxiy*vl;}rZj6mceh4uweQO^PAqs%Tj$lG=9# zDqfaUJJkT@8@Os$Q48?xOp|h`oran>XkS1dF@h-Je~y`J`pU{mErW!6_!L2vS{9!N zzrQ2gYE}k^(%BLSlH|be>GFt(Z;7CjaB^GI;f4p>D_9DgF8&{&9^iUKb zyw4)4W=(dsVEG|J!JA8i3@)CgMSKMft?>hGx{>RJ;(CO*hG9(%MupWNFSK$}j;$=> zAD+7^>>XiRN#rM{m1u4|(b}OPGOIEez;E*o{s*GHwk57Mi%ZUQwZX-qip#Slkzax+ z!j&s=858y{-Spw&{|dy#3rW3nOIgyMAUnD@AOa=Q;kiwpf<_6%l1mC+zt8aVSHZKo zp7Smhg%o*&<$-GmbY-2QbK>$k(s6vv;zw`TB6>cPFXqSaTfnnU`d#6Q(DEpEW{^Yv Ijr{ii0c;NN*Z=?k literal 0 HcmV?d00001 diff --git a/utils/gen_npz_4_zoo.py b/utils/gen_npz_4_zoo.py index fe6b197..46f94b1 100644 --- a/utils/gen_npz_4_zoo.py +++ b/utils/gen_npz_4_zoo.py @@ -68,7 +68,11 @@ def make_npz(): data[k] = dat[k] del dat - classes = data['classes'] + try: + classes = data['classes'] + except: + classes = ['water', 'land'] + class_string = '_'.join([c.strip() for c in classes]) savez_dict = dict() diff --git a/utils/plot_label_generation.py b/utils/plot_label_generation.py index 6737257..f7d65cf 100644 --- a/utils/plot_label_generation.py +++ b/utils/plot_label_generation.py @@ -41,6 +41,8 @@ import matplotlib import matplotlib.pyplot as plt +from imageio import imwrite + ###=========================================================== try: sys.path.insert(1, '../') @@ -83,371 +85,347 @@ def gen_plot_seq(orig_distance, save_mode): #### loop through each file for anno_file in tqdm(files): - # print("Working on %s" % (file)) - print("Working on %s" % (anno_file)) - dat = np.load(anno_file) - data = dict() - for k in dat.keys(): - data[k] = dat[k] - del dat - - class_label_names = data['classes'] - - NUM_LABEL_CLASSES = len(class_label_names) - if NUM_LABEL_CLASSES<=10: - class_label_colormap = px.colors.qualitative.G10 + if os.path.exists(anno_file.replace('.npz','_label.png')): + print('%s exists ... skipping' % (anno_file.replace('.npz','_label.png'))) + continue else: - class_label_colormap = px.colors.qualitative.Light24 - - # we can't have fewer colors than classes - assert NUM_LABEL_CLASSES <= len(class_label_colormap) - - colormap = [ - tuple([fromhex(h[s : s + 2]) for s in range(0, len(h), 2)]) - for h in [c.replace("#", "") for c in class_label_colormap] - ] - - cmap = matplotlib.colors.ListedColormap(class_label_colormap[:NUM_LABEL_CLASSES]) - cmap2 = matplotlib.colors.ListedColormap(['#000000']+class_label_colormap[:NUM_LABEL_CLASSES]) - - - savez_dict = dict() - - ## if more than one label ... - if len(np.unique(data['doodles']))>2: - - img = data['image'] - del data['image'] - - #================================ - ##fig1 - img versus standardized image - plt.subplot(121) - plt.imshow(img); plt.axis('off') - plt.title('a) Original', loc='left', fontsize=7) - - # #standardization using adjusted standard deviation - img = standardize(img) - - #================================ - ##fig2 - img / doodles - plt.subplot(122) - plt.imshow(img); plt.axis('off') - plt.title('b) Filtered', loc='left', fontsize=7) - plt.savefig(anno_file.replace('.npz','_image_filt_labelgen.png'), dpi=200, bbox_inches='tight') - plt.close() - - tmp = data['doodles'].astype('float') - tmp[tmp==0] = np.nan - - ## do plot of images and doodles - plt.imshow(img) - plt.imshow(tmp, alpha=0.25, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2) #'inferno') - plt.axis('off') - plt.colorbar(shrink=0.5) - plt.savefig(anno_file.replace('.npz','_image_doodles_labelgen.png'), dpi=200, bbox_inches='tight') - plt.close() - del tmp - - ## "analytical toola" e.g. compute annotations per unit area of image and per class label - is there an ideal number or threshold not to go below or above? - - #####=========================== RF - - if np.ndim(img)==3: - features = extract_features( - img, - multichannel=True, - intensity=True, - edges=True, - texture=True, - sigma_min=1, #SIGMA_MIN, - sigma_max=16, #SIGMA_MAX, - ) - else: - features = extract_features( - np.dstack((img,img,img)), - multichannel=True, - intensity=True, - edges=True, - texture=True, - sigma_min=1, #SIGMA_MIN, - sigma_max=16, #SIGMA_MAX, - ) - - counter=1 - for k in [0,1,2,3,4]: - plt.subplot(2,5,counter) - plt.imshow(features[k].reshape((img.shape[0], img.shape[1])), cmap='gray'); plt.axis('off') - if k==0: - plt.title('a) Smallest scale', loc='left', fontsize=7) - counter+=1 - - for k in [70,71,72,73,74]: - plt.subplot(2,5,counter) - plt.imshow(features[k].reshape((img.shape[0], img.shape[1])), cmap='gray'); plt.axis('off') - if k==70: - plt.title('b) Largest scale', loc='left', fontsize=7) - counter+=1 - - plt.savefig(anno_file.replace('.npz','_image_feats_labelgen.png'), dpi=200, bbox_inches='tight') - plt.close() - - #================================ - doodles = data['doodles'] - training_data = features[:, doodles > 0].T - training_labels = doodles[doodles > 0].ravel() - del doodles - - training_data = training_data[::DEFAULT_RF_DOWNSAMPLE] - training_labels = training_labels[::DEFAULT_RF_DOWNSAMPLE] - - if save_mode: - savez_dict['color_doodles'] = data['color_doodles'].astype('uint8') - savez_dict['doodles'] = data['doodles'].astype('uint8') - savez_dict['settings'] = data['settings'] - savez_dict['label'] = data['label'].astype('uint8') - - del data - - #================================ - clf = make_pipeline( - StandardScaler(), - MLPClassifier( - solver='lbfgs', alpha=1, random_state=1, max_iter=2000, - early_stopping=True, hidden_layer_sizes=[100, 100], - )) - clf.fit(training_data, training_labels) - - #================================ - - del training_data, training_labels - - # use model in predictive mode - sh = features.shape - features_use = features.reshape((sh[0], np.prod(sh[1:]))).T - - if save_mode: - savez_dict['features'] = features.astype('float16') - del features - - rf_result = clf.predict(features_use) - #del features_use - rf_result = rf_result.reshape(sh[1:]) - - #================================ - plt.imshow(img) - plt.imshow(rf_result-1, alpha=0.25, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap) #'inferno') - plt.axis('off') - plt.colorbar(shrink=0.5) - plt.savefig(anno_file.replace('.npz','_image_label_RF_labelgen.png'), dpi=200, bbox_inches='tight') - plt.close() - - #================================ - plt.subplot(221); plt.imshow(rf_result-1, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap); plt.axis('off') - plt.title('a) Original', loc='left', fontsize=7) - - rf_result_filt = filter_one_hot(rf_result, 2*rf_result.shape[0]) - if save_mode: - savez_dict['rf_result_filt'] = rf_result_filt - - plt.subplot(222); plt.imshow(rf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') - plt.title('b) Filtered', loc='left', fontsize=7) - - if rf_result_filt.shape[0]>512: - ## filter based on distance - rf_result_filt = filter_one_hot_spatial(rf_result_filt, orig_distance) - - if save_mode: - savez_dict['rf_result_spatfilt'] = rf_result_filt - - plt.subplot(223); plt.imshow(rf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') - plt.title('c) Spatially filtered', loc='left', fontsize=7) - - # rf_result_filt_inp = inpaint_zeros(rf_result_filt).astype('uint8') - - rf_result_filt = rf_result_filt.astype('float') - rf_result_filt[rf_result_filt==0] = np.nan - rf_result_filt_inp = inpaint_nans(rf_result_filt).astype('uint8') - - plt.subplot(224); plt.imshow(rf_result_filt_inp, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') - plt.title('d) Inpainted', loc='left', fontsize=7) - plt.savefig(anno_file.replace('.npz','_rf_label_filtered_labelgen.png'), dpi=200, bbox_inches='tight') - plt.close() - - ###======================================================== - #### demo of the spatial filter - if NUM_LABEL_CLASSES==2: + # print("Working on %s" % (file)) + print("Working on %s" % (anno_file)) + dat = np.load(anno_file) + data = dict() + for k in dat.keys(): + data[k] = dat[k] + del dat + # print(data['image'].shape) - distance = orig_distance #3 - shrink_factor= 0.66 - rf_result_filt = filter_one_hot(rf_result, 2*rf_result.shape[0]) + if 'classes' not in locals(): - lstack = (np.arange(rf_result_filt.max()) == rf_result_filt[...,None]-1).astype(int) #one-hot encode + try: + classes = data['classes'] + except: + Tk().withdraw() # we don't want a full GUI, so keep the root window from appearing + classfile = askopenfilename(title='Select file containing class (label) names', filetypes=[("Pick classes.txt file","*.txt")]) - plt.figure(figsize=(12,16)) - plt.subplots_adjust(wspace=0.2, hspace=0.5) + with open(classfile) as f: + classes = f.readlines() - plt.subplot(631) - plt.imshow(img); plt.imshow(rf_result_filt-1, cmap='gray', alpha=0.25) - plt.axis('off'); plt.title('a) Label', loc='left', fontsize=7) #plt.colorbar(shrink=shrink_factor); + class_label_names = [c.strip() for c in classes] + NUM_LABEL_CLASSES = len(class_label_names) - plt.subplot(635) - plt.imshow(img); plt.imshow(lstack[:,:,0], cmap='gray', alpha=0.25) - plt.axis('off'); plt.title('b) "Zero-hot"', loc='left', fontsize=7) #plt.colorbar(shrink=shrink_factor); + if NUM_LABEL_CLASSES<=10: + class_label_colormap = px.colors.qualitative.G10 + else: + class_label_colormap = px.colors.qualitative.Light24 - plt.subplot(636) - plt.imshow(img); plt.imshow(lstack[:,:,1], cmap='gray', alpha=0.25) - plt.axis('off'); plt.title('c) "One-hot"', loc='left', fontsize=7) #plt.colorbar(shrink=shrink_factor); + # we can't have fewer colors than classes + assert NUM_LABEL_CLASSES <= len(class_label_colormap) - tmp = np.zeros_like(rf_result_filt) - for kk in range(lstack.shape[-1]): - l = lstack[:,:,kk] - d = ndimage.distance_transform_edt(l) - l[d2: - plt.subplot(6,3,12) - plt.imshow(img); plt.imshow(rf_result_filt, cmap='gray', alpha=0.25) - plt.axis('off'); plt.title('g) Label encoded with zero class', loc='left', fontsize=7) #plt.colorbar(shrink=shrink_factor); - - ##double distance - distance *= 3 - tmp = np.zeros_like(rf_result_filt) - for kk in range(lstack.shape[-1]): - l = lstack[:,:,kk] - d = ndimage.distance_transform_edt(l) - l[d 0].T + training_labels = doodles[doodles > 0].ravel() + del doodles + + training_data = training_data[::DEFAULT_RF_DOWNSAMPLE] + training_labels = training_labels[::DEFAULT_RF_DOWNSAMPLE] + if save_mode: - savez_dict['crf_tta'] = [r.astype('uint8') for r in R] - savez_dict['crf_tta_weights'] = W + savez_dict['color_doodles'] = data['color_doodles'].astype('uint8') + savez_dict['doodles'] = data['doodles'].astype('uint8') + savez_dict['settings'] = data['settings'] + savez_dict['label'] = data['label'].astype('uint8') - crf_result = np.round(np.average(np.dstack(R), axis=-1, weights = W)).astype('uint8') - del R, W, n, w, r + del data #================================ - plt.subplot(221); plt.imshow(crf_result-1, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap); plt.axis('off') - plt.title('a) Original', loc='left', fontsize=7) + clf = make_pipeline( + StandardScaler(), + MLPClassifier( + solver='adam', alpha=1, random_state=1, max_iter=2000, + early_stopping=True, hidden_layer_sizes=[100, 60], + )) + clf.fit(training_data, training_labels) + + #================================ + + del training_data, training_labels - crf_result_filt = filter_one_hot(crf_result, 2*crf_result.shape[0]) + # use model in predictive mode + sh = features.shape + features_use = features.reshape((sh[0], np.prod(sh[1:]))).T if save_mode: - savez_dict['crf_result_filt'] = crf_result_filt - savez_dict['crf_result'] = crf_result-1 + savez_dict['features'] = features.astype('float16') + del features + + rf_result = clf.predict(features_use) + #del features_use + rf_result = rf_result.reshape(sh[1:]) - del crf_result + #================================ + plt.imshow(img) + plt.imshow(rf_result-1, alpha=0.25, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap) #'inferno') + plt.axis('off') + plt.colorbar(shrink=0.5) + plt.savefig(anno_file.replace('.npz','_image_label_RF_labelgen.png'), dpi=200, bbox_inches='tight') + plt.close() - plt.subplot(222); plt.imshow(crf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') + #================================ + plt.subplot(221); plt.imshow(rf_result-1, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap); plt.axis('off') + plt.title('a) Original', loc='left', fontsize=7) + + rf_result_filt = filter_one_hot(rf_result, 2*rf_result.shape[0]) + if save_mode: + savez_dict['rf_result_filt'] = rf_result_filt + + plt.subplot(222); plt.imshow(rf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') plt.title('b) Filtered', loc='left', fontsize=7) - if crf_result_filt.shape[0]>512: + if rf_result_filt.shape[0]>512: ## filter based on distance - crf_result_filt = filter_one_hot_spatial(crf_result_filt, distance) + rf_result_filt = filter_one_hot_spatial(rf_result_filt, orig_distance) if save_mode: - savez_dict['rf_result_spatfilt'] = crf_result_filt + savez_dict['rf_result_spatfilt'] = rf_result_filt - plt.subplot(223); plt.imshow(crf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') + plt.subplot(223); plt.imshow(rf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') plt.title('c) Spatially filtered', loc='left', fontsize=7) - crf_result_filt = crf_result_filt.astype('float') - crf_result_filt[crf_result_filt==0] = np.nan - crf_result_filt_inp = inpaint_nans(crf_result_filt).astype('uint8') - del crf_result_filt + # rf_result_filt_inp = inpaint_zeros(rf_result_filt).astype('uint8') + + rf_result_filt = rf_result_filt.astype('float') + rf_result_filt[rf_result_filt==0] = np.nan + rf_result_filt_inp = inpaint_nans(rf_result_filt).astype('uint8') - plt.subplot(224); plt.imshow(crf_result_filt_inp, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') - plt.title('d) Inpainted (final label)', loc='left', fontsize=7) + plt.subplot(224); plt.imshow(rf_result_filt_inp, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') + plt.title('d) Inpainted', loc='left', fontsize=7) - plt.savefig(anno_file.replace('.npz','_crf_label_filtered_labelgen.png'), dpi=200, bbox_inches='tight') + plt.savefig(anno_file.replace('.npz','_rf_label_filtered_labelgen.png'), dpi=200, bbox_inches='tight') plt.close() - else: + ###======================================================== + #### demo of the spatial filter + + if NUM_LABEL_CLASSES==2: + + distance = orig_distance #3 + shrink_factor= 0.66 + rf_result_filt = filter_one_hot(rf_result, 2*rf_result.shape[0]) + + lstack = (np.arange(rf_result_filt.max()) == rf_result_filt[...,None]-1).astype(int) #one-hot encode + + plt.figure(figsize=(12,16)) + plt.subplots_adjust(wspace=0.2, hspace=0.5) + + plt.subplot(631) + plt.imshow(img); plt.imshow(rf_result_filt-1, cmap='gray', alpha=0.25) + plt.axis('off'); plt.title('a) Label', loc='left', fontsize=7) #plt.colorbar(shrink=shrink_factor); + + plt.subplot(635) + plt.imshow(img); plt.imshow(lstack[:,:,0], cmap='gray', alpha=0.25) + plt.axis('off'); plt.title('b) "Zero-hot"', loc='left', fontsize=7) #plt.colorbar(shrink=shrink_factor); + + plt.subplot(636) + plt.imshow(img); plt.imshow(lstack[:,:,1], cmap='gray', alpha=0.25) + plt.axis('off'); plt.title('c) "One-hot"', loc='left', fontsize=7) #plt.colorbar(shrink=shrink_factor); + + tmp = np.zeros_like(rf_result_filt) + for kk in range(lstack.shape[-1]): + l = lstack[:,:,kk] + d = ndimage.distance_transform_edt(l) + l[d1: + if save_mode: + savez_dict['crf_tta'] = [r.astype('uint8') for r in R] + savez_dict['crf_tta_weights'] = W - crf_result, n = crf_refine(rf_result_filt_inp, img, DEFAULT_CRF_THETA, DEFAULT_CRF_MU, DEFAULT_CRF_DOWNSAMPLE, DEFAULT_CRF_GTPROB) + crf_result = np.round(np.average(np.dstack(R), axis=-1, weights = W)).astype('uint8') + del R, W, n, w, r #================================ plt.subplot(221); plt.imshow(crf_result-1, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap); plt.axis('off') @@ -466,7 +444,7 @@ def gen_plot_seq(orig_distance, save_mode): if crf_result_filt.shape[0]>512: ## filter based on distance - crf_result_filt = filter_one_hot_spatial(crf_result_filt, orig_distance) + crf_result_filt = filter_one_hot_spatial(crf_result_filt, distance) if save_mode: savez_dict['rf_result_spatfilt'] = crf_result_filt @@ -474,7 +452,6 @@ def gen_plot_seq(orig_distance, save_mode): plt.subplot(223); plt.imshow(crf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') plt.title('c) Spatially filtered', loc='left', fontsize=7) - #crf_result_filt_inp = inpaint_zeros(crf_result_filt).astype('uint8') crf_result_filt = crf_result_filt.astype('float') crf_result_filt[crf_result_filt==0] = np.nan crf_result_filt_inp = inpaint_nans(crf_result_filt).astype('uint8') @@ -485,49 +462,97 @@ def gen_plot_seq(orig_distance, save_mode): plt.savefig(anno_file.replace('.npz','_crf_label_filtered_labelgen.png'), dpi=200, bbox_inches='tight') plt.close() + else: - crf_result_filt_inp = rf_result_filt_inp.copy() - - #================================ - plt.imshow(img) - plt.imshow(crf_result_filt_inp-1, alpha=0.25, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap) #'inferno') - plt.axis('off') - plt.colorbar(shrink=0.5) - plt.savefig(anno_file.replace('.npz','_image_label_final_labelgen.png'), dpi=200, bbox_inches='tight') - plt.close() - - if save_mode: - tosave = (np.arange(crf_result_filt_inp.max()) == crf_result_filt_inp[...,None]-1).astype(int) - savez_dict['final_label'] = tosave.astype('uint8')#crf_result_filt_inp-1 - savez_dict['image'] = (255*img).astype('uint8') - del img, crf_result_filt_inp - - ### if only one label - else: - if save_mode: - savez_dict['color_doodles'] = data['color_doodles'].astype('uint8') - savez_dict['doodles'] = data['doodles'].astype('uint8') - savez_dict['settings'] = data['settings'] - savez_dict['label'] = data['label'].astype('uint8') - v = np.unique(data['doodles']).max()#[0]-1 - if v==2: - tmp = np.zeros_like(data['label']) - tmp+=1 - else: - tmp = np.ones_like(data['label'])*v - tosave = (np.arange(tmp.max()) == tmp[...,None]-1).astype(int) - savez_dict['final_label'] = tosave.astype('uint8').squeeze() - savez_dict['crf_tta'] = None - savez_dict['crf_tta_weights'] = None - savez_dict['crf_result'] =None - savez_dict['rf_result_spatfilt'] = None - savez_dict['crf_result_filt'] = None - savez_dict['image'] = data['image'].astype('uint8') - del data - np.savez(anno_file.replace('.npz','_labelgen.npz'), **savez_dict ) - del savez_dict - plt.close('all') + if len(np.unique(rf_result_filt_inp.flatten()))>1: + + crf_result, n = crf_refine(rf_result_filt_inp, img, DEFAULT_CRF_THETA, DEFAULT_CRF_MU, DEFAULT_CRF_DOWNSAMPLE, DEFAULT_CRF_GTPROB) + + #================================ + plt.subplot(221); plt.imshow(crf_result-1, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap); plt.axis('off') + plt.title('a) Original', loc='left', fontsize=7) + + crf_result_filt = filter_one_hot(crf_result, 2*crf_result.shape[0]) + + if save_mode: + savez_dict['crf_result_filt'] = crf_result_filt + savez_dict['crf_result'] = crf_result-1 + + del crf_result + + plt.subplot(222); plt.imshow(crf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') + plt.title('b) Filtered', loc='left', fontsize=7) + + if crf_result_filt.shape[0]>512: + ## filter based on distance + crf_result_filt = filter_one_hot_spatial(crf_result_filt, orig_distance) + + if save_mode: + savez_dict['rf_result_spatfilt'] = crf_result_filt + + plt.subplot(223); plt.imshow(crf_result_filt, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') + plt.title('c) Spatially filtered', loc='left', fontsize=7) + + #crf_result_filt_inp = inpaint_zeros(crf_result_filt).astype('uint8') + crf_result_filt = crf_result_filt.astype('float') + crf_result_filt[crf_result_filt==0] = np.nan + crf_result_filt_inp = inpaint_nans(crf_result_filt).astype('uint8') + del crf_result_filt + + plt.subplot(224); plt.imshow(crf_result_filt_inp, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap2); plt.axis('off') + plt.title('d) Inpainted (final label)', loc='left', fontsize=7) + + plt.savefig(anno_file.replace('.npz','_crf_label_filtered_labelgen.png'), dpi=200, bbox_inches='tight') + plt.close() + else: + crf_result_filt_inp = rf_result_filt_inp.copy() + + #================================ + plt.imshow(img) + plt.imshow(crf_result_filt_inp-1, alpha=0.25, vmin=0, vmax=NUM_LABEL_CLASSES, cmap=cmap) #'inferno') + plt.axis('off') + plt.colorbar(shrink=0.5) + plt.savefig(anno_file.replace('.npz','_image_label_final_labelgen.png'), dpi=200, bbox_inches='tight') + plt.close() + + + if save_mode: + tosave = (np.arange(crf_result_filt_inp.max()) == crf_result_filt_inp[...,None]-1).astype(int) + savez_dict['final_label'] = tosave.astype('uint8')#crf_result_filt_inp-1 + savez_dict['image'] = (255*img).astype('uint8') + del img, crf_result_filt_inp + + imwrite(anno_file.replace('.npz','_label.png'), np.argmax(savez_dict['final_label'],-1).astype('uint8')) + imwrite(anno_file.replace('.npz','_doodles.png'), savez_dict['doodles'].astype('uint8')) + + + ### if only one label + else: + if save_mode: + savez_dict['color_doodles'] = data['color_doodles'].astype('uint8') + savez_dict['doodles'] = data['doodles'].astype('uint8') + savez_dict['settings'] = data['settings'] + savez_dict['label'] = data['label'].astype('uint8') + v = np.unique(data['doodles']).max()#[0]-1 + if v==2: + tmp = np.zeros_like(data['label']) + tmp+=1 + else: + tmp = np.ones_like(data['label'])*v + tosave = (np.arange(tmp.max()) == tmp[...,None]-1).astype(int) + savez_dict['final_label'] = tosave.astype('uint8').squeeze() + savez_dict['crf_tta'] = None + savez_dict['crf_tta_weights'] = None + savez_dict['crf_result'] =None + savez_dict['rf_result_spatfilt'] = None + savez_dict['crf_result_filt'] = None + savez_dict['image'] = data['image'].astype('uint8') + del data + + np.savez(anno_file.replace('.npz','_labelgen.npz'), **savez_dict ) + del savez_dict + plt.close('all') @@ -572,139 +597,3 @@ def gen_plot_seq(orig_distance, save_mode): #ok, dooo it gen_plot_seq(orig_distance, save_mode) - - - # - # DEFAULT_CRF_THETA=1 - # DEFAULT_CRF_MU=1 - # DEFAULT_CRF_DOWNSAMPLE=2 - # DEFAULT_CRF_GTPROB=.9 - - # Tk().withdraw() # we don't want a full GUI, so keep the root window from appearing - # classfile = askopenfilename(title='Select file containing class (label) names', filetypes=[("Pick classes.txt file","*.txt")]) - # - # with open(classfile) as f: - # classes = f.readlines() - # - # class_label_names = [c.strip() for c in classes] - - - # do_sim = False #True - # - # if do_sim: - # data_file = 'tmp.npz' - # rf_file = 'tmp.pkl.z' - # if do_sim: - # try: #first time around - # file_training_data, file_training_labels = load(data_file) - # training_data = np.concatenate((file_training_data, training_data)) - # training_labels = np.concatenate((file_training_labels, training_labels)) - # except: - # pass - # - # try: #first time around - # os.remove(data_file) - # except: - # pass - # - # dump((training_data, training_labels), data_file, compress=True) #save new file - # try: #first time around - # clf = load(rf_file) #load last model from file - # except: - # clf = RandomForestClassifier(n_estimators=DEFAULT_RF_NESTIMATORS, n_jobs=-1,class_weight="balanced_subsample", min_samples_split=5) - # - # try: #first time around - # os.remove(rf_file) - # except: - # pass - # - # clf = RandomForestClassifier(n_estimators=DEFAULT_RF_NESTIMATORS, n_jobs=-1,class_weight="balanced_subsample", min_samples_split=5) - # clf.fit(training_data, training_labels) - # - # dump(clf, rf_file, compress=True) #save new file - - # else: - - # clf = RandomForestClassifier(n_estimators=DEFAULT_RF_NESTIMATORS, n_jobs=-1,class_weight="balanced_subsample", min_samples_split=5) - # - # scaler = StandardScaler() - # training_data = scaler.fit_transform(training_data) - # - # clf.fit(training_data, training_labels) - #================================ - # plt.figure(figsize=(12,12)) - # plt.subplots_adjust(hspace=0.5) - # counter = 1 - # loc='abcdefghijklmnop' - # for pair in [ (0,1), (0,2), (0,3), (0,4), (1,2), (1,3), (1,4), (2,3), (2,4) ]: - # #clf2 = RandomForestClassifier(n_estimators=DEFAULT_RF_NESTIMATORS, n_jobs=-1,class_weight="balanced_subsample", min_samples_split=5) - # clf2 = make_pipeline( - # StandardScaler(), - # MLPClassifier( - # solver='lbfgs', alpha=1, random_state=1, max_iter=2000, - # early_stopping=True, hidden_layer_sizes=[100, 100], - # )) - # - # clf2.fit(training_data[:,pair], training_labels) - # - # # Now plot the decision boundary using a fine mesh as input to a - # # filled contour plot - # plot_step = .05 - # x_min, x_max = training_data[:, pair[0]].min() - 2, training_data[:, pair[0]].max() + 2 - # y_min, y_max = training_data[:, pair[1]].min() - 2, training_data[:, pair[1]].max() + 2 - # xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step), - # np.arange(y_min, y_max, plot_step)) - # - # #visualize rf decision surface - # ax=plt.subplot(3,3,counter) - # Z = clf2.predict(np.c_[xx.ravel(), yy.ravel()]) - # Z = Z.reshape(xx.shape) - # cs = plt.contourf(xx, yy, Z, cmap=cmap) - # plt.title(loc[counter-1]+')', loc='left', fontsize=8) - # plt.scatter(training_data[:, 0], training_data[:, 1], c=training_labels, - # cmap=cmap,edgecolor='k', s=5, lw=.25) - # for tick in ax.xaxis.get_major_ticks(): - # tick.label.set_fontsize(7) - # for tick in ax.yaxis.get_major_ticks(): - # tick.label.set_fontsize(7) - # #plt.axis('off') - # counter+=1 - # - # #plt.show() - # plt.savefig(anno_file.replace('.npz','_RFdecsurf_labelgen.png'), dpi=200, bbox_inches='tight') - # plt.close() - - # #first two features only (location and intensity) - # rf_result2 = clf2.predict(features_use[:,:2]) - # #del features_use - # rf_result2 = rf_result2.reshape(sh[1:]) - - #================================ - #visualize rf feature importances - #Feature importances are provided by the fitted attribute feature_importances_ - #and they are computed as the mean and standard deviation of accumulation of the impurity decrease within each tree. - # importances = clf.feature_importances_ - # - # plt.figure(figsize=(8,12)) - # plt.subplots_adjust(wspace=0.1, hspace=0.1) - # plt.subplot(3, 1, 1) - # plt.bar(np.arange(len(importances)), importances) - # for f in np.argsort(importances)[-4:]: - # plt.axvline(x=f, ymin=0, ymax=1, color='r', linestyle='--') - # plt.ylabel("Feature importance (non-dim.)") #Mean decrease in impurity - # plt.xlabel("Feature") - # plt.title('a)',loc='left', fontsize=7) - # - # counter=3 - # syms='bcdefghijk' - # for f in np.argsort(importances)[-4:]: - # plt.subplot(3,2,counter) - # plt.imshow(features_use[:,f].reshape(sh[1:]), cmap='gray') - # plt.axis('off'); plt.title(syms[counter-3]+') Feature '+str(f),loc='left', fontsize=7) - # counter+=1 - # - # plt.savefig(anno_file.replace('.npz','_RF_featimps_labelgen.png'), dpi=200, bbox_inches='tight') - # plt.close() - #================================ - - # imsave(anno_file.replace('.npz','_label_RF_col_labelgen.png'), label_to_colors(rf_result-1, img[:,:,0]==0, alpha=128, colormap=class_label_colormap, color_class_offset=0, do_alpha=False), check_contrast=False) diff --git a/utils/viz_npz.py b/utils/viz_npz.py index 21061f2..62ee480 100644 --- a/utils/viz_npz.py +++ b/utils/viz_npz.py @@ -92,8 +92,18 @@ def do_viz_npz(npz_type): except: pass del dat + print(data['image'].shape) - classes = data['classes'] + if 'classes' not in locals(): + + try: + classes = data['classes'] + except: + Tk().withdraw() # we don't want a full GUI, so keep the root window from appearing + classfile = askopenfilename(title='Select file containing class (label) names', filetypes=[("Pick classes.txt file","*.txt")]) + + with open(classfile) as f: + classes = f.readlines() class_label_names = [c.strip() for c in classes]