-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathindex.html
1688 lines (1286 loc) · 363 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=2">
<meta name="theme-color" content="#222">
<meta name="generator" content="Hexo 5.3.0">
<link rel="apple-touch-icon" sizes="180x180" href="/images/apple-touch-icon-next.png">
<link rel="icon" type="image/png" sizes="32x32" href="/images/favicon-32x32-next.png">
<link rel="icon" type="image/png" sizes="16x16" href="/images/favicon-16x16-next.png">
<link rel="mask-icon" href="/images/logo.svg" color="#222">
<link rel="stylesheet" href="/css/main.css">
<link rel="stylesheet" href="/lib/font-awesome/css/all.min.css">
<script id="hexo-configurations">
var NexT = window.NexT || {};
var CONFIG = {"hostname":"example.com","root":"/","scheme":"Pisces","version":"7.8.0","exturl":false,"sidebar":{"position":"left","display":"post","padding":18,"offset":12,"onmobile":false},"copycode":{"enable":false,"show_result":false,"style":null},"back2top":{"enable":true,"sidebar":false,"scrollpercent":false},"bookmark":{"enable":false,"color":"#222","save":"auto"},"fancybox":false,"mediumzoom":false,"lazyload":false,"pangu":false,"comments":{"style":"tabs","active":null,"storage":true,"lazyload":false,"nav":null},"algolia":{"hits":{"per_page":10},"labels":{"input_placeholder":"Search for Posts","hits_empty":"We didn't find any results for the search: ${query}","hits_stats":"${hits} results found in ${time} ms"}},"localsearch":{"enable":false,"trigger":"auto","top_n_per_article":1,"unescape":false,"preload":false},"motion":{"enable":true,"async":false,"transition":{"post_block":"fadeIn","post_header":"slideDownIn","post_body":"slideDownIn","coll_header":"slideLeftIn","sidebar":"slideUpIn"}}};
</script>
<meta property="og:type" content="website">
<meta property="og:title" content="水广山">
<meta property="og:url" content="http://example.com/index.html">
<meta property="og:site_name" content="水广山">
<meta property="og:locale" content="zh_CN">
<meta property="article:author" content="Guangshan Shui">
<meta name="twitter:card" content="summary">
<link rel="canonical" href="http://example.com/">
<script id="page-configurations">
// https://hexo.io/docs/variables.html
CONFIG.page = {
sidebar: "",
isHome : true,
isPost : false,
lang : 'zh-CN'
};
</script>
<title>水广山</title>
<noscript>
<style>
.use-motion .brand,
.use-motion .menu-item,
.sidebar-inner,
.use-motion .post-block,
.use-motion .pagination,
.use-motion .comments,
.use-motion .post-header,
.use-motion .post-body,
.use-motion .collection-header { opacity: initial; }
.use-motion .site-title,
.use-motion .site-subtitle {
opacity: initial;
top: initial;
}
.use-motion .logo-line-before i { left: initial; }
.use-motion .logo-line-after i { right: initial; }
</style>
</noscript>
<!-- hexo injector head_end start -->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css">
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/style.css">
<!-- hexo injector head_end end --></head>
<body itemscope itemtype="http://schema.org/WebPage">
<div class="container use-motion">
<div class="headband"></div>
<header class="header" itemscope itemtype="http://schema.org/WPHeader">
<div class="header-inner"><div class="site-brand-container">
<div class="site-nav-toggle">
<div class="toggle" aria-label="切换导航栏">
<span class="toggle-line toggle-line-first"></span>
<span class="toggle-line toggle-line-middle"></span>
<span class="toggle-line toggle-line-last"></span>
</div>
</div>
<div class="site-meta">
<a href="/" class="brand" rel="start">
<span class="logo-line-before"><i></i></span>
<h1 class="site-title">水广山</h1>
<span class="logo-line-after"><i></i></span>
</a>
</div>
<div class="site-nav-right">
<div class="toggle popup-trigger">
</div>
</div>
</div>
<nav class="site-nav">
<ul id="menu" class="main-menu menu">
<li class="menu-item menu-item-home">
<a href="/" rel="section"><i class="fa fa-home fa-fw"></i>首页</a>
</li>
<li class="menu-item menu-item-about">
<a href="/about/" rel="section"><i class="fa fa-user fa-fw"></i>关于</a>
</li>
<li class="menu-item menu-item-tags">
<a href="/tags/" rel="section"><i class="fa fa-tags fa-fw"></i>标签</a>
</li>
<li class="menu-item menu-item-categories">
<a href="/categories/" rel="section"><i class="fa fa-th fa-fw"></i>分类</a>
</li>
<li class="menu-item menu-item-archives">
<a href="/archives/" rel="section"><i class="fa fa-archive fa-fw"></i>归档</a>
</li>
</ul>
</nav>
</div>
</header>
<div class="back-to-top">
<i class="fa fa-arrow-up"></i>
<span>0%</span>
</div>
<main class="main">
<div class="main-inner">
<div class="content-wrap">
<div class="content index posts-expand">
<article itemscope itemtype="http://schema.org/Article" class="post-block" lang="zh-CN">
<link itemprop="mainEntityOfPage" href="http://example.com/2023/12/29/competition/">
<span hidden itemprop="author" itemscope itemtype="http://schema.org/Person">
<meta itemprop="image" content="/images/avatar.gif">
<meta itemprop="name" content="Guangshan Shui">
<meta itemprop="description" content="">
</span>
<span hidden itemprop="publisher" itemscope itemtype="http://schema.org/Organization">
<meta itemprop="name" content="水广山">
</span>
<header class="post-header">
<h2 class="post-title" itemprop="name headline">
<a href="/2023/12/29/competition/" class="post-title-link" itemprop="url">competition</a>
</h2>
<div class="post-meta">
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar"></i>
</span>
<span class="post-meta-item-text">发表于</span>
<time title="创建时间:2023-12-29 19:49:50" itemprop="dateCreated datePublished" datetime="2023-12-29T19:49:50+08:00">2023-12-29</time>
</span>
</div>
</header>
<div class="post-body" itemprop="articleBody">
</div>
<footer class="post-footer">
<div class="post-eof"></div>
</footer>
</article>
<article itemscope itemtype="http://schema.org/Article" class="post-block" lang="zh-CN">
<link itemprop="mainEntityOfPage" href="http://example.com/2023/12/19/%E7%AB%9E%E8%B5%9B%E6%80%BB%E7%BB%93/">
<span hidden itemprop="author" itemscope itemtype="http://schema.org/Person">
<meta itemprop="image" content="/images/avatar.gif">
<meta itemprop="name" content="Guangshan Shui">
<meta itemprop="description" content="">
</span>
<span hidden itemprop="publisher" itemscope itemtype="http://schema.org/Organization">
<meta itemprop="name" content="水广山">
</span>
<header class="post-header">
<h2 class="post-title" itemprop="name headline">
<a href="/2023/12/19/%E7%AB%9E%E8%B5%9B%E6%80%BB%E7%BB%93/" class="post-title-link" itemprop="url">竞赛总结</a>
</h2>
<div class="post-meta">
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar"></i>
</span>
<span class="post-meta-item-text">发表于</span>
<time title="创建时间:2023-12-19 16:59:41" itemprop="dateCreated datePublished" datetime="2023-12-19T16:59:41+08:00">2023-12-19</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar-check"></i>
</span>
<span class="post-meta-item-text">更新于</span>
<time title="修改时间:2023-12-29 19:34:12" itemprop="dateModified" datetime="2023-12-29T19:34:12+08:00">2023-12-29</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-folder"></i>
</span>
<span class="post-meta-item-text">分类于</span>
<span itemprop="about" itemscope itemtype="http://schema.org/Thing">
<a href="/categories/ML/" itemprop="url" rel="index"><span itemprop="name">ML</span></a>
</span>
</span>
</div>
</header>
<div class="post-body" itemprop="articleBody">
<h2 id="1-能源预测(时序)(比赛结束后更新)"><a href="#1-能源预测(时序)(比赛结束后更新)" class="headerlink" title="1. 能源预测(时序)(比赛结束后更新)"></a>1. 能源预测(时序)(比赛结束后更新)</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br><span class="line">128</span><br><span class="line">129</span><br><span class="line">130</span><br><span class="line">131</span><br><span class="line">132</span><br><span class="line">133</span><br><span class="line">134</span><br><span class="line">135</span><br><span class="line">136</span><br><span class="line">137</span><br><span class="line">138</span><br><span class="line">139</span><br><span class="line">140</span><br><span class="line">141</span><br><span class="line">142</span><br><span class="line">143</span><br><span class="line">144</span><br><span class="line">145</span><br><span class="line">146</span><br><span class="line">147</span><br><span class="line">148</span><br><span class="line">149</span><br><span class="line">150</span><br><span class="line">151</span><br><span class="line">152</span><br><span class="line">153</span><br><span class="line">154</span><br><span class="line">155</span><br><span class="line">156</span><br><span class="line">157</span><br><span class="line">158</span><br><span class="line">159</span><br><span class="line">160</span><br><span class="line">161</span><br><span class="line">162</span><br><span class="line">163</span><br><span class="line">164</span><br><span class="line">165</span><br><span class="line">166</span><br><span class="line">167</span><br><span class="line">168</span><br><span class="line">169</span><br><span class="line">170</span><br><span class="line">171</span><br><span class="line">172</span><br><span class="line">173</span><br><span class="line">174</span><br><span class="line">175</span><br><span class="line">176</span><br><span class="line">177</span><br><span class="line">178</span><br><span class="line">179</span><br><span class="line">180</span><br><span class="line">181</span><br><span class="line">182</span><br><span class="line">183</span><br><span class="line">184</span><br><span class="line">185</span><br><span class="line">186</span><br><span class="line">187</span><br><span class="line">188</span><br><span class="line">189</span><br><span class="line">190</span><br><span class="line">191</span><br><span class="line">192</span><br><span class="line">193</span><br><span class="line">194</span><br><span class="line">195</span><br><span class="line">196</span><br><span class="line">197</span><br><span class="line">198</span><br><span class="line">199</span><br><span class="line">200</span><br><span class="line">201</span><br><span class="line">202</span><br><span class="line">203</span><br><span class="line">204</span><br><span class="line">205</span><br><span class="line">206</span><br><span class="line">207</span><br><span class="line">208</span><br><span class="line">209</span><br><span class="line">210</span><br><span class="line">211</span><br><span class="line">212</span><br><span class="line">213</span><br><span class="line">214</span><br><span class="line">215</span><br><span class="line">216</span><br><span class="line">217</span><br><span class="line">218</span><br><span class="line">219</span><br><span class="line">220</span><br><span class="line">221</span><br><span class="line">222</span><br><span class="line">223</span><br><span class="line">224</span><br><span class="line">225</span><br><span class="line">226</span><br><span class="line">227</span><br><span class="line">228</span><br><span class="line">229</span><br><span class="line">230</span><br><span class="line">231</span><br><span class="line">232</span><br><span class="line">233</span><br><span class="line">234</span><br><span class="line">235</span><br><span class="line">236</span><br><span class="line">237</span><br><span class="line">238</span><br><span class="line">239</span><br><span class="line">240</span><br><span class="line">241</span><br><span class="line">242</span><br><span class="line">243</span><br><span class="line">244</span><br><span class="line">245</span><br><span class="line">246</span><br><span class="line">247</span><br><span class="line">248</span><br><span class="line">249</span><br><span class="line">250</span><br><span class="line">251</span><br><span class="line">252</span><br><span class="line">253</span><br><span class="line">254</span><br><span class="line">255</span><br><span class="line">256</span><br><span class="line">257</span><br><span class="line">258</span><br><span class="line">259</span><br><span class="line">260</span><br><span class="line">261</span><br><span class="line">262</span><br><span class="line">263</span><br><span class="line">264</span><br><span class="line">265</span><br><span class="line">266</span><br><span class="line">267</span><br><span class="line">268</span><br><span class="line">269</span><br><span class="line">270</span><br><span class="line">271</span><br><span class="line">272</span><br><span class="line">273</span><br><span class="line">274</span><br><span class="line">275</span><br><span class="line">276</span><br><span class="line">277</span><br><span class="line">278</span><br><span class="line">279</span><br><span class="line">280</span><br><span class="line">281</span><br><span class="line">282</span><br><span class="line">283</span><br><span class="line">284</span><br><span class="line">285</span><br><span class="line">286</span><br><span class="line">287</span><br><span class="line">288</span><br><span class="line">289</span><br><span class="line">290</span><br><span class="line">291</span><br><span class="line">292</span><br><span class="line">293</span><br><span class="line">294</span><br><span class="line">295</span><br><span class="line">296</span><br><span class="line">297</span><br><span class="line">298</span><br><span class="line">299</span><br><span class="line">300</span><br><span class="line">301</span><br><span class="line">302</span><br><span class="line">303</span><br><span class="line">304</span><br><span class="line">305</span><br><span class="line">306</span><br><span class="line">307</span><br><span class="line">308</span><br><span class="line">309</span><br><span class="line">310</span><br><span class="line">311</span><br><span class="line">312</span><br><span class="line">313</span><br><span class="line">314</span><br><span class="line">315</span><br><span class="line">316</span><br><span class="line">317</span><br><span class="line">318</span><br><span class="line">319</span><br><span class="line">320</span><br><span class="line">321</span><br><span class="line">322</span><br><span class="line">323</span><br><span class="line">324</span><br><span class="line">325</span><br><span class="line">326</span><br><span class="line">327</span><br><span class="line">328</span><br><span class="line">329</span><br><span class="line">330</span><br><span class="line">331</span><br><span class="line">332</span><br><span class="line">333</span><br><span class="line">334</span><br><span class="line">335</span><br><span class="line">336</span><br><span class="line">337</span><br><span class="line">338</span><br><span class="line">339</span><br><span class="line">340</span><br><span class="line">341</span><br><span class="line">342</span><br><span class="line">343</span><br><span class="line">344</span><br><span class="line">345</span><br><span class="line">346</span><br><span class="line">347</span><br><span class="line">348</span><br><span class="line">349</span><br><span class="line">350</span><br><span class="line">351</span><br><span class="line">352</span><br><span class="line">353</span><br><span class="line">354</span><br><span class="line">355</span><br><span class="line">356</span><br><span class="line">357</span><br><span class="line">358</span><br><span class="line">359</span><br><span class="line">360</span><br><span class="line">361</span><br><span class="line">362</span><br><span class="line">363</span><br><span class="line">364</span><br><span class="line">365</span><br><span class="line">366</span><br><span class="line">367</span><br><span class="line">368</span><br><span class="line">369</span><br><span class="line">370</span><br><span class="line">371</span><br><span class="line">372</span><br><span class="line">373</span><br><span class="line">374</span><br><span class="line">375</span><br><span class="line">376</span><br><span class="line">377</span><br><span class="line">378</span><br><span class="line">379</span><br><span class="line">380</span><br><span class="line">381</span><br><span class="line">382</span><br><span class="line">383</span><br><span class="line">384</span><br><span class="line">385</span><br><span class="line">386</span><br><span class="line">387</span><br><span class="line">388</span><br><span class="line">389</span><br><span class="line">390</span><br><span class="line">391</span><br><span class="line">392</span><br><span class="line">393</span><br><span class="line">394</span><br><span class="line">395</span><br><span class="line">396</span><br><span class="line">397</span><br><span class="line">398</span><br><span class="line">399</span><br><span class="line">400</span><br><span class="line">401</span><br><span class="line">402</span><br><span class="line">403</span><br><span class="line">404</span><br><span class="line">405</span><br><span class="line">406</span><br><span class="line">407</span><br><span class="line">408</span><br><span class="line">409</span><br><span class="line">410</span><br><span class="line">411</span><br><span class="line">412</span><br><span class="line">413</span><br><span class="line">414</span><br><span class="line">415</span><br><span class="line">416</span><br><span class="line">417</span><br><span class="line">418</span><br><span class="line">419</span><br><span class="line">420</span><br><span class="line">421</span><br><span class="line">422</span><br><span class="line">423</span><br><span class="line">424</span><br><span class="line">425</span><br><span class="line">426</span><br><span class="line">427</span><br><span class="line">428</span><br><span class="line">429</span><br><span class="line">430</span><br><span class="line">431</span><br><span class="line">432</span><br><span class="line">433</span><br><span class="line">434</span><br><span class="line">435</span><br><span class="line">436</span><br><span class="line">437</span><br><span class="line">438</span><br><span class="line">439</span><br><span class="line">440</span><br><span class="line">441</span><br><span class="line">442</span><br><span class="line">443</span><br><span class="line">444</span><br><span class="line">445</span><br><span class="line">446</span><br><span class="line">447</span><br><span class="line">448</span><br><span class="line">449</span><br><span class="line">450</span><br><span class="line">451</span><br><span class="line">452</span><br><span class="line">453</span><br><span class="line">454</span><br><span class="line">455</span><br><span class="line">456</span><br><span class="line">457</span><br><span class="line">458</span><br><span class="line">459</span><br><span class="line">460</span><br><span class="line">461</span><br><span class="line">462</span><br><span class="line">463</span><br><span class="line">464</span><br><span class="line">465</span><br><span class="line">466</span><br><span class="line">467</span><br><span class="line">468</span><br><span class="line">469</span><br><span class="line">470</span><br><span class="line">471</span><br><span class="line">472</span><br><span class="line">473</span><br><span class="line">474</span><br><span class="line">475</span><br><span class="line">476</span><br><span class="line">477</span><br><span class="line">478</span><br><span class="line">479</span><br><span class="line">480</span><br><span class="line">481</span><br><span class="line">482</span><br><span class="line">483</span><br><span class="line">484</span><br><span class="line">485</span><br><span class="line">486</span><br><span class="line">487</span><br><span class="line">488</span><br><span class="line">489</span><br><span class="line">490</span><br><span class="line">491</span><br><span class="line">492</span><br><span class="line">493</span><br><span class="line">494</span><br><span class="line">495</span><br><span class="line">496</span><br><span class="line">497</span><br><span class="line">498</span><br><span class="line">499</span><br><span class="line">500</span><br><span class="line">501</span><br><span class="line">502</span><br><span class="line">503</span><br><span class="line">504</span><br><span class="line">505</span><br><span class="line">506</span><br><span class="line">507</span><br><span class="line">508</span><br><span class="line">509</span><br><span class="line">510</span><br><span class="line">511</span><br><span class="line">512</span><br><span class="line">513</span><br><span class="line">514</span><br><span class="line">515</span><br><span class="line">516</span><br><span class="line">517</span><br><span class="line">518</span><br><span class="line">519</span><br><span class="line">520</span><br><span class="line">521</span><br><span class="line">522</span><br><span class="line">523</span><br><span class="line">524</span><br><span class="line">525</span><br><span class="line">526</span><br><span class="line">527</span><br><span class="line">528</span><br><span class="line">529</span><br><span class="line">530</span><br><span class="line">531</span><br><span class="line">532</span><br><span class="line">533</span><br><span class="line">534</span><br><span class="line">535</span><br><span class="line">536</span><br><span class="line">537</span><br><span class="line">538</span><br><span class="line">539</span><br><span class="line">540</span><br><span class="line">541</span><br><span class="line">542</span><br><span class="line">543</span><br><span class="line">544</span><br><span class="line">545</span><br><span class="line">546</span><br><span class="line">547</span><br><span class="line">548</span><br><span class="line">549</span><br><span class="line">550</span><br><span class="line">551</span><br><span class="line">552</span><br><span class="line">553</span><br><span class="line">554</span><br><span class="line">555</span><br><span class="line">556</span><br><span class="line">557</span><br><span class="line">558</span><br><span class="line">559</span><br><span class="line">560</span><br><span class="line">561</span><br><span class="line">562</span><br><span class="line">563</span><br><span class="line">564</span><br><span class="line">565</span><br><span class="line">566</span><br><span class="line">567</span><br><span class="line">568</span><br><span class="line">569</span><br><span class="line">570</span><br><span class="line">571</span><br><span class="line">572</span><br><span class="line">573</span><br><span class="line">574</span><br><span class="line">575</span><br><span class="line">576</span><br><span class="line">577</span><br><span class="line">578</span><br><span class="line">579</span><br><span class="line">580</span><br><span class="line">581</span><br><span class="line">582</span><br><span class="line">583</span><br><span class="line">584</span><br><span class="line">585</span><br><span class="line">586</span><br><span class="line">587</span><br><span class="line">588</span><br><span class="line">589</span><br><span class="line">590</span><br><span class="line">591</span><br><span class="line">592</span><br><span class="line">593</span><br><span class="line">594</span><br><span class="line">595</span><br><span class="line">596</span><br><span class="line">597</span><br><span class="line">598</span><br><span class="line">599</span><br><span class="line">600</span><br><span class="line">601</span><br><span class="line">602</span><br><span class="line">603</span><br><span class="line">604</span><br><span class="line">605</span><br><span class="line">606</span><br><span class="line">607</span><br><span class="line">608</span><br><span class="line">609</span><br><span class="line">610</span><br><span class="line">611</span><br><span class="line">612</span><br><span class="line">613</span><br><span class="line">614</span><br><span class="line">615</span><br><span class="line">616</span><br><span class="line">617</span><br><span class="line">618</span><br><span class="line">619</span><br><span class="line">620</span><br><span class="line">621</span><br><span class="line">622</span><br><span class="line">623</span><br><span class="line">624</span><br><span class="line">625</span><br><span class="line">626</span><br><span class="line">627</span><br><span class="line">628</span><br><span class="line">629</span><br><span class="line">630</span><br><span class="line">631</span><br><span class="line">632</span><br><span class="line">633</span><br><span class="line">634</span><br><span class="line">635</span><br><span class="line">636</span><br><span class="line">637</span><br><span class="line">638</span><br><span class="line">639</span><br><span class="line">640</span><br><span class="line">641</span><br><span class="line">642</span><br><span class="line">643</span><br><span class="line">644</span><br><span class="line">645</span><br><span class="line">646</span><br><span class="line">647</span><br><span class="line">648</span><br><span class="line">649</span><br><span class="line">650</span><br><span class="line">651</span><br><span class="line">652</span><br><span class="line">653</span><br><span class="line">654</span><br><span class="line">655</span><br><span class="line">656</span><br><span class="line">657</span><br><span class="line">658</span><br><span class="line">659</span><br><span class="line">660</span><br><span class="line">661</span><br><span class="line">662</span><br><span class="line">663</span><br><span class="line">664</span><br><span class="line">665</span><br><span class="line">666</span><br><span class="line">667</span><br><span class="line">668</span><br><span class="line">669</span><br><span class="line">670</span><br><span class="line">671</span><br><span class="line">672</span><br><span class="line">673</span><br><span class="line">674</span><br><span class="line">675</span><br><span class="line">676</span><br><span class="line">677</span><br><span class="line">678</span><br><span class="line">679</span><br><span class="line">680</span><br><span class="line">681</span><br><span class="line">682</span><br><span class="line">683</span><br><span class="line">684</span><br><span class="line">685</span><br><span class="line">686</span><br><span class="line">687</span><br><span class="line">688</span><br><span class="line">689</span><br><span class="line">690</span><br><span class="line">691</span><br><span class="line">692</span><br><span class="line">693</span><br><span class="line">694</span><br><span class="line">695</span><br><span class="line">696</span><br><span class="line">697</span><br><span class="line">698</span><br><span class="line">699</span><br><span class="line">700</span><br><span class="line">701</span><br><span class="line">702</span><br><span class="line">703</span><br><span class="line">704</span><br><span class="line">705</span><br><span class="line">706</span><br><span class="line">707</span><br><span class="line">708</span><br><span class="line">709</span><br><span class="line">710</span><br><span class="line">711</span><br><span class="line">712</span><br><span class="line">713</span><br><span class="line">714</span><br><span class="line">715</span><br><span class="line">716</span><br><span class="line">717</span><br><span class="line">718</span><br><span class="line">719</span><br><span class="line">720</span><br><span class="line">721</span><br><span class="line">722</span><br><span class="line">723</span><br><span class="line">724</span><br><span class="line">725</span><br><span class="line">726</span><br><span class="line">727</span><br><span class="line">728</span><br><span class="line">729</span><br><span class="line">730</span><br><span class="line">731</span><br><span class="line">732</span><br><span class="line">733</span><br><span class="line">734</span><br><span class="line">735</span><br><span class="line">736</span><br><span class="line">737</span><br><span class="line">738</span><br><span class="line">739</span><br><span class="line">740</span><br><span class="line">741</span><br><span class="line">742</span><br><span class="line">743</span><br><span class="line">744</span><br><span class="line">745</span><br><span class="line">746</span><br><span class="line">747</span><br><span class="line">748</span><br><span class="line">749</span><br><span class="line">750</span><br><span class="line">751</span><br><span class="line">752</span><br><span class="line">753</span><br><span class="line">754</span><br><span class="line">755</span><br><span class="line">756</span><br><span class="line">757</span><br><span class="line">758</span><br><span class="line">759</span><br><span class="line">760</span><br><span class="line">761</span><br><span class="line">762</span><br><span class="line">763</span><br><span class="line">764</span><br><span class="line">765</span><br><span class="line">766</span><br><span class="line">767</span><br><span class="line">768</span><br><span class="line">769</span><br><span class="line">770</span><br><span class="line">771</span><br><span class="line">772</span><br><span class="line">773</span><br><span class="line">774</span><br><span class="line">775</span><br><span class="line">776</span><br><span class="line">777</span><br><span class="line">778</span><br><span class="line">779</span><br><span class="line">780</span><br><span class="line">781</span><br><span class="line">782</span><br><span class="line">783</span><br><span class="line">784</span><br><span class="line">785</span><br><span class="line">786</span><br><span class="line">787</span><br><span class="line">788</span><br><span class="line">789</span><br><span class="line">790</span><br><span class="line">791</span><br><span class="line">792</span><br><span class="line">793</span><br><span class="line">794</span><br><span class="line">795</span><br><span class="line">796</span><br><span class="line">797</span><br><span class="line">798</span><br><span class="line">799</span><br><span class="line">800</span><br><span class="line">801</span><br><span class="line">802</span><br><span class="line">803</span><br><span class="line">804</span><br><span class="line">805</span><br><span class="line">806</span><br><span class="line">807</span><br><span class="line">808</span><br><span class="line">809</span><br><span class="line">810</span><br><span class="line">811</span><br><span class="line">812</span><br><span class="line">813</span><br><span class="line">814</span><br><span class="line">815</span><br><span class="line">816</span><br><span class="line">817</span><br><span class="line">818</span><br><span class="line">819</span><br><span class="line">820</span><br><span class="line">821</span><br><span class="line">822</span><br><span class="line">823</span><br><span class="line">824</span><br><span class="line">825</span><br><span class="line">826</span><br><span class="line">827</span><br><span class="line">828</span><br><span class="line">829</span><br><span class="line">830</span><br><span class="line">831</span><br><span class="line">832</span><br><span class="line">833</span><br><span class="line">834</span><br><span class="line">835</span><br><span class="line">836</span><br><span class="line">837</span><br><span class="line">838</span><br><span class="line">839</span><br><span class="line">840</span><br><span class="line">841</span><br><span class="line">842</span><br><span class="line">843</span><br><span class="line">844</span><br><span class="line">845</span><br><span class="line">846</span><br><span class="line">847</span><br><span class="line">848</span><br><span class="line">849</span><br><span class="line">850</span><br><span class="line">851</span><br><span class="line">852</span><br><span class="line">853</span><br><span class="line">854</span><br><span class="line">855</span><br><span class="line">856</span><br><span class="line">857</span><br><span class="line">858</span><br><span class="line">859</span><br><span class="line">860</span><br><span class="line">861</span><br><span class="line">862</span><br><span class="line">863</span><br><span class="line">864</span><br><span class="line">865</span><br><span class="line">866</span><br><span class="line">867</span><br><span class="line">868</span><br><span class="line">869</span><br><span class="line">870</span><br><span class="line">871</span><br><span class="line">872</span><br><span class="line">873</span><br><span class="line">874</span><br><span class="line">875</span><br><span class="line">876</span><br><span class="line">877</span><br><span class="line">878</span><br><span class="line">879</span><br><span class="line">880</span><br><span class="line">881</span><br><span class="line">882</span><br><span class="line">883</span><br><span class="line">884</span><br><span class="line">885</span><br><span class="line">886</span><br><span class="line">887</span><br><span class="line">888</span><br><span class="line">889</span><br><span class="line">890</span><br><span class="line">891</span><br><span class="line">892</span><br><span class="line">893</span><br><span class="line">894</span><br><span class="line">895</span><br><span class="line">896</span><br><span class="line">897</span><br><span class="line">898</span><br><span class="line">899</span><br><span class="line">900</span><br><span class="line">901</span><br><span class="line">902</span><br><span class="line">903</span><br><span class="line">904</span><br><span class="line">905</span><br><span class="line">906</span><br><span class="line">907</span><br><span class="line">908</span><br><span class="line">909</span><br><span class="line">910</span><br><span class="line">911</span><br><span class="line">912</span><br><span class="line">913</span><br><span class="line">914</span><br><span class="line">915</span><br><span class="line">916</span><br><span class="line">917</span><br><span class="line">918</span><br><span class="line">919</span><br><span class="line">920</span><br><span class="line">921</span><br><span class="line">922</span><br><span class="line">923</span><br><span class="line">924</span><br><span class="line">925</span><br><span class="line">926</span><br><span class="line">927</span><br><span class="line">928</span><br><span class="line">929</span><br><span class="line">930</span><br><span class="line">931</span><br><span class="line">932</span><br><span class="line">933</span><br><span class="line">934</span><br><span class="line">935</span><br><span class="line">936</span><br><span class="line">937</span><br><span class="line">938</span><br><span class="line">939</span><br><span class="line">940</span><br><span class="line">941</span><br><span class="line">942</span><br><span class="line">943</span><br><span class="line">944</span><br><span class="line">945</span><br><span class="line">946</span><br><span class="line">947</span><br><span class="line">948</span><br><span class="line">949</span><br><span class="line">950</span><br><span class="line">951</span><br><span class="line">952</span><br><span class="line">953</span><br><span class="line">954</span><br><span class="line">955</span><br><span class="line">956</span><br><span class="line">957</span><br><span class="line">958</span><br><span class="line">959</span><br><span class="line">960</span><br><span class="line">961</span><br><span class="line">962</span><br><span class="line">963</span><br><span class="line">964</span><br><span class="line">965</span><br><span class="line">966</span><br><span class="line">967</span><br><span class="line">968</span><br><span class="line">969</span><br><span class="line">970</span><br><span class="line">971</span><br><span class="line">972</span><br><span class="line">973</span><br><span class="line">974</span><br><span class="line">975</span><br><span class="line">976</span><br><span class="line">977</span><br><span class="line">978</span><br><span class="line">979</span><br><span class="line">980</span><br><span class="line">981</span><br><span class="line">982</span><br><span class="line">983</span><br><span class="line">984</span><br><span class="line">985</span><br><span class="line">986</span><br><span class="line">987</span><br><span class="line">988</span><br><span class="line">989</span><br><span class="line">990</span><br><span class="line">991</span><br><span class="line">992</span><br><span class="line">993</span><br><span class="line">994</span><br><span class="line">995</span><br><span class="line">996</span><br><span class="line">997</span><br><span class="line">998</span><br><span class="line">999</span><br><span class="line">1000</span><br><span class="line">1001</span><br><span class="line">1002</span><br><span class="line">1003</span><br><span class="line">1004</span><br><span class="line">1005</span><br><span class="line">1006</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#!/usr/bin/env python</span></span><br><span class="line"><span class="comment"># coding: utf-8</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment">#!/usr/bin/env python</span></span><br><span class="line"><span class="comment"># coding: utf-8</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> datetime</span><br><span class="line"><span class="keyword">import</span> gc</span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> tqdm <span class="keyword">import</span> tqdm</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> warnings</span><br><span class="line">warnings.filterwarnings(<span class="string">'ignore'</span>) </span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> xgboost <span class="keyword">as</span> xgb</span><br><span class="line"><span class="keyword">import</span> lightgbm <span class="keyword">as</span> lgb</span><br><span class="line"><span class="keyword">from</span> catboost <span class="keyword">import</span> CatBoostRegressor, Pool</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> sklearn.metrics <span class="keyword">import</span> mean_squared_error</span><br><span class="line"><span class="keyword">from</span> sklearn.model_selection <span class="keyword">import</span> StratifiedKFold</span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">power_forecast_history_train = pd.read_csv(<span class="string">'/opt/dataset/2023“SEED”第四届江苏大数据开发与应用大赛--新能源【复赛B榜】数据集/data/data/train/power_forecast_history.csv'</span>)</span><br><span class="line">stub_info_train = pd.read_csv(<span class="string">'/opt/dataset/2023“SEED”第四届江苏大数据开发与应用大赛--新能源【复赛B榜】数据集/data/data/train/stub_info.csv'</span>)</span><br><span class="line">power_train = pd.read_csv(<span class="string">'/opt/dataset/2023“SEED”第四届江苏大数据开发与应用大赛--新能源【复赛B榜】数据集/data/data/train/power.csv'</span>)</span><br><span class="line"></span><br><span class="line">power_forecast_history_test = pd.read_csv(<span class="string">'/opt/dataset/2023“SEED”第四届江苏大数据开发与应用大赛--新能源【复赛B榜】数据集/data/data/test/power_forecast_history.csv'</span>)</span><br><span class="line">stub_info_test = pd.read_csv(<span class="string">'/opt/dataset/2023“SEED”第四届江苏大数据开发与应用大赛--新能源【复赛B榜】数据集/data/data/test/stub_info.csv'</span>)</span><br><span class="line"></span><br><span class="line">weather = pd.read_csv(<span class="string">'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/temp_data/city_weather_v2.csv'</span>)</span><br><span class="line">h3_feature = pd.read_csv(<span class="string">'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/temp_data/h3_feature_v2.csv'</span>)</span><br><span class="line"></span><br><span class="line">len_train = <span class="built_in">len</span>(power_forecast_history_train)</span><br><span class="line"></span><br><span class="line">min_id_encode = stub_info_train[<span class="string">'id_encode'</span>].<span class="built_in">min</span>()</span><br><span class="line">max_id_encode = stub_info_train[<span class="string">'id_encode'</span>].<span class="built_in">max</span>()</span><br><span class="line"></span><br><span class="line">min_date = power_forecast_history_train[<span class="string">'ds'</span>].<span class="built_in">min</span>()</span><br><span class="line">max_date = power_forecast_history_train[<span class="string">'ds'</span>].<span class="built_in">max</span>()</span><br><span class="line"></span><br><span class="line">test_min_date = power_forecast_history_test[<span class="string">'ds'</span>].<span class="built_in">min</span>()</span><br><span class="line">test_max_date = power_forecast_history_test[<span class="string">'ds'</span>].<span class="built_in">max</span>()</span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># power_forecast_history_train = pd.read_csv('../../raw_data/复赛A榜/train/power_forecast_history.csv')</span></span><br><span class="line"><span class="comment"># stub_info_train = pd.read_csv('../../raw_data/复赛A榜/train/stub_info.csv')</span></span><br><span class="line"><span class="comment"># power_train = pd.read_csv('../../raw_data/复赛A榜/train/power.csv')</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># power_forecast_history_test = pd.read_csv('../../raw_data/复赛A榜/test/power_forecast_history.csv')</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># weather = pd.read_csv('../../raw_data/city_weather_v2.csv')</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># h3_feature = pd.read_csv('../../raw_data/h3_feature_v3.csv')</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># power_forecast_history_train = power_forecast_history_train[power_forecast_history_train['ds'] >= 20220801]</span></span><br><span class="line">power_forecast_history_train = power_forecast_history_train[power_forecast_history_train.hour.notnull()]</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">convert_dtype_dict = {</span><br><span class="line"> <span class="string">'id_encode'</span>: np.int16,</span><br><span class="line"> <span class="string">'hour'</span>: np.int8,</span><br><span class="line"> <span class="string">'ele_price'</span>: np.float16,</span><br><span class="line"> <span class="string">'ser_price'</span>: np.float16,</span><br><span class="line"> <span class="string">'after_ser_price'</span>: np.float16,</span><br><span class="line"> <span class="string">'total_price'</span>: np.float16,</span><br><span class="line"> <span class="string">'f1'</span>: np.float16,</span><br><span class="line"> <span class="string">'f2'</span>: np.float16, <span class="comment">## xgb这里会报错,f2 有inf,需要设置成32</span></span><br><span class="line"> <span class="string">'f3'</span>: np.float16,</span><br><span class="line"> <span class="string">'ds'</span>: np.uint32,</span><br><span class="line"> <span class="string">'power'</span>: np.float16,</span><br><span class="line"> <span class="string">'parking_free'</span>: np.uint8,</span><br><span class="line"> </span><br><span class="line"> <span class="string">'flag'</span>: np.<span class="built_in">object</span>,</span><br><span class="line"> <span class="string">'h3'</span>: np.<span class="built_in">object</span>,</span><br><span class="line"> <span class="string">'ac_equipment_kw'</span>: np.uint16,</span><br><span class="line"> <span class="string">'dc_equipment_kw'</span>: np.uint16,</span><br><span class="line"> </span><br><span class="line"> <span class="string">'city'</span>: np.<span class="built_in">object</span>,</span><br><span class="line"> <span class="comment"># 'date': np.datetime64,</span></span><br><span class="line"> <span class="string">'temp_max'</span>: np.float16,</span><br><span class="line"> <span class="string">'temp_min'</span>: np.float16,</span><br><span class="line"> <span class="string">'weather'</span>: np.<span class="built_in">object</span>,</span><br><span class="line"> </span><br><span class="line"> <span class="string">'province'</span>: np.<span class="built_in">object</span>,</span><br><span class="line"> <span class="string">'city'</span>: np.<span class="built_in">object</span>,</span><br><span class="line"> <span class="string">'district'</span>: np.<span class="built_in">object</span>,</span><br><span class="line"> <span class="string">'town'</span>: np.<span class="built_in">object</span>,</span><br><span class="line">}</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">convert_type</span>(<span class="params">data, convert_dtype_dict</span>):</span></span><br><span class="line"> cols_dict = <span class="built_in">dict</span>()</span><br><span class="line"> <span class="keyword">for</span> col <span class="keyword">in</span> data.columns:</span><br><span class="line"> <span class="keyword">if</span> col <span class="keyword">in</span> convert_dtype_dict:</span><br><span class="line"> cols_dict[col] = convert_dtype_dict[col]</span><br><span class="line"></span><br><span class="line"> data = data.astype(cols_dict, copy=<span class="literal">False</span>)</span><br><span class="line"> <span class="keyword">return</span> data</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">train = power_forecast_history_train.merge(power_train, on=[<span class="string">'id_encode'</span>, <span class="string">'ds'</span>, <span class="string">'hour'</span>], how=<span class="string">'left'</span>)</span><br><span class="line">train = train[train.hour.notnull()]</span><br><span class="line"><span class="keyword">del</span> power_forecast_history_train, power_train; gc.collect()</span><br><span class="line"></span><br><span class="line">train = convert_type(train, convert_dtype_dict)</span><br><span class="line">stub_info_train = convert_type(stub_info_train, convert_dtype_dict)</span><br><span class="line"></span><br><span class="line">power_forecast_history_test = convert_type(power_forecast_history_test, convert_dtype_dict)</span><br><span class="line"></span><br><span class="line">weather = convert_type(weather, convert_dtype_dict)</span><br><span class="line">h3_feature = convert_type(h3_feature, convert_dtype_dict)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">data = pd.concat([train, power_forecast_history_test])</span><br><span class="line">data.sort_values(by=[<span class="string">'id_encode'</span>, <span class="string">'ds'</span>, <span class="string">'hour'</span>], inplace=<span class="literal">True</span>)</span><br><span class="line"><span class="keyword">del</span> train, power_forecast_history_test; gc.collect()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">data = data.merge(stub_info_train, on=[<span class="string">'id_encode'</span>], how=<span class="string">'left'</span>)</span><br><span class="line">data = data.merge(h3_feature[[<span class="string">'h3'</span>, <span class="string">'province'</span>, <span class="string">'city'</span>, <span class="string">'district'</span>, <span class="string">'town'</span>]],</span><br><span class="line"> on=[<span class="string">'h3'</span>], how=<span class="string">'left'</span>)</span><br><span class="line"></span><br><span class="line">data[<span class="string">'date'</span>] = pd.to_datetime(data[<span class="string">'ds'</span>].astype(<span class="built_in">str</span>, copy=<span class="literal">False</span>))</span><br><span class="line">weather[<span class="string">'date'</span>] = pd.to_datetime(weather[<span class="string">'date'</span>].astype(<span class="built_in">str</span>, copy=<span class="literal">False</span>))</span><br><span class="line">data = data.merge(weather[[<span class="string">'city'</span>, <span class="string">'date'</span>, <span class="string">'temp_max'</span>, <span class="string">'temp_min'</span>, <span class="string">'weather'</span>]], on=[<span class="string">'city'</span>, <span class="string">'date'</span>], how=<span class="string">'left'</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">del</span> stub_info_train, h3_feature, weather; gc.collect()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># ## 特征工程</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">data[<span class="string">'year'</span>] = data[<span class="string">'date'</span>].dt.year.astype(np.uint16, copy=<span class="literal">False</span>)</span><br><span class="line">data[<span class="string">'quarter'</span>] = data[<span class="string">'date'</span>].dt.quarter.astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line">data[<span class="string">'month'</span>] = data[<span class="string">'date'</span>].dt.month.astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line">data[<span class="string">'day'</span>] = data[<span class="string">'date'</span>].dt.day.astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line">data[<span class="string">'dayofyear'</span>] = data[<span class="string">'date'</span>].dt.dayofyear.astype(np.uint16, copy=<span class="literal">False</span>)</span><br><span class="line">data[<span class="string">'weekofyear'</span>] = data[<span class="string">'date'</span>].dt.isocalendar().week.astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line">data[<span class="string">'dayofweek'</span>] = data[<span class="string">'date'</span>].dt.dayofweek.astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line">data[<span class="string">'is_wknd'</span>] = (data[<span class="string">'date'</span>].dt.dayofweek // <span class="number">6</span>).astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line">data[<span class="string">'is_month_start'</span>] = data[<span class="string">'date'</span>].dt.is_month_start.astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line">data[<span class="string">'is_month_end'</span>] = data[<span class="string">'date'</span>].dt.is_month_end.astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">f3_disceret</span>(<span class="params">x</span>):</span></span><br><span class="line"> <span class="keyword">if</span> x < <span class="number">0.6</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">0.7</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">0.8</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">2</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">0.9</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">3</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">4</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">5</span></span><br><span class="line">data[<span class="string">'f3_dis'</span>] = data[<span class="string">'f3'</span>].apply(<span class="keyword">lambda</span> x:f3_disceret(x)).astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">total_price_disceret</span>(<span class="params">x</span>):</span></span><br><span class="line"> <span class="keyword">if</span> x < <span class="number">0.75</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">1.2</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">1.5</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">2</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">3</span></span><br><span class="line">data[<span class="string">'total_price_dis'</span>] = data[<span class="string">'total_price'</span>].apply(<span class="keyword">lambda</span> x:total_price_disceret(x)).astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">f1_disceret</span>(<span class="params">x</span>):</span></span><br><span class="line"> <span class="keyword">if</span> x < <span class="number">20</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">40</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">60</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">2</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">80</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">3</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">100</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">4</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">150</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">5</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">6</span></span><br><span class="line">data[<span class="string">'f1_dis'</span>] = data[<span class="string">'f1'</span>].apply(<span class="keyword">lambda</span> x:f1_disceret(x)).astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">after_ser_price_disceret</span>(<span class="params">x</span>):</span></span><br><span class="line"> <span class="keyword">if</span> x < <span class="number">0.2</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">0.4</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">0.6</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">2</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">3</span></span><br><span class="line">data[<span class="string">'after_ser_price_dis'</span>] = data[<span class="string">'after_ser_price'</span>].apply(<span class="keyword">lambda</span> x:after_ser_price_disceret(x)).astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">ser_price_disceret</span>(<span class="params">x</span>):</span></span><br><span class="line"> <span class="keyword">if</span> x >= <span class="number">0</span> <span class="keyword">and</span> x < <span class="number">0.25</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">0.35</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">0.5</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">2</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">0.65</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">3</span></span><br><span class="line"> <span class="keyword">elif</span> x < <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">4</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">5</span></span><br><span class="line">data[<span class="string">'ser_price_dis'</span>] = data[<span class="string">'ser_price'</span>].apply(<span class="keyword">lambda</span> x:ser_price_disceret(x)).astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">ele_price_disceret</span>(<span class="params">x</span>):</span></span><br><span class="line"> <span class="keyword">if</span> x == <span class="number">0</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">0</span></span><br><span class="line"> <span class="keyword">elif</span> x > <span class="number">0</span> <span class="keyword">and</span> x < <span class="number">0.5</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">elif</span> x >= <span class="number">0.5</span> <span class="keyword">and</span> x < <span class="number">0.75</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">2</span></span><br><span class="line"> <span class="keyword">elif</span> x == <span class="number">0.75</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">3</span></span><br><span class="line"> <span class="keyword">elif</span> x > <span class="number">0.75</span> <span class="keyword">and</span> x < <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">4</span></span><br><span class="line"> <span class="keyword">elif</span> x >= <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">5</span></span><br><span class="line">data[<span class="string">'ele_price_dis'</span>] = data[<span class="string">'ele_price'</span>].apply(<span class="keyword">lambda</span> x:ele_price_disceret(x)).astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> feat <span class="keyword">in</span> [</span><br><span class="line"> <span class="string">'flag'</span>, <span class="string">'h3'</span>, </span><br><span class="line"> <span class="string">'province'</span>, <span class="string">'city'</span>, <span class="string">'district'</span>, <span class="string">'town'</span>, <span class="string">'weather'</span>]:</span><br><span class="line"> <span class="keyword">if</span> data[feat].isnull().<span class="built_in">sum</span>() > <span class="number">0</span>:</span><br><span class="line"> print(feat, data[feat].isnull().<span class="built_in">sum</span>())</span><br><span class="line"> data[feat] = data[feat].fillna(-<span class="number">1</span>)</span><br><span class="line"> unique = <span class="built_in">list</span>(data[feat].unique())</span><br><span class="line"> value = <span class="built_in">list</span>(<span class="built_in">range</span>(data[feat].nunique()))</span><br><span class="line"> le = <span class="built_in">dict</span>(<span class="built_in">zip</span>(unique, value))</span><br><span class="line"> data[feat] = data[feat].<span class="built_in">map</span>(le).astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">count_0</span>(<span class="params">x</span>):</span></span><br><span class="line"> <span class="keyword">return</span> (x == <span class="number">0</span>).<span class="built_in">sum</span>()</span><br><span class="line"></span><br><span class="line">agg_ = data.groupby([<span class="string">'id_encode'</span>])[<span class="string">'power'</span>].agg(count_0).reset_index()</span><br><span class="line">agg_.columns = [<span class="string">'id_encode'</span>, <span class="string">'power_count_0'</span>]</span><br><span class="line">agg_[<span class="string">'power_count_0'</span>] = agg_[<span class="string">'power_count_0'</span>].astype(np.uint16, copy=<span class="literal">False</span>)</span><br><span class="line">data = data.merge(agg_, on=<span class="string">'id_encode'</span>, how=<span class="string">'left'</span>)</span><br><span class="line"><span class="keyword">del</span> agg_; gc.collect()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment">### count </span></span><br><span class="line"><span class="keyword">for</span> feat <span class="keyword">in</span> [</span><br><span class="line"> <span class="string">'id_encode'</span>,</span><br><span class="line"> </span><br><span class="line"> <span class="string">'parking_free'</span>,<span class="string">'dc_equipment_kw'</span>, </span><br><span class="line"><span class="comment"># 'town', 'ad_code', </span></span><br><span class="line"><span class="comment"># 'temp_max', 'temp_min', </span></span><br><span class="line"><span class="comment"># 'weather', </span></span><br><span class="line"><span class="comment"># 'year', </span></span><br><span class="line"><span class="comment"># 'quarter', </span></span><br><span class="line"><span class="comment"># 'month', 'day', 'dayofyear', 'weekofyear', 'dayofweek', 'is_wknd', </span></span><br><span class="line"><span class="comment"># 'is_month_start', 'is_month_end',</span></span><br><span class="line"> </span><br><span class="line"> <span class="string">'ele_price_dis'</span>, </span><br><span class="line"><span class="comment"># 'ser_price_dis', </span></span><br><span class="line"><span class="comment"># 'after_ser_price_dis',</span></span><br><span class="line"><span class="comment"># 'f1_dis',</span></span><br><span class="line"> <span class="string">'total_price_dis'</span>,</span><br><span class="line"> </span><br><span class="line"> <span class="string">'f3_dis'</span></span><br><span class="line">]:</span><br><span class="line"> agg_ = data.groupby([feat])[<span class="string">'power'</span>].agg(<span class="string">'count'</span>).reset_index()</span><br><span class="line"> agg_.columns = [feat, <span class="string">f'<span class="subst">{feat}</span>_count'</span>]</span><br><span class="line"> agg_[<span class="string">f'<span class="subst">{feat}</span>_count'</span>] = agg_[<span class="string">f'<span class="subst">{feat}</span>_count'</span>].astype(np.uint32, copy=<span class="literal">False</span>)</span><br><span class="line"> data = data.merge(agg_, on=feat, how=<span class="string">'left'</span>)</span><br><span class="line"><span class="keyword">del</span> agg_; gc.collect()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> feat <span class="keyword">in</span> [</span><br><span class="line"> <span class="string">'temp_max'</span>, </span><br><span class="line"> <span class="string">'temp_min'</span>, <span class="string">'weather'</span>,<span class="string">'ele_price_dis'</span>, <span class="string">'ser_price_dis'</span>, </span><br><span class="line"> <span class="string">'after_ser_price_dis'</span>, <span class="string">'f1_dis'</span>, <span class="string">'total_price_dis'</span></span><br><span class="line">]:</span><br><span class="line"> agg_ = data.groupby([<span class="string">'id_encode'</span>])[feat].agg(<span class="string">'nunique'</span>).reset_index()</span><br><span class="line"> agg_.columns = [<span class="string">'id_encode'</span>, <span class="string">f'id_encode_<span class="subst">{feat}</span>_nunique'</span>]</span><br><span class="line"> agg_[<span class="string">f'id_encode_<span class="subst">{feat}</span>_nunique'</span>] = agg_[<span class="string">f'id_encode_<span class="subst">{feat}</span>_nunique'</span>].astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"> data = data.merge(agg_, on=<span class="string">'id_encode'</span>, how=<span class="string">'left'</span>)</span><br><span class="line"><span class="keyword">del</span> agg_; gc.collect()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 目标编码</span></span><br><span class="line"><span class="keyword">for</span> feat <span class="keyword">in</span> [<span class="string">'id_encode'</span>, <span class="string">'parking_free'</span>, <span class="string">'flag'</span>, </span><br><span class="line"> <span class="string">'h3'</span>, <span class="string">'year'</span>,<span class="string">'quarter'</span>, </span><br><span class="line"> <span class="string">'month'</span>, <span class="string">'day'</span>, <span class="string">'hour'</span>,</span><br><span class="line"> <span class="string">'dayofweek'</span>, <span class="string">'is_wknd'</span>,</span><br><span class="line"></span><br><span class="line"> <span class="string">'dc_equipment_kw'</span>,</span><br><span class="line"> <span class="string">'city'</span>,</span><br><span class="line"><span class="comment"># 'province', 'district', 'town', 'ad_code', 'temp_max', 'temp_min', </span></span><br><span class="line"> <span class="string">'weather'</span>,</span><br><span class="line"> <span class="string">'ele_price_dis'</span>, <span class="string">'ser_price_dis'</span>, <span class="string">'after_ser_price_dis'</span>, <span class="string">'f1_dis'</span>, <span class="string">'total_price_dis'</span>,</span><br><span class="line"> </span><br><span class="line"> <span class="string">'f3_dis'</span></span><br><span class="line"> </span><br><span class="line"> ]:</span><br><span class="line"> agg_ = data.groupby(feat).agg({</span><br><span class="line"> <span class="string">'power'</span>: [<span class="string">'mean'</span>, <span class="string">'max'</span>, <span class="string">'min'</span>, <span class="string">'std'</span>]</span><br><span class="line"> })</span><br><span class="line"> new_col_names = [feat + <span class="string">'_'</span> + f[<span class="number">0</span>] + <span class="string">'_'</span> + f[<span class="number">1</span>] <span class="keyword">for</span> f <span class="keyword">in</span> agg_.columns]</span><br><span class="line"> agg_.columns = new_col_names</span><br><span class="line"> </span><br><span class="line"> agg_[new_col_names] = agg_[new_col_names].astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> agg_ = agg_.reset_index(drop = <span class="literal">False</span>)</span><br><span class="line"> data = data.merge(agg_, on = feat, how = <span class="string">'left'</span>) </span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> feat1 <span class="keyword">in</span> [</span><br><span class="line"><span class="comment"># 'id_encode',</span></span><br><span class="line"> <span class="string">'parking_free'</span>, <span class="string">'flag'</span>, <span class="string">'dc_equipment_kw'</span>,</span><br><span class="line"> <span class="string">'ele_price_dis'</span>, <span class="string">'ser_price_dis'</span>, <span class="string">'after_ser_price_dis'</span>, <span class="string">'f1_dis'</span>, <span class="string">'total_price_dis'</span>,</span><br><span class="line"> <span class="string">'f3_dis'</span></span><br><span class="line">]:</span><br><span class="line"> <span class="keyword">for</span> feat2 <span class="keyword">in</span> [</span><br><span class="line"> <span class="string">'year'</span>,</span><br><span class="line"> <span class="string">'quarter'</span>, </span><br><span class="line"> <span class="string">'month'</span>, </span><br><span class="line"> <span class="string">'day'</span>, </span><br><span class="line"> <span class="string">'hour'</span>,</span><br><span class="line"> <span class="string">'dayofweek'</span>, </span><br><span class="line"> <span class="string">'is_wknd'</span></span><br><span class="line"> ]:</span><br><span class="line"> agg_ = data.groupby([feat1, feat2]).agg({</span><br><span class="line"> <span class="string">'power'</span>: [<span class="string">'mean'</span>, <span class="string">'max'</span>, <span class="string">'min'</span>, <span class="string">'std'</span>]</span><br><span class="line"> })</span><br><span class="line"> new_col_names = [feat1 + <span class="string">'_'</span> + feat2 + <span class="string">'_'</span> + f[<span class="number">0</span>] + <span class="string">'_'</span> + f[<span class="number">1</span>] <span class="keyword">for</span> f <span class="keyword">in</span> agg_.columns]</span><br><span class="line"> agg_.columns = new_col_names</span><br><span class="line"></span><br><span class="line"> agg_[new_col_names] = agg_[new_col_names].astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> agg_ = agg_.reset_index(drop = <span class="literal">False</span>)</span><br><span class="line"> data = data.merge(agg_, on = [feat1, feat2], how = <span class="string">'left'</span>) </span><br><span class="line"><span class="keyword">del</span> agg_; gc.collect()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> feat1 <span class="keyword">in</span> [</span><br><span class="line"> <span class="string">'id_encode'</span>,</span><br><span class="line"><span class="comment"># 'h3'</span></span><br><span class="line"><span class="comment"># 'parking_free', 'flag', 'h3'</span></span><br><span class="line">]:</span><br><span class="line"> <span class="keyword">for</span> feat2 <span class="keyword">in</span> [</span><br><span class="line"><span class="comment"># 'year',</span></span><br><span class="line"><span class="comment"># 'quarter', </span></span><br><span class="line"> <span class="string">'month'</span>, </span><br><span class="line"> <span class="string">'day'</span>, </span><br><span class="line"> <span class="string">'dayofweek'</span>,</span><br><span class="line"> <span class="string">'hour'</span>,</span><br><span class="line"> <span class="string">'is_wknd'</span></span><br><span class="line"> ]:</span><br><span class="line"> agg_ = data.groupby([feat1, feat2]).agg({</span><br><span class="line"> <span class="string">'power'</span>: [<span class="string">'mean'</span>, <span class="string">'max'</span>, <span class="string">'min'</span>, <span class="string">'std'</span>]</span><br><span class="line"> })</span><br><span class="line"> new_col_names = [feat1 + <span class="string">'_'</span> + feat2 + <span class="string">'_'</span> + f[<span class="number">0</span>] + <span class="string">'_'</span> + f[<span class="number">1</span>] <span class="keyword">for</span> f <span class="keyword">in</span> agg_.columns]</span><br><span class="line"> agg_.columns = new_col_names</span><br><span class="line"> agg_[new_col_names] = agg_[new_col_names].astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> agg_ = agg_.reset_index(drop = <span class="literal">False</span>)</span><br><span class="line"> data = data.merge(agg_, on = [feat1, feat2], how = <span class="string">'left'</span>)</span><br><span class="line"><span class="keyword">del</span> agg_; gc.collect()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> feat <span class="keyword">in</span> data.columns:</span><br><span class="line"> <span class="keyword">if</span> feat == <span class="string">'year'</span>: <span class="keyword">continue</span></span><br><span class="line"> <span class="keyword">if</span> data[feat].nunique() == <span class="number">1</span>:</span><br><span class="line"><span class="comment"># print(feat)</span></span><br><span class="line"> <span class="keyword">del</span> data[feat]</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> feat <span class="keyword">in</span> data.columns:</span><br><span class="line"> <span class="keyword">try</span>:</span><br><span class="line"> <span class="keyword">if</span> (data[feat].<span class="built_in">max</span>() > <span class="number">999999</span>):</span><br><span class="line"> print(feat, data[feat].<span class="built_in">min</span>(),data[feat].<span class="built_in">max</span>())</span><br><span class="line"> <span class="keyword">except</span>:</span><br><span class="line"> print(feat)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">data.info()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># ## shift 特征</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_feature</span>(<span class="params">data, lag=<span class="number">1</span></span>):</span></span><br><span class="line"> </span><br><span class="line"> tmp_cols_3 = []</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> shift <span class="keyword">in</span> <span class="built_in">range</span>(lag * <span class="number">24</span>, (lag + <span class="number">3</span>) * <span class="number">24</span>):</span><br><span class="line"> column = <span class="string">f'power_shift<span class="subst">{shift - lag * <span class="number">24</span>}</span>'</span></span><br><span class="line"> data[column]=data.groupby(<span class="string">'id_encode'</span>)[<span class="string">'power'</span>].shift(shift).astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> </span><br><span class="line"> tmp_cols_3.append(column)</span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"> diff_cols = [<span class="string">f"<span class="subst">{feat}</span>_diff"</span> <span class="keyword">for</span> feat <span class="keyword">in</span> tmp_cols_3[<span class="number">1</span>:]]</span><br><span class="line"> data[diff_cols] = data[tmp_cols_3].diff(axis=<span class="number">1</span>).astype(np.float16, copy=<span class="literal">False</span>)[tmp_cols_3[<span class="number">1</span>:]]</span><br><span class="line"> </span><br><span class="line"> data[<span class="string">f'diff_mean_1_to_3'</span>] = data[diff_cols].mean(axis=<span class="number">1</span>).astype(np.float16, copy=<span class="literal">False</span>) </span><br><span class="line"> </span><br><span class="line"> data[<span class="string">f'power_window_1_to_3_mean'</span>] = data[tmp_cols_3].mean(axis=<span class="number">1</span>).astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> data[<span class="string">f'power_window_1_to_3_max'</span>] = data[tmp_cols_3].<span class="built_in">max</span>(axis=<span class="number">1</span>).astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> data[<span class="string">f'power_window_1_to_3_min'</span>] = data[tmp_cols_3].<span class="built_in">min</span>(axis=<span class="number">1</span>).astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> data[<span class="string">f'power_window_1_to_3_median'</span>] = data[tmp_cols_3].<span class="built_in">min</span>(axis=<span class="number">1</span>).astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> data[<span class="string">f'power_window_1_to_3_std'</span>] = data[tmp_cols_3].std(axis=<span class="number">1</span>).astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> data</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_history_info</span>(<span class="params">df_label, history, window</span>):</span></span><br><span class="line"> id_encode_count_0 = history.groupby([<span class="string">'id_encode'</span>])[<span class="string">'power'</span>].agg(count_0).reset_index()</span><br><span class="line"> id_encode_count_0.columns = [<span class="string">'id_encode'</span>, <span class="string">f'power_count_0_window_<span class="subst">{window}</span>'</span>]</span><br><span class="line"> id_encode_count_0[<span class="string">f'power_count_0_window_<span class="subst">{window}</span>'</span>] = id_encode_count_0[<span class="string">f'power_count_0_window_<span class="subst">{window}</span>'</span>].astype(np.uint16, copy=<span class="literal">False</span>)</span><br><span class="line"> </span><br><span class="line"> df_label = df_label.merge(id_encode_count_0, on=[<span class="string">'id_encode'</span>], how=<span class="string">'left'</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="comment">## count/nunique</span></span><br><span class="line"> <span class="keyword">for</span> feat <span class="keyword">in</span> [</span><br><span class="line"> <span class="string">'id_encode'</span>,</span><br><span class="line"> <span class="string">'parking_free'</span>,</span><br><span class="line"> <span class="string">'dc_equipment_kw'</span>, </span><br><span class="line"> <span class="string">'ele_price_dis'</span>, </span><br><span class="line"> <span class="string">'total_price_dis'</span>,</span><br><span class="line"> <span class="string">'f3_dis'</span></span><br><span class="line"> ]:</span><br><span class="line"></span><br><span class="line"> agg_ = history[feat].value_counts().reset_index()</span><br><span class="line"> agg_.columns = [feat, <span class="string">f'<span class="subst">{feat}</span>_count_window_<span class="subst">{window}</span>'</span>]</span><br><span class="line"> agg_[<span class="string">f'<span class="subst">{feat}</span>_count_window_<span class="subst">{window}</span>'</span>] = agg_[<span class="string">f'<span class="subst">{feat}</span>_count_window_<span class="subst">{window}</span>'</span>].astype(np.uint32, copy=<span class="literal">False</span>)</span><br><span class="line"> df_label = df_label.merge(agg_, on=[feat], how=<span class="string">'left'</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> feat <span class="keyword">in</span> [</span><br><span class="line"> <span class="string">'temp_max'</span>, </span><br><span class="line"> <span class="string">'temp_min'</span>, <span class="string">'weather'</span>,<span class="string">'ele_price_dis'</span>, <span class="string">'ser_price_dis'</span>, </span><br><span class="line"> <span class="string">'after_ser_price_dis'</span>, <span class="string">'f1_dis'</span>, <span class="string">'total_price_dis'</span></span><br><span class="line"> ]:</span><br><span class="line"> agg_ = history.groupby([<span class="string">'id_encode'</span>])[feat].agg(<span class="string">'nunique'</span>).reset_index()</span><br><span class="line"> agg_.columns = [<span class="string">'id_encode'</span>, <span class="string">f'id_encode_<span class="subst">{feat}</span>_nunique_window_<span class="subst">{window}</span>'</span>]</span><br><span class="line"> agg_[<span class="string">f'id_encode_<span class="subst">{feat}</span>_nunique_window_<span class="subst">{window}</span>'</span>] = agg_[<span class="string">f'id_encode_<span class="subst">{feat}</span>_nunique_window_<span class="subst">{window}</span>'</span>].astype(np.uint8, copy=<span class="literal">False</span>)</span><br><span class="line"> </span><br><span class="line"> df_label = df_label.merge(agg_, on=[<span class="string">'id_encode'</span>], how=<span class="string">'left'</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="comment">## 目标编码</span></span><br><span class="line"> <span class="keyword">for</span> feat <span class="keyword">in</span> [<span class="string">'id_encode'</span>, <span class="string">'parking_free'</span>, <span class="string">'flag'</span>, </span><br><span class="line"> <span class="string">'h3'</span>, <span class="string">'year'</span>,<span class="string">'quarter'</span>, </span><br><span class="line"> <span class="string">'month'</span>, <span class="string">'day'</span>, <span class="string">'hour'</span>,</span><br><span class="line"> <span class="string">'dayofweek'</span>, <span class="string">'is_wknd'</span>,</span><br><span class="line"></span><br><span class="line"> <span class="string">'dc_equipment_kw'</span>,</span><br><span class="line"> <span class="string">'ele_price_dis'</span>, <span class="string">'ser_price_dis'</span>, <span class="string">'after_ser_price_dis'</span>, <span class="string">'f1_dis'</span>, <span class="string">'total_price_dis'</span>,</span><br><span class="line"></span><br><span class="line"> <span class="string">'f3_dis'</span></span><br><span class="line"></span><br><span class="line"> ]:</span><br><span class="line"> agg_ = history.groupby([feat])[<span class="string">'power'</span>].agg(<span class="string">'mean'</span>).reset_index()</span><br><span class="line"> agg_.columns = [feat, <span class="string">f'<span class="subst">{feat}</span>_power_mean_window_<span class="subst">{window}</span>'</span>]</span><br><span class="line"> agg_[<span class="string">f'<span class="subst">{feat}</span>_power_mean_window_<span class="subst">{window}</span>'</span>] = agg_[<span class="string">f'<span class="subst">{feat}</span>_power_mean_window_<span class="subst">{window}</span>'</span>].astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> df_label = df_label.merge(agg_, on=[feat], how=<span class="string">'left'</span>)</span><br><span class="line"> </span><br><span class="line"></span><br><span class="line"> <span class="keyword">for</span> feat1 <span class="keyword">in</span> [</span><br><span class="line"> <span class="string">'id_encode'</span>,</span><br><span class="line"> <span class="string">'parking_free'</span>, <span class="string">'flag'</span>, <span class="string">'dc_equipment_kw'</span>,</span><br><span class="line"> <span class="string">'ele_price_dis'</span>, <span class="string">'ser_price_dis'</span>, <span class="string">'after_ser_price_dis'</span>, <span class="string">'f1_dis'</span>, <span class="string">'total_price_dis'</span>,</span><br><span class="line"> <span class="string">'f3_dis'</span></span><br><span class="line"> ]:</span><br><span class="line"> <span class="keyword">for</span> feat2 <span class="keyword">in</span> [</span><br><span class="line"> <span class="comment"># 'year',</span></span><br><span class="line"> <span class="comment"># 'quarter', </span></span><br><span class="line"> <span class="comment"># 'month', </span></span><br><span class="line"> <span class="comment"># 'day', </span></span><br><span class="line"> <span class="string">'hour'</span>,</span><br><span class="line"> <span class="string">'dayofweek'</span>, </span><br><span class="line"> <span class="string">'is_wknd'</span></span><br><span class="line"> ]:</span><br><span class="line"> agg_ = history.groupby([feat1, feat2])[<span class="string">'power'</span>].agg(<span class="string">'mean'</span>).reset_index()</span><br><span class="line"> agg_.columns = [feat1, feat2, <span class="string">f'<span class="subst">{feat1}</span>_<span class="subst">{feat2}</span>_power_mean_window_<span class="subst">{window}</span>'</span>]</span><br><span class="line"> agg_[<span class="string">f'<span class="subst">{feat1}</span>_<span class="subst">{feat2}</span>_power_mean_window_<span class="subst">{window}</span>'</span>] = agg_[<span class="string">f'<span class="subst">{feat1}</span>_<span class="subst">{feat2}</span>_power_mean_window_<span class="subst">{window}</span>'</span>].astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> df_label = df_label.merge(agg_, on=[feat1, feat2], how=<span class="string">'left'</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> df_label</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">drop_features = [<span class="string">'ds'</span>,</span><br><span class="line"> <span class="string">'power'</span>,</span><br><span class="line"> <span class="string">'date'</span>,</span><br><span class="line"> <span class="string">'id_encode_year_count'</span>,</span><br><span class="line"> <span class="string">'id_encode_quarter_count'</span>,</span><br><span class="line"> <span class="string">'id_encode_month_count'</span>,</span><br><span class="line"> <span class="string">'id_encode_day_count'</span>,</span><br><span class="line"> <span class="string">'id_encode_dayofweek_count'</span>,</span><br><span class="line"> <span class="string">'parking_free_id_encode_nunique'</span>,</span><br><span class="line"> <span class="string">'dc_equipment_kw_id_encode_nunique'</span>,</span><br><span class="line"> <span class="string">'city_power_mean'</span>,</span><br><span class="line"> <span class="string">'city_power_max'</span>,</span><br><span class="line"> <span class="string">'city_power_std'</span>,</span><br><span class="line"> <span class="string">'weather_power_mean'</span>,</span><br><span class="line"> <span class="string">'weather_power_max'</span>,</span><br><span class="line"> <span class="string">'weather_power_std'</span>,</span><br><span class="line"> <span class="string">'parking_free_year_power_mean'</span>,</span><br><span class="line"> <span class="string">'parking_free_year_power_max'</span>,</span><br><span class="line"> <span class="string">'parking_free_year_power_std'</span>,</span><br><span class="line"> <span class="string">'id_encode_month_power_mean'</span>,</span><br><span class="line"> <span class="string">'id_encode_month_power_max'</span>,</span><br><span class="line"> <span class="string">'id_encode_month_power_min'</span>,</span><br><span class="line"> <span class="string">'id_encode_month_power_std'</span>,</span><br><span class="line"> <span class="string">'id_encode_day_power_mean'</span>,</span><br><span class="line"> <span class="string">'id_encode_day_power_max'</span>,</span><br><span class="line"> <span class="string">'id_encode_day_power_min'</span>,</span><br><span class="line"> <span class="string">'id_encode_day_power_std'</span></span><br><span class="line"> ]</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">cat_features = [</span><br><span class="line"> <span class="string">'id_encode'</span>, <span class="string">'parking_free'</span>, <span class="string">'flag'</span>, <span class="string">'h3'</span>,<span class="string">'province'</span>, <span class="string">'city'</span>, <span class="string">'district'</span>, <span class="string">'town'</span>, </span><br><span class="line"> <span class="string">'weather'</span>, <span class="string">'is_wknd'</span>, <span class="string">'is_month_start'</span>, <span class="string">'is_month_end'</span></span><br><span class="line">]</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">cat_train</span>(<span class="params">train_df, features</span>):</span></span><br><span class="line"> </span><br><span class="line"> X_train = train_df[features]</span><br><span class="line"> Y_train = train_df[[<span class="string">'power'</span>]]</span><br><span class="line"> scores = []</span><br><span class="line"> </span><br><span class="line"> folds = StratifiedKFold(n_splits= <span class="number">5</span>, shuffle=<span class="literal">True</span>, random_state=<span class="number">1996</span>)</span><br><span class="line"> <span class="keyword">for</span> fold_, (train_index, test_index) <span class="keyword">in</span> <span class="built_in">enumerate</span>(folds.split(train_df, train_df[<span class="string">'id_encode'</span>])):</span><br><span class="line"> print(<span class="string">'Fold_{}'</span>.<span class="built_in">format</span>(fold_))</span><br><span class="line"></span><br><span class="line"> X_trn, y_trn = X_train.iloc[train_index], Y_train.iloc[train_index]</span><br><span class="line"> X_val, y_val = X_train.iloc[test_index], Y_train.iloc[test_index]</span><br><span class="line"> </span><br><span class="line"> trn_dataset = Pool(X_trn,y_trn)</span><br><span class="line"> val_dataset = Pool(X_val,y_val)</span><br><span class="line"></span><br><span class="line"> best_n = <span class="number">3500</span></span><br><span class="line"> model = CatBoostRegressor(iterations=best_n,</span><br><span class="line"> learning_rate=<span class="number">0.1</span>,</span><br><span class="line"> depth=<span class="number">6</span>,</span><br><span class="line"> loss_function=<span class="string">'RMSE'</span>,</span><br><span class="line"> eval_metric=<span class="string">'RMSE'</span>,</span><br><span class="line"> random_seed = <span class="number">22222</span>,</span><br><span class="line"> <span class="comment"># bagging_temperature = 0.2,</span></span><br><span class="line"> od_type=<span class="string">'Iter'</span>,</span><br><span class="line"> metric_period = <span class="number">500</span>,</span><br><span class="line"> <span class="comment"># od_wait=500,</span></span><br><span class="line"> <span class="comment"># task_type='GPU',</span></span><br><span class="line"> l2_leaf_reg=<span class="number">5</span>,</span><br><span class="line"> min_data_in_leaf=<span class="number">500</span>,</span><br><span class="line"> <span class="comment"># scale_pos_weight=16,</span></span><br><span class="line"> )</span><br><span class="line"></span><br><span class="line"> model.fit(trn_dataset,</span><br><span class="line"> eval_set=val_dataset,</span><br><span class="line"> use_best_model=<span class="literal">True</span>,</span><br><span class="line"> verbose=<span class="number">500</span>,</span><br><span class="line"> early_stopping_rounds=<span class="number">50</span>)</span><br><span class="line"></span><br><span class="line"> model.save_model(<span class="string">f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/cat_test<span class="subst">{fold_}</span>.txt'</span>)</span><br><span class="line"></span><br><span class="line"> val_pred = model.predict(X_val)</span><br><span class="line"></span><br><span class="line"> score = mean_squared_error(y_val, val_pred, squared=<span class="literal">False</span>)</span><br><span class="line"> scores.append(score)</span><br><span class="line"> print(<span class="string">'===rmse==='</span>, score)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">del</span> X_trn, y_trn, X_val, y_val, trn_dataset, val_dataset, model, val_pred; gc.collect()</span><br><span class="line"></span><br><span class="line"> <span class="keyword">del</span> train_df; gc.collect()</span><br><span class="line"> score = np.mean(scores)</span><br><span class="line"> print(<span class="string">'mean rmse:'</span>, score)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">lgb_train</span>(<span class="params">train_df, features</span>):</span></span><br><span class="line"> X_train = train_df[features]</span><br><span class="line"> Y_train = train_df[[<span class="string">'power'</span>]]</span><br><span class="line"></span><br><span class="line"> scores = []</span><br><span class="line"> params = {<span class="string">'learning_rate'</span>: <span class="number">0.1</span>, </span><br><span class="line"> <span class="string">'boosting_type'</span>: <span class="string">'gbdt'</span>, </span><br><span class="line"> <span class="string">'objective'</span>: <span class="string">'rmse'</span>,</span><br><span class="line"> <span class="string">'metric'</span>: <span class="string">'rmse'</span>,</span><br><span class="line"> <span class="string">'min_child_samples'</span>: <span class="number">46</span>, </span><br><span class="line"> <span class="string">'min_child_weight'</span>: <span class="number">0.01</span>,</span><br><span class="line"> <span class="string">'feature_fraction'</span>: <span class="number">0.8</span>, </span><br><span class="line"> <span class="string">'bagging_fraction'</span>: <span class="number">0.8</span>, </span><br><span class="line"> <span class="string">'bagging_freq'</span>: <span class="number">5</span>, </span><br><span class="line"> <span class="string">'num_leaves'</span>: <span class="number">32</span>, </span><br><span class="line"> <span class="string">'min_data_in_leaf'</span>: <span class="number">30</span>,</span><br><span class="line"> <span class="comment"># 'max_depth': 7, </span></span><br><span class="line"> <span class="string">'n_jobs'</span>: -<span class="number">1</span>, </span><br><span class="line"> <span class="string">'seed'</span>: <span class="number">22222</span>, </span><br><span class="line"> <span class="string">'verbosity'</span>: -<span class="number">1</span>, </span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> folds = StratifiedKFold(n_splits=<span class="number">5</span>, shuffle=<span class="literal">True</span>, random_state=<span class="number">1996</span>)</span><br><span class="line"> <span class="keyword">for</span> fold_, (train_index, test_index) <span class="keyword">in</span> <span class="built_in">enumerate</span>(folds.split(train_df, train_df[<span class="string">'id_encode'</span>])):</span><br><span class="line"> print(<span class="string">'Fold_{}'</span>.<span class="built_in">format</span>(fold_))</span><br><span class="line"></span><br><span class="line"> X_trn, y_trn = X_train.iloc[train_index], Y_train.iloc[train_index]</span><br><span class="line"> X_val, y_val = X_train.iloc[test_index], Y_train.iloc[test_index]</span><br><span class="line"> </span><br><span class="line"> trn_data = lgb.Dataset(X_trn,y_trn)</span><br><span class="line"> val_data = lgb.Dataset(X_val,y_val)</span><br><span class="line"> </span><br><span class="line"> best_n = <span class="number">3000</span></span><br><span class="line"> print(<span class="string">'best_n:'</span>,best_n)</span><br><span class="line"> model = lgb.train(params, </span><br><span class="line"> trn_data, </span><br><span class="line"> best_n, </span><br><span class="line"> valid_sets = [trn_data, val_data], </span><br><span class="line"> verbose_eval = <span class="number">500</span>,</span><br><span class="line"> early_stopping_rounds = <span class="number">50</span>)</span><br><span class="line"></span><br><span class="line"> model.save_model(<span class="string">f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/lgb_test<span class="subst">{fold_}</span>.txt'</span>)</span><br><span class="line"> val_pred = model.predict(X_val)</span><br><span class="line"></span><br><span class="line"> score = mean_squared_error(y_val, val_pred, squared=<span class="literal">False</span>)</span><br><span class="line"> scores.append(score)</span><br><span class="line"> print(<span class="string">'===rmse==='</span>, score)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">del</span> X_trn, y_trn, X_val, y_val, trn_data, val_data, model, val_pred; gc.collect()</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">del</span> train_df; gc.collect()</span><br><span class="line"> score = np.mean(scores)</span><br><span class="line"> print(<span class="string">'mean rmse:'</span>, score)</span><br><span class="line"> </span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">xgb_train</span>(<span class="params">train_df, features</span>):</span></span><br><span class="line"> X_train = train_df[features]</span><br><span class="line"> Y_train = train_df[[<span class="string">'power'</span>]]</span><br><span class="line"></span><br><span class="line"> scores = []</span><br><span class="line"></span><br><span class="line"> folds = StratifiedKFold(n_splits=<span class="number">5</span>, shuffle=<span class="literal">True</span>, random_state=<span class="number">1996</span>)</span><br><span class="line"> <span class="keyword">for</span> fold_, (train_index, test_index) <span class="keyword">in</span> <span class="built_in">enumerate</span>(folds.split(train_df, train_df[<span class="string">'id_encode'</span>])):</span><br><span class="line"> print(<span class="string">'Fold_{}'</span>.<span class="built_in">format</span>(fold_))</span><br><span class="line"></span><br><span class="line"> X_trn, y_trn = X_train.iloc[train_index], Y_train.iloc[train_index]</span><br><span class="line"> X_val, y_val = X_train.iloc[test_index], Y_train.iloc[test_index]</span><br><span class="line"> </span><br><span class="line"> best_n = <span class="number">3000</span></span><br><span class="line"> model = xgb.XGBRegressor(base_score=<span class="number">0.5</span>, booster=<span class="string">'gbtree'</span>, </span><br><span class="line"> n_estimators=best_n,</span><br><span class="line"> early_stopping_rounds=<span class="number">50</span>,</span><br><span class="line"> objective=<span class="string">'reg:squarederror'</span>,</span><br><span class="line"> max_depth=<span class="number">7</span>,</span><br><span class="line"> <span class="comment"># tree_method='gpu_hist',</span></span><br><span class="line"> learning_rate=<span class="number">0.01</span>)</span><br><span class="line"> </span><br><span class="line"> model.fit(X_trn, y_trn,</span><br><span class="line"> eval_set=[(X_trn, y_trn), (X_val,y_val)],</span><br><span class="line"> verbose=<span class="number">500</span>)</span><br><span class="line"></span><br><span class="line"> model.save_model(<span class="string">f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/xgb_test<span class="subst">{fold_}</span>.txt'</span>)</span><br><span class="line"></span><br><span class="line"> val_pred = model.predict(X_val)</span><br><span class="line"> </span><br><span class="line"> score = mean_squared_error(y_val, val_pred, squared=<span class="literal">False</span>)</span><br><span class="line"> scores.append(score)</span><br><span class="line"> print(<span class="string">'===rmse==='</span>, score)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">del</span> X_trn, y_trn, X_val, y_val, model, val_pred; gc.collect()</span><br><span class="line"> <span class="keyword">del</span> train_df; gc.collect()</span><br><span class="line"> score = np.mean(scores)</span><br><span class="line"> print(<span class="string">'mean rmse:'</span>, score)</span><br><span class="line"> </span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">train</span>(<span class="params">data, date_id, lag, drop_features</span>):</span></span><br><span class="line"> </span><br><span class="line"> data = get_feature(data, lag=lag)</span><br><span class="line"></span><br><span class="line"> train_df = data[data[<span class="string">'ds'</span>] <= date_id].reset_index(drop = <span class="literal">True</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> lag == <span class="number">1</span>:</span><br><span class="line"> train_dfs = []</span><br><span class="line"> <span class="keyword">for</span> ds <span class="keyword">in</span> tqdm(<span class="built_in">sorted</span>(train_df[<span class="string">'ds'</span>].unique(), reverse=<span class="literal">True</span>)[:-<span class="number">14</span>]):</span><br><span class="line"> df_label = data[data[<span class="string">'ds'</span>] == ds]</span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> ds % <span class="number">10</span> < <span class="number">5</span>:</span><br><span class="line"> <span class="keyword">for</span> window <span class="keyword">in</span> [<span class="number">7</span>, <span class="number">14</span>]:</span><br><span class="line"> end_time = datetime.datetime.strptime(<span class="built_in">str</span>(ds),<span class="string">"%Y%m%d"</span>)</span><br><span class="line"> start_time = end_time - datetime.timedelta(window)</span><br><span class="line"> history = data[(data[<span class="string">'date'</span>] < end_time) & (data[<span class="string">'date'</span>] >= start_time)]</span><br><span class="line"> df_label = get_history_info(df_label, history, window)</span><br><span class="line"> train_dfs.append(df_label)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">for</span> window <span class="keyword">in</span> [<span class="number">7</span>, <span class="number">14</span>]:</span><br><span class="line"> end_time = datetime.datetime.strptime(<span class="built_in">str</span>(ds),<span class="string">"%Y%m%d"</span>)</span><br><span class="line"> start_time = end_time - datetime.timedelta(window)</span><br><span class="line"> history = data[(data[<span class="string">'date'</span>] < end_time) & (data[<span class="string">'date'</span>] >= start_time)]</span><br><span class="line"> df_label = get_history_info(df_label, history, window)</span><br><span class="line"> train_dfs.append(df_label)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">del</span> df_label, history; gc.collect()</span><br><span class="line"> train_df = pd.concat(train_dfs)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> feat <span class="keyword">in</span> train_df.columns:</span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">str</span>(train_df[feat].dtype) == <span class="string">'float64'</span>:</span><br><span class="line"> <span class="keyword">if</span> train_df[feat].<span class="built_in">max</span>() < np.finfo(np.float16).<span class="built_in">max</span>:</span><br><span class="line"> train_df[feat] = train_df[feat].astype(np.float16, copy=<span class="literal">False</span>)</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> train_df[feat] = train_df[feat].astype(np.float32, copy=<span class="literal">False</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">del</span> data, train_dfs; gc.collect()</span><br><span class="line"> </span><br><span class="line"> drop = []</span><br><span class="line"> <span class="keyword">for</span> feat <span class="keyword">in</span> train_df.columns:</span><br><span class="line"> <span class="keyword">if</span> train_df[feat].nunique() == <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">del</span> train_df[feat]</span><br><span class="line"> gc.collect()</span><br><span class="line"> </span><br><span class="line"> features = [feat <span class="keyword">for</span> feat <span class="keyword">in</span> train_df.columns <span class="keyword">if</span> feat <span class="keyword">not</span> <span class="keyword">in</span> drop_features]</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> feat <span class="keyword">in</span> train_df.columns:</span><br><span class="line"> <span class="keyword">if</span> feat <span class="keyword">not</span> <span class="keyword">in</span> features + [<span class="string">'power'</span>, <span class="string">'ds'</span>]:</span><br><span class="line"> <span class="keyword">del</span> train_df[feat]</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> feat <span class="keyword">in</span> features + [<span class="string">'power'</span>, <span class="string">'ds'</span>]:</span><br><span class="line"> <span class="keyword">try</span>:</span><br><span class="line"> <span class="keyword">if</span> (train_df[feat].<span class="built_in">max</span>() > <span class="number">999999</span>):</span><br><span class="line"> print(feat, train_df[feat].<span class="built_in">min</span>(),train_df[feat].<span class="built_in">max</span>())</span><br><span class="line"> <span class="keyword">except</span>:</span><br><span class="line"> print(feat)</span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"> <span class="comment"># gc.collect()</span></span><br><span class="line"> cat_train(train_df, features)</span><br><span class="line"> <span class="comment"># train_df = train_df[train_df['ds'] >= 20220801]</span></span><br><span class="line"> <span class="comment"># gc.collect()</span></span><br><span class="line"> <span class="comment"># lgb_train(train_df, features)</span></span><br><span class="line"> <span class="comment"># train_df.replace([np.inf, -np.inf], np.nan, inplace=True)</span></span><br><span class="line"> <span class="comment"># xgb_train(train_df, features)</span></span><br><span class="line"></span><br><span class="line"> print(<span class="built_in">len</span>(features), train_df.shape)</span><br><span class="line"> <span class="keyword">return</span> features</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">predict</span>(<span class="params">model, data, pred_date_id, lag, features, merge=<span class="literal">False</span>, mode=<span class="string">'cat'</span></span>):</span></span><br><span class="line"> data = get_feature(data, lag=lag)</span><br><span class="line"> </span><br><span class="line"> test_df = data[data[<span class="string">'ds'</span>] == pred_date_id]</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> lag == <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">for</span> window <span class="keyword">in</span> [<span class="number">7</span>, <span class="number">14</span>]:</span><br><span class="line"> end_time = datetime.datetime.strptime(<span class="string">'20230415'</span>,<span class="string">"%Y%m%d"</span>)</span><br><span class="line"> start_time = end_time - datetime.timedelta(window)</span><br><span class="line"> history = data[(data[<span class="string">'date'</span>] < end_time) & (data[<span class="string">'date'</span>] >= start_time)]</span><br><span class="line"></span><br><span class="line"> test_df = get_history_info(test_df, history, window)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># print(len(features), test_df.shape)</span></span><br><span class="line"> </span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> mode == <span class="string">'xgb'</span>:</span><br><span class="line"> test_df.replace([np.inf, -np.inf], np.nan, inplace=<span class="literal">True</span>)</span><br><span class="line"> pred = model.predict(xgb.DMatrix(test_df[features]))</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> pred = model.predict(test_df[features])</span><br><span class="line"> </span><br><span class="line"> test_df[<span class="string">'power'</span>] = pred</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> merge == <span class="literal">True</span>:</span><br><span class="line"> data1 = data[data[<span class="string">'ds'</span>] != pred_date_id]</span><br><span class="line"> data2 = data[data[<span class="string">'ds'</span>] == pred_date_id]</span><br><span class="line"></span><br><span class="line"> <span class="keyword">del</span> data2[<span class="string">'power'</span>]</span><br><span class="line"> data2 = data2.merge(test_df[[<span class="string">'id_encode'</span>, <span class="string">'ds'</span>, <span class="string">'hour'</span>, <span class="string">'power'</span>]], </span><br><span class="line"> on=[<span class="string">'id_encode'</span>, <span class="string">'ds'</span>, <span class="string">'hour'</span>], how=<span class="string">'left'</span>)</span><br><span class="line"></span><br><span class="line"> data = pd.concat([data1, data2])</span><br><span class="line"></span><br><span class="line"> data = data.sort_values([<span class="string">'id_encode'</span>, <span class="string">'ds'</span>, <span class="string">'hour'</span>]).reset_index(drop=<span class="literal">True</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> data, pred</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># ### model1</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line">len_use = <span class="built_in">len</span>(data)</span><br><span class="line">features = train(data, <span class="number">20230414</span>, <span class="number">1</span>, drop_features)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># In[ ]:</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">cat_res = []</span><br><span class="line"><span class="keyword">for</span> fold_ <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">5</span>):</span><br><span class="line"> model1 = CatBoostRegressor()</span><br><span class="line"> model1.load_model(<span class="string">f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/cat_test<span class="subst">{fold_}</span>.txt'</span>)</span><br><span class="line"> print(<span class="string">f'load model: /opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/cat_test<span class="subst">{fold_}</span>.txt'</span>)</span><br><span class="line"> data_fold = data.copy()</span><br><span class="line"> </span><br><span class="line"> data_fold, model1_pred_20230415 = predict(model1, data_fold, <span class="number">20230415</span>, <span class="number">1</span>, features, merge=<span class="literal">True</span>, mode=<span class="string">'cat'</span>)</span><br><span class="line"> data_fold, model1_pred_20230416 = predict(model1, data_fold, <span class="number">20230416</span>, <span class="number">1</span>, features, merge=<span class="literal">True</span>, mode=<span class="string">'cat'</span>)</span><br><span class="line"> data_fold, model1_pred_20230417 = predict(model1, data_fold, <span class="number">20230417</span>, <span class="number">1</span>, features, merge=<span class="literal">True</span>, mode=<span class="string">'cat'</span>)</span><br><span class="line"> data_fold, model1_pred_20230418 = predict(model1, data_fold, <span class="number">20230418</span>, <span class="number">1</span>, features, merge=<span class="literal">True</span>, mode=<span class="string">'cat'</span>)</span><br><span class="line"> data_fold, model1_pred_20230419 = predict(model1, data_fold, <span class="number">20230419</span>, <span class="number">1</span>, features, merge=<span class="literal">True</span>, mode=<span class="string">'cat'</span>)</span><br><span class="line"> data_fold, model1_pred_20230420 = predict(model1, data_fold, <span class="number">20230420</span>, <span class="number">1</span>, features, merge=<span class="literal">True</span>, mode=<span class="string">'cat'</span>)</span><br><span class="line"> data_fold, model1_pred_20230421 = predict(model1, data_fold, <span class="number">20230421</span>, <span class="number">1</span>, features, merge=<span class="literal">True</span>, mode=<span class="string">'cat'</span>)</span><br><span class="line"> </span><br><span class="line"> test_df = data_fold[data_fold[<span class="string">'ds'</span>] > <span class="number">20230414</span>].reset_index(drop = <span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"> test_df[<span class="string">'power'</span>] = test_df[<span class="string">'power'</span>].apply(<span class="keyword">lambda</span> x : <span class="number">0</span> <span class="keyword">if</span> x<<span class="number">0</span> <span class="keyword">else</span> x)</span><br><span class="line"></span><br><span class="line"> submit = test_df[[<span class="string">'id_encode'</span>, <span class="string">'date'</span>,<span class="string">'hour'</span>,<span class="string">'power'</span>]]</span><br><span class="line"> submit[<span class="string">'time'</span>] = submit[<span class="string">'date'</span>].astype(<span class="built_in">str</span>)</span><br><span class="line"> submit[<span class="string">'time'</span>] = submit[<span class="string">'time'</span>].apply(<span class="keyword">lambda</span> x: x.replace(<span class="string">'-'</span>, <span class="string">''</span>))</span><br><span class="line"> submit.rename(columns={<span class="string">'time'</span>:<span class="string">'ds'</span>}, inplace=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"> submit[[<span class="string">'id_encode'</span>, <span class="string">'ds'</span>, <span class="string">'hour'</span>,<span class="string">'power'</span>]].to_csv(<span class="string">f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/result/cat_fold<span class="subst">{fold_}</span>.csv'</span>, index=<span class="literal">None</span>)</span><br><span class="line"> cat_res.append(submit[[<span class="string">'id_encode'</span>, <span class="string">'ds'</span>, <span class="string">'hour'</span>, <span class="string">'power'</span>]])</span><br><span class="line"></span><br><span class="line"><span class="comment"># lgb_res = []</span></span><br><span class="line"><span class="comment"># for fold_ in range(5):</span></span><br><span class="line"><span class="comment"># model1 = lgb.Booster(model_file=f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/lgb_test{fold_}.txt')</span></span><br><span class="line"><span class="comment"># print(f'load model: /opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/lgb_test{fold_}.txt')</span></span><br><span class="line"><span class="comment"># data_fold = data.copy()</span></span><br><span class="line"> </span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230415 = predict(model1, data_fold, 20230415, 1, features, merge=True, mode='lgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230416 = predict(model1, data_fold, 20230416, 1, features, merge=True, mode='lgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230417 = predict(model1, data_fold, 20230417, 1, features, merge=True, mode='lgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230418 = predict(model1, data_fold, 20230418, 1, features, merge=True, mode='lgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230419 = predict(model1, data_fold, 20230419, 1, features, merge=True, mode='lgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230420 = predict(model1, data_fold, 20230420, 1, features, merge=True, mode='lgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230421 = predict(model1, data_fold, 20230421, 1, features, merge=True, mode='lgb')</span></span><br><span class="line"> </span><br><span class="line"><span class="comment"># test_df = data_fold[data_fold['ds'] > 20230414].reset_index(drop = True)</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># test_df['power'] = test_df['power'].apply(lambda x : 0 if x<0 else x)</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># submit = test_df[['id_encode', 'date', 'hour','power']]</span></span><br><span class="line"><span class="comment"># submit['time'] = submit['date'].astype(str)</span></span><br><span class="line"><span class="comment"># submit['time'] = submit['time'].apply(lambda x: x.replace('-', ''))</span></span><br><span class="line"><span class="comment"># submit.rename(columns={'time':'ds'}, inplace=True)</span></span><br><span class="line"><span class="comment"># submit[['id_encode', 'ds', 'hour', 'power']].to_csv(f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/result/lgb_test{fold_}.csv', index=None)</span></span><br><span class="line"><span class="comment"># del data_fold; gc.collect()</span></span><br><span class="line"><span class="comment"># lgb_res.append(submit[['id_encode', 'ds', 'hour', 'power']])</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># xgb_res = []</span></span><br><span class="line"><span class="comment"># for fold_ in range(5):</span></span><br><span class="line"><span class="comment"># model1 = xgb.Booster()</span></span><br><span class="line"><span class="comment"># model1.load_model(f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/xgb_test{fold_}.txt')</span></span><br><span class="line"><span class="comment"># # print(f'load model: /opt/project/workspace/wdbMYCgvUKxijgaxPNUL/model/xgb_test{fold_}.txt')</span></span><br><span class="line"><span class="comment"># data_fold = data.copy()</span></span><br><span class="line"> </span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230415 = predict(model1, data_fold, 20230415, 1, features, merge=True, mode='xgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230416 = predict(model1, data_fold, 20230416, 1, features, merge=True, mode='xgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230417 = predict(model1, data_fold, 20230417, 1, features, merge=True, mode='xgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230418 = predict(model1, data_fold, 20230418, 1, features, merge=True, mode='xgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230419 = predict(model1, data_fold, 20230419, 1, features, merge=True, mode='xgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230420 = predict(model1, data_fold, 20230420, 1, features, merge=True, mode='xgb')</span></span><br><span class="line"><span class="comment"># data_fold, model1_pred_20230421 = predict(model1, data_fold, 20230421, 1, features, merge=True, mode='xgb')</span></span><br><span class="line"> </span><br><span class="line"><span class="comment"># test_df = data_fold[data_fold['ds'] > 20230414].reset_index(drop = True)</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># test_df['power'] = test_df['power'].apply(lambda x : 0 if x<0 else x)</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># submit = test_df[['id_encode', 'date', 'hour','power']]</span></span><br><span class="line"><span class="comment"># submit['time'] = submit['date'].astype(str)</span></span><br><span class="line"><span class="comment"># submit['time'] = submit['time'].apply(lambda x: x.replace('-', ''))</span></span><br><span class="line"><span class="comment"># submit.rename(columns={</span></span><br><span class="line"><span class="comment"># 'time':'ds'</span></span><br><span class="line"><span class="comment"># }, inplace=True)</span></span><br><span class="line"><span class="comment"># submit[['id_encode', 'ds', 'hour', 'power']].to_csv(f'/opt/project/workspace/wdbMYCgvUKxijgaxPNUL/result/xgb_test{fold_}.csv', index=None)</span></span><br><span class="line"><span class="comment"># del data_fold; gc.collect()</span></span><br><span class="line"><span class="comment"># xgb_res.append(submit[['id_encode', 'ds', 'hour', 'power']])</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># # In[ ]:</span></span><br><span class="line"></span><br><span class="line">cat_5fold_res = cat_res[<span class="number">0</span>]</span><br><span class="line">cat_5fold_res[<span class="string">'power'</span>] = (cat_res[<span class="number">0</span>][<span class="string">'power'</span>] + </span><br><span class="line"> cat_res[<span class="number">1</span>][<span class="string">'power'</span>] + </span><br><span class="line"> cat_res[<span class="number">2</span>][<span class="string">'power'</span>] + </span><br><span class="line"> cat_res[<span class="number">3</span>][<span class="string">'power'</span>] + </span><br><span class="line"> cat_res[<span class="number">4</span>][<span class="string">'power'</span>]</span><br><span class="line"> ) / <span class="number">5</span></span><br><span class="line"><span class="comment"># # lgb_5fold_res = lgb_res[0]</span></span><br><span class="line"><span class="comment"># # lgb_5fold_res['power'] = ( lgb_res[0]['power'] + </span></span><br><span class="line"><span class="comment"># # lgb_res[1]['power'] + </span></span><br><span class="line"><span class="comment"># # lgb_res[2]['power'] + </span></span><br><span class="line"><span class="comment"># # lgb_res[3]['power'] + </span></span><br><span class="line"><span class="comment"># # lgb_res[4]['power']</span></span><br><span class="line"><span class="comment"># # ) / 5</span></span><br><span class="line"><span class="comment"># xgb_5fold_res = xgb_res[0]</span></span><br><span class="line"><span class="comment"># xgb_5fold_res['power'] = ( xgb_res[0]['power'] + </span></span><br><span class="line"><span class="comment"># xgb_res[1]['power'] + </span></span><br><span class="line"><span class="comment"># xgb_res[2]['power'] + </span></span><br><span class="line"><span class="comment"># xgb_res[3]['power'] + </span></span><br><span class="line"><span class="comment"># xgb_res[4]['power']</span></span><br><span class="line"><span class="comment"># ) / 5</span></span><br><span class="line">submit = cat_5fold_res</span><br><span class="line">submit[<span class="string">'rank'</span>] = submit[<span class="string">'power'</span>].rank(method=<span class="string">'first'</span>).astype(<span class="built_in">int</span>)</span><br><span class="line">submit.loc[submit[<span class="string">'rank'</span>] >= <span class="number">15000</span> , <span class="string">'power'</span>] *= <span class="number">1.05</span></span><br><span class="line">submit.loc[(submit[<span class="string">'rank'</span>] < <span class="number">15000</span>) & (submit[<span class="string">'rank'</span>] >= <span class="number">10000</span>), <span class="string">'power'</span>] *= <span class="number">1.03</span></span><br><span class="line">submit.loc[(submit[<span class="string">'rank'</span>] < <span class="number">10000</span>) & (submit[<span class="string">'rank'</span>] >= <span class="number">5000</span>), <span class="string">'power'</span>] *= <span class="number">1.01</span></span><br><span class="line">submit.loc[(submit[<span class="string">'rank'</span>] < <span class="number">5000</span>), <span class="string">'power'</span>] *= <span class="number">1.01</span></span><br><span class="line"></span><br><span class="line">os.makedirs(<span class="string">'/opt/output/result'</span>, exist_ok=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">submit[[<span class="string">'id_encode'</span>,<span class="string">'ds'</span>,<span class="string">'hour'</span>, <span class="string">'power'</span>]].to_csv(<span class="string">'/opt/output/result/result.csv'</span>, index=<span class="literal">None</span>)</span><br><span class="line"></span><br><span class="line">print(<span class="string">'train: '</span>, len_train, <span class="string">'use: '</span>, len_use)</span><br><span class="line">print(<span class="string">'min_id_encode: '</span>, min_id_encode, <span class="string">'max_id_encode: '</span>, max_id_encode)</span><br><span class="line">print(<span class="string">'min_date: '</span>, min_date, <span class="string">'max_date: '</span>, max_date)</span><br><span class="line">print(<span class="string">'test_min_date: '</span>, test_min_date, <span class="string">'test_max_date: '</span>, test_max_date)</span><br></pre></td></tr></table></figure>
</div>
<footer class="post-footer">
<div class="post-eof"></div>
</footer>
</article>
<article itemscope itemtype="http://schema.org/Article" class="post-block" lang="zh-CN">
<link itemprop="mainEntityOfPage" href="http://example.com/2021/04/09/%E5%9F%BA%E4%BA%8Etensorflow%E7%9A%84%E5%A4%9A%E5%B1%82%E6%84%9F%E7%9F%A5%E6%9C%BA%E7%9A%84%E5%AE%9E%E7%8E%B0/">
<span hidden itemprop="author" itemscope itemtype="http://schema.org/Person">
<meta itemprop="image" content="/images/avatar.gif">
<meta itemprop="name" content="Guangshan Shui">
<meta itemprop="description" content="">
</span>
<span hidden itemprop="publisher" itemscope itemtype="http://schema.org/Organization">
<meta itemprop="name" content="水广山">
</span>
<header class="post-header">
<h2 class="post-title" itemprop="name headline">
<a href="/2021/04/09/%E5%9F%BA%E4%BA%8Etensorflow%E7%9A%84%E5%A4%9A%E5%B1%82%E6%84%9F%E7%9F%A5%E6%9C%BA%E7%9A%84%E5%AE%9E%E7%8E%B0/" class="post-title-link" itemprop="url">基于tensorflow的多层感知机的代码实现</a>
</h2>
<div class="post-meta">
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar"></i>
</span>
<span class="post-meta-item-text">发表于</span>
<time title="创建时间:2021-04-09 21:23:09" itemprop="dateCreated datePublished" datetime="2021-04-09T21:23:09+08:00">2021-04-09</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar-check"></i>
</span>
<span class="post-meta-item-text">更新于</span>
<time title="修改时间:2021-04-10 09:15:30" itemprop="dateModified" datetime="2021-04-10T09:15:30+08:00">2021-04-10</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-folder"></i>
</span>
<span class="post-meta-item-text">分类于</span>
<span itemprop="about" itemscope itemtype="http://schema.org/Thing">
<a href="/categories/ML/" itemprop="url" rel="index"><span itemprop="name">ML</span></a>
</span>
</span>
</div>
</header>
<div class="post-body" itemprop="articleBody">
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># coding=utf-8</span></span><br><span class="line"><span class="comment"># author: Shuigs18</span></span><br><span class="line"><span class="comment"># date: 2021-04-07</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 基于tensorflow的多层感知机的实现</span></span><br><span class="line"><span class="comment"># 三层(输入、隐藏、输出)MLP + K-fold + weight decay(L2正则化) + dropout</span></span><br><span class="line"><span class="comment"># Relu(隐藏) + softmax(输出)</span></span><br><span class="line"><span class="comment"># 数据集:fashion—mnist</span></span><br><span class="line"><span class="comment"># 梯度计算利用tensorflow</span></span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</span><br><span class="line"><span class="keyword">from</span> tensorflow <span class="keyword">import</span> keras</span><br><span class="line"><span class="keyword">from</span> matplotlib <span class="keyword">import</span> pyplot <span class="keyword">as</span> plt</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> random</span><br><span class="line"><span class="keyword">import</span> time</span><br><span class="line"><span class="keyword">from</span> tensorflow.keras.datasets <span class="keyword">import</span> fashion_mnist</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 读取数据并处理(测试集验证集)</span></span><br><span class="line"><span class="comment"># 定义模型参数</span></span><br><span class="line"><span class="comment"># 定义激活函数Relu和softmax</span></span><br><span class="line"><span class="comment"># 定义网络(dropout实现)</span></span><br><span class="line"><span class="comment"># 定义损失函数(加L2正则项 weight decay实现)</span></span><br><span class="line"><span class="comment"># K-fold 函数 </span></span><br><span class="line"><span class="comment"># 先 k-fold 确定训练集和验证集</span></span><br><span class="line"><span class="comment"># 然后在将训练集生成SGD训练的迭代器</span></span><br><span class="line"><span class="comment"># train函数 (输入包含训练集迭代器和验证集)</span></span><br><span class="line"><span class="comment"># 小批量梯度下降</span></span><br><span class="line"><span class="comment"># 返回 fold 0,1,2,3,4,5 训练集验证集的误差</span></span><br><span class="line"><span class="comment"># predict函数 生成结果</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 数据处理</span></span><br><span class="line">(X_train, Y_train), (X_test, Y_test) = fashion_mnist.load_data()</span><br><span class="line">batch_size = <span class="number">256</span></span><br><span class="line">X_train = tf.cast(X_train, tf.float32)</span><br><span class="line">X_test = tf.cast(X_test, tf.float32)</span><br><span class="line">X_train = X_train / <span class="number">255.0</span> <span class="comment"># 颜色的深浅没有关系</span></span><br><span class="line">X_test = X_test / <span class="number">255.0</span></span><br><span class="line"><span class="comment"># 划分批次 </span></span><br><span class="line"><span class="comment"># train_iter = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)</span></span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 定义模型参数(一层隐藏层)</span></span><br><span class="line">dim_inputs, dim_hiddens, dim_outputs = <span class="number">784</span>, <span class="number">256</span>, <span class="number">10</span></span><br><span class="line">W1 = tf.Variable(tf.random.normal(shape=(dim_inputs, dim_hiddens), mean=<span class="number">0.0</span>, stddev=<span class="number">0.01</span>, dtype=tf.float32))</span><br><span class="line">b1 = tf.Variable(tf.zeros(dim_hiddens, dtype=tf.float32))</span><br><span class="line">W2 = tf.Variable(tf.random.normal(shape=(dim_hiddens, dim_outputs), mean=<span class="number">0.0</span>, stddev=<span class="number">0.01</span>, dtype=tf.float32))</span><br><span class="line">b2 = tf.Variable(tf.random.normal([dim_outputs], mean=<span class="number">0.0</span>, stddev=<span class="number">0.01</span>, dtype=tf.float32))</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 定义激活函数 ReLu softmax</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">ReLu</span>(<span class="params">X</span>):</span></span><br><span class="line"> <span class="keyword">return</span> tf.math.maximum(X, <span class="number">0</span>)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">softmax</span>(<span class="params">X</span>):</span></span><br><span class="line"> <span class="keyword">return</span> tf.exp(X) / tf.reduce_sum(tf.math.exp(X), axis=<span class="number">1</span>, keepdims=<span class="literal">True</span>)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># dropout(H, drop_prob)</span></span><br><span class="line"><span class="comment"># 网络net()</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">dropout</span>(<span class="params">H, drop_prob</span>):</span></span><br><span class="line"> <span class="keyword">assert</span> <span class="number">0</span> <= drop_prob <= <span class="number">1</span> </span><br><span class="line"> keep_prob = <span class="number">1</span>- drop_prob</span><br><span class="line"> <span class="keyword">if</span> keep_prob == <span class="number">0</span>:</span><br><span class="line"> <span class="keyword">return</span> tf.zeros_like(H)</span><br><span class="line"> mask = tf.random.uniform(shape=H.shape, minval=<span class="number">0</span>, maxval=<span class="number">1</span>) < keep_prob</span><br><span class="line"> <span class="keyword">return</span> tf.cast(mask, dtype=tf.float32) * tf.cast(H, dtype=tf.float32) / keep_prob</span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义整个网络</span></span><br><span class="line">drop_prob1 = <span class="number">0.2</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">net</span>(<span class="params">X, training=<span class="literal">False</span></span>):</span></span><br><span class="line"> X = tf.reshape(X, shape=(-<span class="number">1</span>, dim_inputs))</span><br><span class="line"> H1 = ReLu(tf.matmul(X, W1) + b1)</span><br><span class="line"> <span class="keyword">if</span> training:</span><br><span class="line"> H1 = drop_out(H, drop_prob1)</span><br><span class="line"> <span class="keyword">return</span> softmax(tf.matmul(H1, W2) + b2)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 定义损失函数 交叉熵 L2正则项</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">loss_cross_entropy</span>(<span class="params">y_true, y_pred</span>):</span></span><br><span class="line"> <span class="keyword">return</span> tf.losses.sparse_categorical_crossentropy(y_true, y_pred)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">L2_penalty</span>(<span class="params">W</span>):</span></span><br><span class="line"> <span class="keyword">return</span> tf.reduce_sum(W ** <span class="number">2</span>) / <span class="number">2.0</span></span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 定义get_K_fold_data函数</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">get_K_fold_data</span>(<span class="params">k, i, X, Y</span>):</span></span><br><span class="line"> fold_size = X.shape[<span class="number">0</span>] // k</span><br><span class="line"> X_train, Y_train = <span class="literal">None</span>, <span class="literal">None</span></span><br><span class="line"> <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(k):</span><br><span class="line"> idx = <span class="built_in">slice</span>(j * fold_size, (j + <span class="number">1</span>) * fold_size)</span><br><span class="line"> X_part, Y_part = X[idx, :], Y[idx]</span><br><span class="line"> <span class="keyword">if</span> j == i:</span><br><span class="line"> X_valid, Y_valid = X_part, Y_part</span><br><span class="line"> <span class="keyword">elif</span> X_train <span class="keyword">is</span> <span class="literal">None</span>:</span><br><span class="line"> X_train, Y_train = X_part, Y_part</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> X_train = tf.concat([X_train, X_part], axis=<span class="number">0</span>)</span><br><span class="line"> Y_train = tf.concat([Y_train, Y_part], axis=<span class="number">0</span>)</span><br><span class="line"> <span class="keyword">return</span> X_train, Y_train, X_valid, Y_valid</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 定义训练函数</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">train</span>(<span class="params">net, train_iter, X_valid, Y_valid, loss, num_epochs, batch_size, params=<span class="literal">None</span>, learning_rate=<span class="literal">None</span></span>):</span> </span><br><span class="line"> train_loss_sum, valid_loss_sum = <span class="number">0.0</span>, <span class="number">0.0</span></span><br><span class="line"> <span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line"> train_loss, valid_loss, n = <span class="number">0.0</span>, <span class="number">0.0</span>, <span class="number">0</span></span><br><span class="line"> <span class="keyword">for</span> X_train, Y_train <span class="keyword">in</span> train_iter:</span><br><span class="line"> <span class="keyword">with</span> tf.GradientTape() <span class="keyword">as</span> tape: </span><br><span class="line"> Y_pred = net(X_train)</span><br><span class="line"> l = loss(Y_train, Y_pred)</span><br><span class="line"> <span class="comment"># 计算梯度</span></span><br><span class="line"> grads = tape.gradient(l, params)</span><br><span class="line"> <span class="comment"># 创建一个优化器</span></span><br><span class="line"> opt = tf.keras.optimizers.SGD(learning_rate = learning_rate)</span><br><span class="line"> <span class="comment"># 梯度下降更新参数(批量梯度下降)</span></span><br><span class="line"> opt.apply_gradients(<span class="built_in">zip</span>([grad / batch_size <span class="keyword">for</span> grad <span class="keyword">in</span> grads], params))</span><br><span class="line"> <span class="comment"># 更新训练集损失值</span></span><br><span class="line"> train_loss += l.numpy().<span class="built_in">sum</span>()</span><br><span class="line"> n += Y_train.shape[<span class="number">0</span>]</span><br><span class="line"> valid_loss += (loss(Y_valid, net(X_valid)).numpy().<span class="built_in">sum</span>() / Y_valid.shape[<span class="number">0</span>])</span><br><span class="line"> train_loss /= n</span><br><span class="line"> train_loss_sum += train_loss</span><br><span class="line"> valid_loss_sum += valid_loss</span><br><span class="line"> train_loss_sum /= num_epochs</span><br><span class="line"> valid_loss_sum /= num_epochs</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> params, train_loss_sum, valid_loss_sum</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">k_fold</span>(<span class="params">k, net, X_train, Y_train, num_epochs, </span></span></span><br><span class="line"><span class="function"><span class="params"> batch_size, loss_cross_entropy, params=<span class="literal">None</span>, learning_rate=<span class="literal">None</span></span>):</span></span><br><span class="line"> start_time = time.time() </span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(k):</span><br><span class="line"> data = get_K_fold_data(k, i, X_train, Y_train)</span><br><span class="line"> train_iter = tf.data.Dataset.from_tensor_slices((data[<span class="number">0</span>], data[<span class="number">1</span>])).batch(batch_size)</span><br><span class="line"> X_valid = data[<span class="number">2</span>]</span><br><span class="line"> Y_valid = data[<span class="number">3</span>]</span><br><span class="line"> params, train_loss, valid_loss = train(net, train_iter, X_valid, Y_valid, loss_cross_entropy, num_epochs, batch_size, params, learning_rate)</span><br><span class="line"> print(<span class="string">"fold %d: train loss %f, valid loss %f"</span> % (i, train_loss, valid_loss))</span><br><span class="line"> end_time = time.time()</span><br><span class="line"> print(<span class="string">'总用时:%f'</span> % (start_time - end_time))</span><br><span class="line"> <span class="keyword">return</span> params</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 预测函数 predict()</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">predict</span>(<span class="params">net, params, X_test</span>):</span></span><br><span class="line"> Y_pred = net(X_test)</span><br><span class="line"> result = tf.argmax(Y_pred, axis=<span class="number">1</span>)</span><br><span class="line"> <span class="keyword">return</span> result</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">params = [W1, b1, W2, b2]</span><br><span class="line">num_epochs = <span class="number">10</span></span><br><span class="line">params = k_fold(<span class="number">5</span>, net, X_train, Y_train, num_epochs, batch_size, loss_cross_entropy, params, learning_rate=<span class="number">0.1</span>)</span><br><span class="line"><span class="string">'''</span></span><br><span class="line"><span class="string">fold 0: train loss 0.169400, valid loss 0.169741</span></span><br><span class="line"><span class="string">fold 1: train loss 0.156253, valid loss 0.167162</span></span><br><span class="line"><span class="string">fold 2: train loss 0.147749, valid loss 0.159257</span></span><br><span class="line"><span class="string">fold 3: train loss 0.139382, valid loss 0.157192</span></span><br><span class="line"><span class="string">fold 4: train loss 0.130207, valid loss 0.149647</span></span><br><span class="line"><span class="string">总用时:-114.377255</span></span><br><span class="line"><span class="string">'''</span></span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">result = predict(net, params, X_test)</span><br><span class="line"><span class="string">'''</span></span><br><span class="line"><span class="string"><tf.Tensor: shape=(100,), dtype=int64, numpy=</span></span><br><span class="line"><span class="string">array([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 5, 3, 4, 1, 2, 2, 8, 0, 2, 5,</span></span><br><span class="line"><span class="string"> 7, 5, 1, 4, 6, 0, 9, 6, 8, 8, 3, 3, 8, 0, 7, 5, 7, 9, 0, 1, 6, 7,</span></span><br><span class="line"><span class="string"> 6, 7, 2, 1, 2, 6, 4, 2, 5, 8, 2, 2, 8, 4, 8, 0, 7, 7, 8, 5, 1, 1,</span></span><br><span class="line"><span class="string"> 3, 3, 7, 8, 7, 0, 2, 6, 2, 3, 1, 2, 8, 4, 1, 8, 5, 9, 5, 0, 3, 2,</span></span><br><span class="line"><span class="string"> 0, 2, 5, 3, 6, 7, 1, 8, 0, 1, 2, 2])></span></span><br><span class="line"><span class="string">'''</span></span><br></pre></td></tr></table></figure>
<h1 id="Reference"><a href="#Reference" class="headerlink" title="Reference"></a>Reference</h1>
</div>
<footer class="post-footer">
<div class="post-eof"></div>
</footer>
</article>
<article itemscope itemtype="http://schema.org/Article" class="post-block" lang="zh-CN">
<link itemprop="mainEntityOfPage" href="http://example.com/2021/03/21/Ensemble/">
<span hidden itemprop="author" itemscope itemtype="http://schema.org/Person">
<meta itemprop="image" content="/images/avatar.gif">
<meta itemprop="name" content="Guangshan Shui">
<meta itemprop="description" content="">
</span>
<span hidden itemprop="publisher" itemscope itemtype="http://schema.org/Organization">
<meta itemprop="name" content="水广山">
</span>
<header class="post-header">
<h2 class="post-title" itemprop="name headline">
<a href="/2021/03/21/Ensemble/" class="post-title-link" itemprop="url">Ensemble</a>
</h2>
<div class="post-meta">
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar"></i>
</span>
<span class="post-meta-item-text">发表于</span>
<time title="创建时间:2021-03-21 14:53:52 / 修改时间:15:05:09" itemprop="dateCreated datePublished" datetime="2021-03-21T14:53:52+08:00">2021-03-21</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-folder"></i>
</span>
<span class="post-meta-item-text">分类于</span>
<span itemprop="about" itemscope itemtype="http://schema.org/Thing">
<a href="/categories/ML/" itemprop="url" rel="index"><span itemprop="name">ML</span></a>
</span>
</span>
</div>
</header>
<div class="post-body" itemprop="articleBody">
<ul>
<li><p>boosting</p>
</li>
<li><p>bagging</p>
</li>
<li><p>stacking</p>
</li>
</ul>
<h1 id="1-boosting"><a href="#1-boosting" class="headerlink" title="1. boosting"></a>1. boosting</h1><ul>
<li><strong>代表算法:AdaBoost</strong></li>
</ul>
<p><strong>只能分类?</strong></p>
<p>弱分类器:比较粗糙的分类规则</p>
<p>强分类器:精确的分类规则</p>
<p>提升方法就是组合弱分类器,构成一个强分类器</p>
<p>大多数的提升方法都是改变训练数据的概率分布(训练数据的权值分布),针对<strong>不同的训练数据分布</strong>调用弱学习算法学习</p>
<p>对于提升方法来说有两个问题需要回答:</p>
<ol>
<li><p>每一轮如何改变训练数据的权值或概率分布 </p>
</li>
<li><p>如何将弱分类器组合成一个强分类器</p>
</li>
</ol>
<h2 id="1-1-AdaBoost"><a href="#1-1-AdaBoost" class="headerlink" title="1.1 AdaBoost"></a>1.1 AdaBoost</h2><p><strong>AdaBoost的做法:</strong></p>
<p>针对问题1,提高被前一轮弱分类器错误分类的样本的权值,降低分类正确的样本的权值</p>
<p>针对问题2,组合采用<strong>加权多数表决</strong>的方法</p>
<p><strong>算法8.1 (AdaBoost)</strong></p>
<p>输入:训练数据集 $T$ ;弱学习算法</p>
<p>输出:最终的分类器 $G(x)$</p>
<ol>
<li><p>初始化训练数据权值分布<br>$$<br>D_1 = (w_{11}, \cdots, w_{1i}, \cdots, w_{1N}),\quad w_{1i} = \frac{1}{N}, \quad i = 1,2,…,N<br>$$</p>
</li>
<li><p>对 $m = 1,2,…, M$</p>
<p>a. 使用具有权值分布的 $D_m$ 的训练数据集学习得到基本分类器(学习的依据就是分类误差率最小,分类错误的权重和)<br>$$<br>G_m(x) : \mathcal{X} \rightarrow \lbrace c_1,c_2,\cdots, c_k \rbrace<br>$$<br>b.计算 $G_m(x)$ 在训练数据集上的分类误差率<br>$$<br>e_{m}=\sum_{i=1}^{N} P\left(G_{m}\left(x_{i}\right) \neq y_{i}\right)=\sum_{i=1}^{N} w_{mi} I\left(G_{m}\left(x_{i}\right) \neq y_{i}\right)<br>$$<br>c. 计算 $G_m(x)$ 的系数<br>$$<br>\alpha_m = \frac{1}{2} \log(\frac{1-e_m}{e_m})<br>$$<br>这里的对数是自然对数</p>
<p>d. 更新训练数据集的权值分布<br>$$<br>D_{m+1} = (w_{m+1,1}, …, w_{m+1, i}, …, w_{m+1, N}) \\<br>w_{m+1, i} = \frac{w_{mi}}{Z_m} \exp(-\alpha_m I\left(G_{m}\left(x_{i}\right) = y_{i}\right)), \quad i =1,2,…,N<br>$$<br>这里, $Z_m$规范因子<br>$$<br>Z_m = \sum_{i = 1}^{N}w_{mi}\exp(-\alpha_m I\left(G_{m}\left(x_{i}\right) = y_{i}\right))<br>$$<br>$I\left(G_{m}\left(x_{i}\right) = y_{i}\right)$,当分类正确的时候取1,分类错误取0。二分类时可以 $y_iG_m(x_i)$ 代替</p>
<p><strong>它使 $D_{m+1}$ 称为一个概率分布</strong></p>
</li>
<li><p>构建基本分类器的线性组合<br>$$<br>f(x) = \sum_{m=1}^{M}\alpha_mG_m(x)<br>$$<br>最终得到分类器<br>$$<br>G(x) = \text{sign}(f(x))<br>$$</p>
</li>
</ol>
<blockquote>
<p>注: $e_m$ 大的,$\alpha_m$ 小,这个时候观察权值的更新公式,对于分子来说,正确率高的,分子小,正确率低的,分子大, 这样就达到了提高被前一轮弱分类器错误分类的样本的权值,降低分类正确的样本的权值的目的</p>
</blockquote>
<h3 id="Q:Adaboost终止条件是什么?为什么不容易过拟合?"><a href="#Q:Adaboost终止条件是什么?为什么不容易过拟合?" class="headerlink" title="Q:Adaboost终止条件是什么?为什么不容易过拟合?"></a>Q:Adaboost终止条件是什么?为什么不容易过拟合?</h3><p>有些代码将终止条件设置为一个超参数,即弱分类器的个数。</p>
<p>为什么不容易过拟合?</p>
<p><a target="_blank" rel="noopener" href="https://www.zhihu.com/question/41047671/answer/127832345">知乎: 俞扬的回答</a></p>
<p>Adaboost在实验中表现出更加不容易过拟合(相对于别的boosting方法)的现象,尤其是在实验中发现,迭代中经验误差都已经不变化了的情况下,测试误差居然还能下降一段时间。</p>
<p>对于二分类问题,AdaBoost的训练误差以指数速率下降就很具有吸引力</p>
<h2 id="1-2-AdaBoost算法解释"><a href="#1-2-AdaBoost算法解释" class="headerlink" title="1.2 AdaBoost算法解释"></a>1.2 AdaBoost算法解释</h2><p><strong>AdaBoost = 加法模型 + 损失函数(指数函数) + 学习算法(前向分步算法)</strong></p>
<ul>
<li>加法模型</li>
</ul>
<p>$$<br>f(x) = \sum_{m=1}^{M} \beta_m b(x;\gamma_m)<br>$$</p>
<ul>
<li>损失函数</li>
</ul>
<p><br>$$<br>L(y,f(x)) = \exp(-yf(x))\\<br>(\alpha_m, G_m(x)) = \arg \min_{\alpha,G}\sum_{i = 1}^{N}\exp[-y_i(f_{m-1}(x_i) +\alpha G(x_i))]<br>$$</p>
<h2 id="1-3-Adaboost代码实现"><a href="#1-3-Adaboost代码实现" class="headerlink" title="1.3 Adaboost代码实现"></a>1.3 Adaboost代码实现</h2><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 解决的问题:二分类,Y_label = {-1, 1} 这里简化特征X Xi = {0,1} (二值化处理)</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 算法实现思路</span></span><br><span class="line"><span class="comment">## 计算分类错误率,输入应该是:数据集和训练好的模弱分类器参数,返回预测结果和分类误差率</span></span><br><span class="line"><span class="comment">## 单层提升树: 输入:数据,权值分布 输出:创建的单层提升树</span></span><br><span class="line"><span class="comment">## 生成提升树:输入:数据, 输出:提升树</span></span><br><span class="line"><span class="comment">## 预测结果</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">calc_e_Gx</span>(<span class="params">X_trainArr, Y_trainArr, n, div, D</span>):</span></span><br><span class="line"> <span class="string">'''</span></span><br><span class="line"><span class="string"> 计算分类误差率</span></span><br><span class="line"><span class="string"> n,要操作的特征</span></span><br><span class="line"><span class="string"> div,划分的点</span></span><br><span class="line"><span class="string"> D, 权值分布</span></span><br><span class="line"><span class="string"> return 预测结果, 分类误差率</span></span><br><span class="line"><span class="string"> '''</span></span><br><span class="line"> <span class="comment"># 初始化误差率</span></span><br><span class="line"> e = <span class="number">0</span> </span><br><span class="line"> <span class="comment"># 单独提取X, Y</span></span><br><span class="line"> X_n = X_trainArr[:, n]</span><br><span class="line"> Y_n = Y_trainArr</span><br><span class="line"> predict = []</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(X_n)):</span><br><span class="line"> <span class="keyword">if</span> X_n[i] < div:</span><br><span class="line"> predict[i] = -<span class="number">1</span></span><br><span class="line"> <span class="keyword">if</span> predict[i] != Y_train[i]: e += D[i]</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> predict[i] = <span class="number">1</span></span><br><span class="line"> <span class="keyword">if</span> predict[i] != Y_train[i]: e += D[i]</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> np.array(predict), e</span><br><span class="line"> </span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">creatSingleTree</span>(<span class="params">X_trainArr, Y_trainArr, D</span>):</span></span><br><span class="line"> <span class="comment"># 该函数可以用其他弱分类器代替</span></span><br><span class="line"> <span class="comment"># 获得样本数目及特征数量</span></span><br><span class="line"> m, n = np.shape(X_trainArr)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 单层树字典,用于存放当前提升树的参数,包括:分割点,预测结果,误差率(由预测结果计算),</span></span><br><span class="line"> <span class="comment"># 该单层树所处理的特征 </span></span><br><span class="line"> singleBoostTree = {} </span><br><span class="line"> <span class="comment"># 初始化误差率,最大为100%</span></span><br><span class="line"> singleBoostTree[<span class="string">'e'</span>] = <span class="number">1</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 遍历每一个特征,寻找用于划分的最合适的特征</span></span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(n):</span><br><span class="line"> <span class="comment"># 由于特征进行了二值化处理,只能为0、1,因此切分点为 -0.5, 0.5,1.5</span></span><br><span class="line"> <span class="keyword">for</span> div <span class="keyword">in</span> [-<span class="number">0.5</span>, <span class="number">0.5</span>, <span class="number">1.5</span>]:</span><br><span class="line"> Gx, e = calc_e_Gx(X_trainArr, Y_trainArr, i, div, D)</span><br><span class="line"> <span class="keyword">if</span> e < singleBoostTree[<span class="string">'e'</span>]:</span><br><span class="line"> singleBoostTree[<span class="string">'e'</span>] = e</span><br><span class="line"> singleBoostTree[<span class="string">'div'</span>] = div</span><br><span class="line"> singleBoostTree[<span class="string">'Gx'</span>] = Gx</span><br><span class="line"> singleBoostTree[<span class="string">'feature'</span>] = i</span><br><span class="line"> <span class="keyword">return</span> singleBoostTree</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">creatBoostingTree</span>(<span class="params">X_trainArr, Y_trainArr, treeNum = <span class="number">50</span></span>):</span></span><br><span class="line"> <span class="string">'''</span></span><br><span class="line"><span class="string"> treeNum: 弱分类器的数目作为一个超参数,可以通过交叉验证挑选一个最好的</span></span><br><span class="line"><span class="string"> return: 提升树</span></span><br><span class="line"><span class="string"> '''</span></span><br><span class="line"> m, n = np.shape(X_trainArr)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 初始化权值分布</span></span><br><span class="line"> D = np.array([<span class="number">1</span> / m] * m)</span><br><span class="line"> <span class="comment"># 初始化树列表</span></span><br><span class="line"> iterationNum = <span class="number">0</span></span><br><span class="line"> tree = []</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(treeNum):</span><br><span class="line"> <span class="comment"># 创建当层的提升树</span></span><br><span class="line"> iterationNum += <span class="number">1</span></span><br><span class="line"> curTree = singleBoostTree(X_trainArr, Y_trainArr, D)</span><br><span class="line"> <span class="comment"># 计算alpha</span></span><br><span class="line"> alpha = <span class="number">1</span> / <span class="number">2</span> * np.log((<span class="number">1</span> - curTree[<span class="string">'e'</span>]) / curTree[<span class="string">'e'</span>])</span><br><span class="line"> Gx = curTree[<span class="string">'Gx'</span>]</span><br><span class="line"> D = np.multiply(D, np.exp( -alpha * np.multiply(Y_trainArr, Gx))) / \</span><br><span class="line"> <span class="built_in">sum</span>(np.multiply(D, np.exp( -alpha * np.multiply(Y_trainArr, Gx))))</span><br><span class="line"> curTree[<span class="string">'alpha'</span>] = alpha</span><br><span class="line"> tree.append(curTree)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 当前训练集预测结果</span></span><br><span class="line"> finalpredict += alpha * Gx</span><br><span class="line"> <span class="comment"># 当前预测误差数目</span></span><br><span class="line"> error_count = <span class="number">0</span></span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(Y_trainArr)):</span><br><span class="line"> <span class="keyword">if</span> np.sign(finalpredict)[i] != Y_trainArr[i]:</span><br><span class="line"> error_count += <span class="number">1</span></span><br><span class="line"> </span><br><span class="line"> error_rate = error_count / <span class="built_in">len</span>(Y_trainArr)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 如果误差已经为0了,那么就可以停止了不用再计算了</span></span><br><span class="line"> <span class="keyword">if</span> error_rate == <span class="number">0</span>: <span class="keyword">return</span> tree</span><br><span class="line"> print(<span class="string">'Numbers of iteration: {}, \n error rate: {}'</span>.<span class="built_in">format</span>(iterationNum, error_rate))</span><br><span class="line"> <span class="keyword">return</span> tree</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">Gx_predict</span>(<span class="params">x, div, feature</span>):</span></span><br><span class="line"> <span class="keyword">if</span> x[feature] < div:</span><br><span class="line"> <span class="keyword">return</span> -<span class="number">1</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">model_predict</span>(<span class="params">X_testArr, tree</span>):</span></span><br><span class="line"> prediction = []</span><br><span class="line"> <span class="comment"># 每一层的tree 有div,alpha,feature</span></span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(X_testArr)):</span><br><span class="line"> result = <span class="number">0</span></span><br><span class="line"> <span class="keyword">for</span> curTree <span class="keyword">in</span> tree:</span><br><span class="line"> div = curTree[<span class="string">'div'</span>]</span><br><span class="line"> alpha = curTree[<span class="string">'alpha'</span>]</span><br><span class="line"> feature = curTree[<span class="string">'feature'</span>]</span><br><span class="line"> result += alpha * Gx_predict(X_testArr[i], div, feature)</span><br><span class="line"> prediction.append(result)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> prediction</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">model_score</span>(<span class="params">X_testArr, Y_testArr, tree</span>):</span></span><br><span class="line"> prediction = model_predict(X_testArr, tree)</span><br><span class="line"> error_count = <span class="number">0</span></span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(Y_testArr)):</span><br><span class="line"> <span class="keyword">if</span> prediction[i] != Y_testArr[i]:</span><br><span class="line"> error_count += <span class="number">1</span></span><br><span class="line"> score = <span class="number">1</span> - (error_count / <span class="built_in">len</span>(Y_testArr))</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> score</span><br></pre></td></tr></table></figure>
<h2 id="1-3-boosting-tree"><a href="#1-3-boosting-tree" class="headerlink" title="1.3 boosting tree"></a>1.3 boosting tree</h2><p>以<strong>决策树</strong>为基函数的提升方法称为<strong>提升树</strong>。可以是二叉分类树,或二叉回归树。</p>
<ul>
<li><strong>模型</strong></li>
</ul>
<p>$$<br>f_M(x) = \sum_{m = 1}^M T(x;\Theta_m)<br>$$</p>
<p>其中, $T(x;\Theta_m)$ 表示决策树, $\Theta_m$ 为决策树的参数, $M$ 为树的个数</p>
<ul>
<li><strong>参数确定</strong>(经验风险极小化)</li>
</ul>
<p>$$<br>\hat{\Theta_m}=\arg \min_{\Theta_m} \sum_{i=1}^{N} L\left(y_{i}, f_{m-1}\left(x_{i}\right)+T\left(x_{i}; \Theta_{m}\right)\right)<br>$$</p>
<ul>
<li><strong>损失函数</strong><ul>
<li>平方误差损失函数的回归问题</li>
<li>指数损失函数的分类问题</li>
<li>一般损失函数的决策问题</li>
</ul>
</li>
</ul>
<p>当分类问题是二分类,AdaBoost算法的弱分类器限制为二分类即可,就是boosting tree</p>
<ul>
<li><p>回归问题的提升树<br>$$<br>T(x;\Theta)=\sum_{j=1}^{J}c_j I(x \in R_j)<br>$$<br>其中,参数 $\Theta = \lbrace (R_1, c_1), (R_2, c_2),…,(R_J, c_J) \rbrace$ 表示树的区域划分和各区域上的常数, $J$ 是回归树的复杂度即叶结点的个数。</p>
<ul>
<li>采用<strong>平方误差损失函数</strong><br>$$<br>\begin{align*}<br>L(y,f_{m-1}(x) + T(x; \Theta_m)) &= [y - f_{m-1}(x) - T(x;\Theta_m)]^2\\<br>&=[r-T(x;\Theta_m)]^2<br>\end{align*}<br>$$<br>所以从损失函数可以看出,当前层树的确定,是通过拟合上一层模型的残差。所以算法可以通过将该层的训练数据给改为残差即可递归得到每一层的tree</li>
</ul>
</li>
<li><p><strong>参数更新(梯度下降)</strong></p>
</li>
</ul>
<p>$$<br>-[\frac{\partial L(y, f(x_i))}{\partial f(x_i)}]<em>{f(x)=f</em>{m-1}(x)}<br>$$</p>
<h1 id="2-Bagging-Beriman-1996a"><a href="#2-Bagging-Beriman-1996a" class="headerlink" title="2. Bagging[Beriman, 1996a]"></a>2. Bagging[Beriman, 1996a]</h1><p>采用自主抽样法(<strong>bootstrap sampling</strong>)</p>
<ul>
<li>确定样本集个数 $m$ </li>
<li>又放回的抽取 $m$ 个样本</li>
<li>采样出 $T$ 个含 $m$ 样本的采样集</li>
<li>然后基于每个样本集训练一个基学习器</li>
<li>预测可以使用简单投票法</li>
</ul>
<p>可以记录一下使用的数据,未被使用的数据可以用来对泛化性能进行“包外估计”</p>
<p>代码难度不大</p>
<p>使用random.sample()或者np.random.choice()就可以实现重抽样</p>
<p>random.sample()默认不重复抽样,可以抽取多维数据</p>
<p>np.random.choice() 默认重复抽样 replace=True, 不可以抽取多维数据,所以可以用这个函数产生下标。</p>
<h1 id="3-Random-Forest-RF"><a href="#3-Random-Forest-RF" class="headerlink" title="3. Random Forest(RF)"></a>3. Random Forest(RF)</h1><p>Bagging 的变体,RF在基学习器构建Bagging集成的基础上,进一步在决策树的训练过程中引入了随机属性选择。具体来说, 传统决策树在选择划分属性时是在当前结点的属性集合(假定有d个属性)中选择一个最优属性; 而在RF中, 对基决策树的每个结点, 先从该结点的属性集合中随机选择一个包含k个属性的子集, 然后再从这个子集中选择一个最优属性用于划分。若令$k=d$,则基决策树的构建与传统决策树相同;若令$k = 1$,则是随机选择一个属性用于划分; 一般情况下, 推荐值$k = log_2d$</p>
<p><strong>算法:</strong></p>
<p>输入:训练数据 $T$ ,基决策树(CART),重抽样参数 $m$,属性集合参数 $k$</p>
<p>输出:最终分类器</p>
<p>for 1,2,…T</p>
<p> 重抽样数据(m个),记录未被抽取的数据</p>
<p> 随机抽取k个属性</p>
<p> 根据重抽样的数据和确定的属性,学习基分类器 (如:CART)</p>
<p>end for</p>
<p>输出:最终分类器</p>
<h1 id="4-Stacking"><a href="#4-Stacking" class="headerlink" title="4. Stacking"></a>4. Stacking</h1><ul>
<li><strong>介绍</strong></li>
</ul>
<p><strong>假设是五折的stacking,我们有一个train数据集和一个test数据集,那么一个基本的stacking框架会进行如下几个操作:</strong></p>
<p>1、选择基模型。我们可以有xgboost,lightGBM,RandomForest,SVM,ANN,KNN,LR等等你能想到的各种基本算法模型。</p>
<p>2、把训练集分为不交叉的五份。我们标记为train1到train5。</p>
<p>3、从train1开始作为预测集,使用train2到train5建模(model1),然后预测train1,并保留结果;然后,以train2作为预测集,使用train1,train3到train5建模,预测train2,并保留结果(model2);如此进行下去…….直到把train1到train5各预测一遍;</p>
<p>4、把预测的结果按照train1到trian5的位置对应填补上,得到对train整个数据集在第一个基模型的一个stacking转换。</p>
<p>5、在上述建立的五个模型过程中,每个模型分别对test数据集进行预测,并最终保留这五列结果,然后对这五列取平均,作为第一个基模型对test数据的一个stacking转换。</p>
<p>6、然后进入第二层,将新特征作为第二层的train data,第一层的预测结果作为test data</p>
<p><strong>7、一般使用LR(逻辑回归)作为第二层的模型进行建模预测。</strong></p>
<p>但使用stacking会比较耗时,所以真正比赛的时候可以改为3折或4折</p>
<ul>
<li><strong>Stacking的输出层为什么用逻辑回归?</strong></li>
</ul>
<p>stacking的有效性主要来自于特征抽取。<strong>而表示学习中,如影随形的问题就是过拟合,试回想深度学习中的过拟合问题。</strong></p>
<p>在[5]中,周志华教授也重申了stacking在使用中的过拟合问题。因为第二层的特征来自于对于第一层数据的学习,那么第二层数据中的特征中不该包括原始特征,<strong>以降低过拟合的风险</strong>。举例:</p>
<ul>
<li>第二层数据特征:仅包含学习到的特征</li>
<li>第二层数据特征:包含学习到的特征 + 原始特征</li>
</ul>
<p>另一个例子是,stacking中一般都用交叉验证来避免过拟合,足可见这个问题的严重性。</p>
<p>为了降低过拟合的问题,第二层分类器应该是较为简单的分类器,广义线性如逻辑回归是一个不错的选择。<strong>在特征提取的过程中,我们已经使用了复杂的非线性变换,因此在输出层不需要复杂的分类器</strong>。这一点可以对比神经网络的激活函数或者输出层,都是很简单的函数,一点原因就是不需要复杂函数并能控制复杂度。</p>
<p>因此,stacking的输出层不需要过分复杂的函数,用逻辑回归还有额外的好处:</p>
<ul>
<li>配合L1正则化还可以进一步防止过拟合</li>
<li>配合L1正则化还可以选择有效特征,从第一层的学习器中删除不必要的分类器,节省运算开销。</li>
<li>逻辑回归的输出结果还可被理解为概率</li>
</ul>
<h1 id="Reference"><a href="#Reference" class="headerlink" title="Reference"></a>Reference</h1><p>[1]. <a target="_blank" rel="noopener" href="https://github.com/Dod-o/Statistical-Learning-Method_Code/blob/master/AdaBoost/AdaBoost.py">https://github.com/Dod-o/</a></p>
<p>[2]. 《统计学习方法》 – 李航</p>
<p>[3]. 《机器学习》 – 周志华</p>
<p>[4]. <a target="_blank" rel="noopener" href="https://cloud.tencent.com/developer/article/1005304">【SPA大赛】腾讯广告点击大赛:对stacking的一些基本介绍</a></p>
<p>[5]. Zhou, Z.H., 2012. <em>Ensemble methods: foundations and algorithms</em>. CRC press.</p>
</div>
<footer class="post-footer">
<div class="post-eof"></div>
</footer>
</article>
<article itemscope itemtype="http://schema.org/Article" class="post-block" lang="zh-CN">
<link itemprop="mainEntityOfPage" href="http://example.com/2021/03/18/DecisionTree/">
<span hidden itemprop="author" itemscope itemtype="http://schema.org/Person">
<meta itemprop="image" content="/images/avatar.gif">
<meta itemprop="name" content="Guangshan Shui">
<meta itemprop="description" content="">
</span>
<span hidden itemprop="publisher" itemscope itemtype="http://schema.org/Organization">
<meta itemprop="name" content="水广山">
</span>
<header class="post-header">
<h2 class="post-title" itemprop="name headline">
<a href="/2021/03/18/DecisionTree/" class="post-title-link" itemprop="url">DecisionTree</a>
</h2>
<div class="post-meta">
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar"></i>
</span>
<span class="post-meta-item-text">发表于</span>
<time title="创建时间:2021-03-18 15:51:09" itemprop="dateCreated datePublished" datetime="2021-03-18T15:51:09+08:00">2021-03-18</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar-check"></i>
</span>
<span class="post-meta-item-text">更新于</span>
<time title="修改时间:2022-09-19 10:08:15" itemprop="dateModified" datetime="2022-09-19T10:08:15+08:00">2022-09-19</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-folder"></i>
</span>
<span class="post-meta-item-text">分类于</span>
<span itemprop="about" itemscope itemtype="http://schema.org/Thing">
<a href="/categories/ML/" itemprop="url" rel="index"><span itemprop="name">ML</span></a>
</span>
</span>
</div>
</header>
<div class="post-body" itemprop="articleBody">
<p>基本的分类与回归方法。本文主要是关于分类的决策树。</p>
<p>通常包括三个步骤:特征选择、决策树的生成和决策树的修剪</p>
<h1 id="1-模型与学习"><a href="#1-模型与学习" class="headerlink" title="1. 模型与学习"></a>1. 模型与学习</h1><ul>
<li><strong>模型</strong></li>
</ul>
<p>可以将决策模型看作是一个if-then结构,具有的重要性质是互斥且完备</p>
<ul>
<li><strong>学习</strong></li>
</ul>
<p><strong>目标</strong>是根据给定的训练数据集构建一个决策树模型, 使它能够对实例进行正确的分类。通常使用损失函数(正则化的极大似然函数)表示这一目标。</p>
<p>决策树学习的算法通常是一个递归地选择最优特征, 并根据该特征对训练数据进行分割, 使得对各个子数据集有一个最好的分类的过程。</p>
<p>如果特征数量很多, 也可以在决策树学习开始的时候,对特征进行选择,只留下对训练数据有足够分类能力的特征</p>
<p>一般而言<strong>决策树学习算法</strong>包含</p>
<ul>
<li>特征选择</li>
<li>决策树的生成</li>
<li>决策树的剪枝过程</li>
</ul>
<p>常用的算法有ID3、C4.5、CART</p>
<h1 id="2-特征选择"><a href="#2-特征选择" class="headerlink" title="2. 特征选择"></a>2. 特征选择</h1><p>特征选择在于选取对训练数据具有分类能力的特征。 如果利用一个特征进行分类的结果与随机分类的结果没有很大差别, 则称这个特征是没有分类能力的</p>
<p>通常选择的准则是<strong>信息增益</strong>或<strong>信息增益比</strong>。</p>
<h2 id="2-1-信息增益"><a href="#2-1-信息增益" class="headerlink" title="2.1 信息增益"></a>2.1 信息增益</h2><ul>
<li><strong>熵</strong></li>
</ul>
<p>随机变量X的熵定义为<br>$$<br>H(X) = H(p) = -\sum_{i=1}^{n} p_{i} \log p_{i}<br>$$<br>熵越大,随机变量的不确定性就越大。</p>
<ul>
<li><strong>条件熵</strong></li>
</ul>
<p>已知随机变量X的条件下随机变量Y的不确定性<br>$$<br>H(Y \mid X)=\sum_{i=1}^{n} p_{i} H(Y \mid X=x_{i})<br>$$<br>其中 $p_i = P(X=x_i)$ .</p>
<p>当熵和条件熵中的概率由数据估计(特别是极大似然估计)得到时, 所对应的熵与条件熵分别称为经验熵(empirical entropy)和经验条件熵。</p>
<ul>
<li><strong>信息增益</strong></li>
</ul>
<p>表示得知特征 $X$ 的信息而使得类 $Y$ 的信息的不确定性减少的程度。</p>
<p>特征A对训练数据集D的信息增益 $g(D,A)$,定义为集合 $D$ 的经验熵 $H(D)$ 与特征 $A$ 给定条件下 $D$ 的经验条件熵 $H(D|A)$之差<br>$$<br>g(D,A) = H(D) - H(D|A)<br>$$<br>一般地, 熵H(y)与条件熵H(y|X)之差称为互信息(mutual information)。 决策树学习中的信息增益等价于训练数据集中类与特征的互信息</p>
<p>根据信息增益准则的特征选择方法是: 对训练数据集(或子集)D, 计算其每个特<br>征的信息增益,并比较它们的大小,选择信息增益最大的特征。</p>
<p>设训练数据集为 $D,|D|$ 表示其样本容量,即样本个数。设有 $K$ 个类 $C_{k}, k=$ $1,2, \cdots, K,|C_{k}|$ 为属于类 $C_{k}$ 的样本个数, $\sum_{k=1}^{K}\left|C_{k}\right|=|D|$ 设特征 $A$ 有 $n$ 个不同的取值 $\lbrace a_{1}, a_{2}, \cdots, a_{n} \rbrace$ 根据特征 $A$ 的取值将 $D$ 划分为 $n$ 个子集 $D_{1}, D_{2}, \cdots, D_{n}$ ,$|D_{i}|$ 为 $D_{i}$ 的样本个数, $\sum_{i=1}^{n}\left|D_{i}\right|=|D|$ 记子集 $D_{i}$ 中属于类 $C_{k}$ 的样本的集合为 $D_{i k}$ ,即 $D_{ik}=D_{i} \cap C_{k},|D_{ik}|$ 为 $D_{ik}$ 的样本个数。于是信息增益的算法如下。</p>
<p><strong>算法5.1(信息增益的算法)</strong><br>输入: 训练数据集 $D$ 和特征 $A$;</p>
<p>输出:特征 $A$ 对训练数据集 $D$ 的信息增益 $g(D, A)$ 。<br>(1) 计算数据集 $D$ 的经验熵 $H(D)$<br>$$<br>H(D)=-\sum_{k=1}^{K} \frac{\left|C_{k}\right|}{|D|} \log_{2} \frac{|C_{k}|}{|D|}<br>$$<br>(2) 计算特征 $A$ 对数据集 $D$ 的经验条件嫡 $H(D \mid A)$<br>$$<br>H(D \mid A)=\sum_{i=1}^{n} \frac{\left|D_{i}\right|}{|D|} H\left(D_{i}\right)=-\sum_{i=1}^{n} \frac{\left|D_{i}\right|}{|D|} \sum_{k=1}^{K} \frac{\left|D_{i k}\right|}{\left|D_{i}\right|} \log_{2} \frac{\left|D_{ik}\right|}{\left|D_{i}\right|}<br>$$<br>(3) 计算信息增益<br>$$<br>g(D, A)=H(D)-H(D \mid A)<br>$$</p>
<h2 id="2-2-信息增益比"><a href="#2-2-信息增益比" class="headerlink" title="2.2 信息增益比"></a>2.2 信息增益比</h2><ul>
<li>定义(信息增益与熵之比)<br>$$<br>g_{R}(D, A)=\frac{g(D, A)}{H_{A}(D)}<br>$$</li>
</ul>
<p>其中,$H_{A}(D)=-\sum_{i=1}^{n} \frac{|D_{i}|}{|D|} \log <em>{2} \frac{|D</em>{i}|}{|D|}, n$ 是特征 $A$ 取值的个数。</p>
<h1 id="3-决策树的生成"><a href="#3-决策树的生成" class="headerlink" title="3. 决策树的生成"></a>3. 决策树的生成</h1><h2 id="3-1-ID3算法"><a href="#3-1-ID3算法" class="headerlink" title="3.1 ID3算法"></a>3.1 ID3算法</h2><p>具体方法是: 从根结点(root node)开始, 对结点计算所有可能的特征的信息增益, 选择信息增益最大的特征作为结点的特征, 由该特征的不同取值建立子结点;再对子结点递归地调用以上方法, 构建决策树; 直到所有特征的信息增益均很小或没有特征可以选择为止。最后得到一颗决策树。ID3算法相当于用极大似然法进行概率模型的选择。</p>
<p><strong>算法3.1 (ID3算法)</strong></p>
<p>输入:训练数据集 $D$ ,特征集 $A$ 阈值 $\varepsilon$</p>
<p>输出:决策树 $T$ 。</p>
<p>(1)若 $D$ 中所有实例属于同一类 $C_{k},$ 则 $T$ 为单结点树,并将类 $C_{k}$ 作为该结点的类标记,返回 $T$;</p>
<p>(2)若 $A=\varnothing$ ,则 $T$ 为,sh单结点树,并将 $D$ 中实例数最大的类 $C_{k}$ 作为该结点的类标记,返回 $T$;</p>
<p>(3)否则,计算 $A$ 中各特征对 $D$ 的fanjia信息增益,选择信息增益最大的特征 $A_{g}$ ;</p>
<p>(4)如果 $A_{g}$ 的信息增益小于阈值 $\varepsilon$ ,则设置 $T$ 为单结点树,并将 $D$ 中实例数最大的类 $C_{k}$ 作为该结点的类标记,返回 $T$;</p>
<p>(5)否则,对 $A_{g}$ 的每一可能值 $a_{i},$ 依 $A_{g}=a_{i}$ 将 $D$ 分割为若干非空子集 $D_{i},$ 将$D_{i}$ 中实例数最大的类作为标记,构建子结点,由结点及其子结点构成树 $T$ ,返回 $T$;</p>
<p>(6)对第 $i$ 个子结点,以 $D_{i}$ 为训练集,以 $A-\lbrace A_{g}\rbrace$ 为特征集,递归地调用步 (1)$\sim$ 步 $(5),$ 得到子树 $T_{i},$ 返回 $T_{i}$</p>
<p><strong>缺点</strong>:只有树生成,没有树剪枝,容易过拟合</p>
<h2 id="3-2-C4-5算法"><a href="#3-2-C4-5算法" class="headerlink" title="3.2 C4.5算法"></a>3.2 C4.5算法</h2><p>与ID3相比,在生成过程中用<strong>信息增益比</strong>来选择特征。</p>
<p><strong>算法 3.2 (C4.5的生成算法)</strong></p>
<p>输入: 训练数据集 $D$, 特征集 $A$ 阈值 $\varepsilon ;$</p>
<p>输出: 决策树 $T_{\circ}$</p>
<p>(1)如果 $D$ 中所有实例属于同一类 $C_{k},$ 则置 $T$ 为单结点树,并将 $C_{k}$ 作为该结点的类,返回 $T$;</p>
<p>(2)如果 $A=\varnothing,$ 则置 $T$ 为单结点树,并将 $D$ 中实例数最大的类 $C_{k}$ 作为该结的类,返回 $T$;</p>
<p>(3)否则,计算 $A$ 中各特征对 $D$ 的信息增益比,选择信息增益比最大的特征 $A_{g} ;$</p>
<p>(4)如果 $A_{g}$ 的信息增益比小于阈值 $\varepsilon$ ,则置 $T$ 为单结点树,并将 $D$ 中实例数最大的类 $C_{k}$ 作为该结点的类,返回 $T$;</p>
<p>(5)否则,对 $A_{g}$ 的每一可能值 $a_{i},$ 依 $A_{g}=a_{i}$ 将 $D$ 分割为子集若干非空 $D_{i}$ ,将 $D_{i}$ 中实例数最大的类作为标记,构建子结点,由结点及其子结点构成树 $T$ ,返回 $T$;</p>
<p>(6)对结点 $i$,以 $D_{i}$ 为训练集,以 $A-\left{A_{g}\right}$ 为特征集,递归地调用步(1)$\sim$步 $(5),$ 得到子树 $T_{i},$ 返回 $T_{i}$ 。</p>
<h1 id="4-决策树的剪枝"><a href="#4-决策树的剪枝" class="headerlink" title="4. 决策树的剪枝"></a>4. 决策树的剪枝</h1><p>以上两种算法均未考虑树的复杂度,这样产生的决策树容易出现过拟合的问题。因此需要将生成的树进行简化。</p>
<p>决策树的剪枝往往通过极小化决策树整体的损失函数或代价函数来实现。设树 $T$ 的叶节点个数为 $|T|$ ,$t$ 是树 $T$ 的叶节点,该叶节点有 $N_t$ 个样本点,其中 $k$ 类样本点有 $N_{tk}$ 个, $H_t(T)$ 为叶节点 $t$ 上的经验熵。 $\alpha$ 为参数。则决策树的损失函数可以定义为<br>$$<br>\begin{array}{c}<br>C_{\alpha}(T)=\sum_{t=1}^{|T|} N_{t} H_{t}(T)+\alpha|T| \\<br>H_{t}(T)=-\sum_{k} \frac{N_{tk}}{N_{t}} \log \frac{N_{tk}}{N_{t}}<br>\end{array}<br>$$<br>将第一项记做 $C(T)$,这时有<br>$$<br>C_{\alpha}(T) = C(T)+ \alpha|T|<br>$$<br>第一项表示对训练数据的预测误差,第二项表示模型的复杂度。$\alpha \geq 0$ 控制两者之间的影响。由此可以看出,决策树的生成学习局部的模型,而决策树剪枝学习整体的模型。</p>
<p><strong>算法 4.1 (树的剪枝算法)</strong></p>
<p>输入:生成算法产生的整个树 $T$ ,参数 $\alpha$</p>
<p>输出:修剪后的子树 $T_{\alpha}$</p>
<p>(1)计算每个结点的经验熵</p>
<p>(2)递归的从树的叶结点向上回缩。</p>
<p>设一组叶结点回缩到其父结点之前与之后的整体树分别为 $T_B$ 与 $T_A$ ,其对应的损失函数值分别是 $C_{\alpha}(T_B)$ 与 $C_{\alpha}(T_A)$ , 如果<br>$$<br>C_{\alpha}(T_A) \leq C_{\alpha}(T_B)<br>$$<br>则进行剪枝,即将父结点变为新的叶结点。</p>
<p>(3)返回2,直至不能继续为止,得到损失函数最小的子树 $T$</p>
<h1 id="5-CART-classification-and-regresstion-tree-算法-–1984"><a href="#5-CART-classification-and-regresstion-tree-算法-–1984" class="headerlink" title="5. CART (classification and regresstion tree)算法 –1984"></a>5. CART (classification and regresstion tree)算法 –1984</h1><p><strong>CART</strong>由<u>特征选择、树的生成及剪枝</u>组成的<strong>二叉树</strong>。既可以用于<strong>分类</strong>也可以用于<strong>回归</strong>。左分支是“是”的分支,右分支是“否”的分支。</p>
<p><strong>CART</strong>是在给定输入随机变量 $X$ 条件下输出随机变量 $Y$ 的<strong>条件概率分布</strong>的学习方法。</p>
<h2 id="5-1-CART的生成"><a href="#5-1-CART的生成" class="headerlink" title="5.1 CART的生成"></a>5.1 CART的生成</h2><p>回归树:平方误差最小化准则</p>
<p>分类树:基尼指数(Gini index) 最小化准则</p>
<ol>
<li><strong>回归树的生成</strong></li>
</ol>
<p>每个输入空间 $R_m$ 的标签值为,所有输入空间中实例的标签的均值<br>$$<br>\hat{c} = ave(y_i|x_i \in R_m)<br>$$</p>
<p>算法 5.1 (最小二乘回归树生成算法)<br>输入: 训练数据集 $D$;</p>
<p>输出: 回归树 $f(x)$ </p>
<p>在训练数据集所在的输入空间中,递归地将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树:</p>
<p>(1)选择最优切分变量 $j$ 与切分点 $s$ ,求解<br>$$<br>\min_{j, s}\left[\min_{c_{1}} \sum_{x_{i} \in R_{1}(j, s)}\left(y_{i}-c_{1}\right)^{2}+\min_{c_{2}} \sum_{x_{i} \in R_{2}(j, s)}\left(y_{i}-c_{2}\right)^{2}\right]<br>$$<br><strong>遍历</strong>变量 $j,$ 对固定的切分变量 $j$ 扫描切分点 $s$ ,选择使上述目标函数达到最小值的对<br>$(j, s)$</p>
<p>(2)用选定的对 $(j, s)$ 划分区域并决定相应的输出值:<br>$$<br>\begin{array}{c}<br>R_{1}(j, s)=\lbrace x \mid x^{(j)} \leqslant s\rbrace, \quad R_{2}(j, s)=\lbrace x \mid x^{(j)}>s\rbrace \\<br>\hat{c}<em>{m}=\frac{1}{N</em>{m}} \sum_{x_{i} \in R_{m}(j, s)} y_{i}, \quad x \in R_{m}, \quad m=1,2<br>\end{array}<br>$$<br>(3)继续对两个子区域调用步骤 $(1),(2),$ 直至满足停止条件。</p>
<p>(4)将输入空间划分为 $M$ 个区域 $R_{1}, R_{2}, \cdots, R_{M}$ ,生成决策树:<br>$$<br>f(x)=\sum_{m=1}^{M} \hat{c}<em>{m} I\left(x \in R</em>{m}\right)<br>$$<br>当数据量大的时候,遍历的效率不高。</p>
<ol start="2">
<li><strong>分类树的生成</strong></li>
</ol>
<p>分类树用基尼指数选择最优特征, 同时决定该特征的最优二值切分点。</p>
<p><strong>基尼指数</strong>的定义<br>$$<br>\text{Gini}(p) = \sum_{k = 1}^{K} p_k (1 - p_k) = 1 - \sum_{k=1}^{K}p_k^2<br>$$<br>对于给定样本集合D, $p_k = \frac{|C_K|}{|D|}$</p>
<p>基尼指数表示集合D的不确定性,数值越大,样本集合的不确定性就越大。这一点与熵类似。</p>
<p>算法 $5.6(\mathrm{CART}$ 生成算法) </p>
<p>输入: 训练数据集 $D$ ,停止计算的条件; </p>
<p>输出: CART 决策树。</p>
<p>根据训练数据集,从根结点开始,递归地对每个结点进行以下操作,构建二叉决策树:</p>
<p>(1)设结点的训练数据集为 $D,$ 计算现有特征对该数据集的基尼指数。此时,对每一个特征 $A$ ,对其可能取的每个值 $a$,根据样本点对 $A=a$ 的测试为“是”或“否” 将 $D$ 分割成 $D_{1}$ 和 $D_{2}$ 两部分, 计算 $A=a$ 时的基尼指数。</p>
<p>(2)在所有可能的特征 $A$ 以及它们所有可能的切分点 $a$ 中,选择基尼指数最小的特征及其对应的切分点作为最优特征与最优切分点。依最优特征与最优切分点,从现结点生成两个子结点,将训练数据集依特征分配到两个子结点中去。</p>
<p>(3)对两个子结点递归地调用 $(1),(2),$ 直至满足停止条件。</p>
<p>(4) 生成 CART 决策树。</p>
<h2 id="5-2-CART剪枝"><a href="#5-2-CART剪枝" class="headerlink" title="5.2 CART剪枝"></a>5.2 CART剪枝</h2><p>具体地, 从整体树 $T_0$ 开始剪枝。 对 $T_0$ 的任意内部结点 $t$ ,以 $t$ 为单结点树的损失函数是<br>$$<br>C_{\alpha}(t) = C(t) + \alpha<br>$$<br>以 $t$ 为根节点的子树 $T_t$ 的损失函数是<br>$$<br>C_{\alpha}(T_t) = C(T_t) + \alpha |T_t|<br>$$<br>当 $\alpha = 0$ 及 $\alpha$ 充分小时,有<br>$$<br>C_{\alpha}(T_t) < C_{\alpha}(t)<br>$$<br>当 $\alpha$ 增大时,在某一 $\alpha$ 有<br>$$<br>C_{\alpha}(T_t) = C_{\alpha}(t)<br>$$<br>所以只要 $\alpha =\frac{C(t)-C\left(T_{t}\right)}{\left|T_{t}\right|-1}$,就可以选择结点更少的t。</p>
<p>而Beriman等人证明:可以用递归的方法对树进行剪枝。将 $\alpha$ 逐渐增大,$0 = \alpha_0 < \alpha_1 <\cdots <\alpha_n < +\infty$ ,产生一系列的区间 $[\alpha_i, \alpha_{i+1}), i = 0,1,…,n$ ;剪枝得到的子树序列对应着区间 $\alpha \in [\alpha_i, \alpha_{i+1})$,的最优子树序列 $\lbrace T_0, T_1, \cdots, T_n \rbrace$。然后利用交叉验证挑出一个最好的</p>
<p><strong>算法 5.7 ($\mathbf{C A R T}$ 剪枝算法)</strong></p>
<p>输入: $\mathrm{CART}$ 算法生成的决策树 $T_{0} ;$</p>
<p>输出:最优决策树 $T_{\alpha^{\circ}}$</p>
<p>(1)设 $k=0, T=T_{0}$ 。</p>
<p>(2)设 $\alpha=+\infty$ 。</p>
<p>(3)自下而上地对各内部结点 $t$ 计算 $C\left(T_{t}\right),\left|T_{t}\right|$ 以及<br>$$<br>\begin{aligned}<br>g(t) &=\frac{C(t)-C\left(T_{t}\right)}{\left|T_{t}\right|-1} \\<br>\alpha &=\min (\alpha, g(t))<br>\end{aligned}<br>$$<br>这里, $T_{t}$ 表示以 $t$ 为根结点的子树, $C\left(T_{t}\right)$ 是对训练数据的预测误差, $\left|T_{t}\right|$ 是 $T_{t}$ 的叶 结点个数。</p>
<p>(4)对 $g(t)=\alpha$ 的内部结点 $t$ 进行剪枝,并对叶结点 $t$ 以多数表决法决定其类,<br>得到树 $T$ 。</p>
<p>(5)设 $k=k+1, \alpha_{k}=\alpha, T_{k}=T_{\circ}$</p>
<p>(6)如果 $T_{k}$ 不是由根结点及两个叶结点构成的树,则回到步骤 (2)$;$ 否则令<br>$T_{k}=T_{n}$</p>
<p>(7)采用交叉验证法在子树序列 $T_{0}, T_{1}, \cdots, T_{n}$ 中选取最优子树 $T_{\alpha} \circ$</p>
<h1 id="6-代码实现"><a href="#6-代码实现" class="headerlink" title="6. 代码实现"></a>6. 代码实现</h1><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line">%matplotlib inline</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> sklearn.datasets <span class="keyword">import</span> load_iris</span><br><span class="line"><span class="keyword">from</span> sklearn.model_selection <span class="keyword">import</span> train_test_split</span><br><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> Counter</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> math</span><br><span class="line"><span class="keyword">from</span> math <span class="keyword">import</span> log</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 经验熵</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">entropy</span>(<span class="params">datasets</span>):</span></span><br><span class="line"> data_length = <span class="built_in">len</span>(datasets)</span><br><span class="line"> label_count = {}</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(data_length):</span><br><span class="line"> label = datasets[i][-<span class="number">1</span>]</span><br><span class="line"> <span class="keyword">if</span> label <span class="keyword">not</span> <span class="keyword">in</span> label_count:</span><br><span class="line"> label_count[label] = <span class="number">0</span></span><br><span class="line"> label_count[label] += <span class="number">1</span></span><br><span class="line"> entropy = - <span class="built_in">sum</span>([(p / data_length) * log(p / data_length, <span class="number">2</span>)</span><br><span class="line"> <span class="keyword">for</span> p <span class="keyword">in</span> label_count.values()])</span><br><span class="line"> <span class="keyword">return</span> entropy</span><br><span class="line"></span><br><span class="line"><span class="comment"># 条件经验熵</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">cond_entropy</span>(<span class="params">datasets, axis=<span class="number">0</span></span>):</span></span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> 求数据集datasets中第axis列的条件经验熵</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> data_length = <span class="built_in">len</span>(datasets)</span><br><span class="line"> feature_sets = {}</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(data_length):</span><br><span class="line"> feature = datasets[i][axis]</span><br><span class="line"> <span class="keyword">if</span> feature <span class="keyword">not</span> <span class="keyword">in</span> feature_sets:</span><br><span class="line"> feature_sets[feature] = []</span><br><span class="line"> feature_sets[feature].append(datasets[i])</span><br><span class="line"> cond_entropy = <span class="built_in">sum</span>([(<span class="built_in">len</span>(p) / data_length) * entropy(p)</span><br><span class="line"> <span class="keyword">for</span> p <span class="keyword">in</span> feature_sets.values()])</span><br><span class="line"> <span class="keyword">return</span> cond_entropy</span><br><span class="line"></span><br><span class="line"><span class="comment"># 信息增益 = 经验熵 - 条件经验熵</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">info_gain</span>(<span class="params">entropy, cond_entropy</span>):</span></span><br><span class="line"> <span class="keyword">return</span> entropy - cond_entropy</span><br><span class="line"></span><br><span class="line"><span class="comment"># 利用信息增益选择根节点</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">info_gain_train</span>(<span class="params">datasets</span>):</span></span><br><span class="line"> data_dim = <span class="built_in">len</span>(datasets[<span class="number">0</span>]) - <span class="number">1</span></span><br><span class="line"> ent = entropy(datasets)</span><br><span class="line"> info_gain_feature = []</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(data_dim):</span><br><span class="line"> i_info_gain = info_gain(ent, cond_entropy(datasets, axis=i))</span><br><span class="line"> info_gain_feature.append(i_info_gain)</span><br><span class="line"> print(<span class="string">'特征{}的信息增益为:{}'</span>.<span class="built_in">format</span>(i + <span class="number">1</span>, i_info_gain))</span><br><span class="line"> best_feature = <span class="built_in">max</span>(info_gain_feature)</span><br><span class="line"> <span class="keyword">return</span> <span class="string">'特征{}的信息增益最大,选择根节点特征'</span>.<span class="built_in">format</span>(info_gain_feature.index(best_feature) + <span class="number">1</span>)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">datasets = [[<span class="string">'青年'</span>, <span class="string">'否'</span>, <span class="string">'否'</span>, <span class="string">'一般'</span>, <span class="string">'否'</span>],</span><br><span class="line"> [<span class="string">'青年'</span>, <span class="string">'否'</span>, <span class="string">'否'</span>, <span class="string">'好'</span>, <span class="string">'否'</span>],</span><br><span class="line"> [<span class="string">'青年'</span>, <span class="string">'是'</span>, <span class="string">'否'</span>, <span class="string">'好'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'青年'</span>, <span class="string">'是'</span>, <span class="string">'是'</span>, <span class="string">'一般'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'青年'</span>, <span class="string">'否'</span>, <span class="string">'否'</span>, <span class="string">'一般'</span>, <span class="string">'否'</span>],</span><br><span class="line"> [<span class="string">'中年'</span>, <span class="string">'否'</span>, <span class="string">'否'</span>, <span class="string">'一般'</span>, <span class="string">'否'</span>],</span><br><span class="line"> [<span class="string">'中年'</span>, <span class="string">'否'</span>, <span class="string">'否'</span>, <span class="string">'好'</span>, <span class="string">'否'</span>],</span><br><span class="line"> [<span class="string">'中年'</span>, <span class="string">'是'</span>, <span class="string">'是'</span>, <span class="string">'好'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'中年'</span>, <span class="string">'否'</span>, <span class="string">'是'</span>, <span class="string">'非常好'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'中年'</span>, <span class="string">'否'</span>, <span class="string">'是'</span>, <span class="string">'非常好'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'老年'</span>, <span class="string">'否'</span>, <span class="string">'是'</span>, <span class="string">'非常好'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'老年'</span>, <span class="string">'否'</span>, <span class="string">'是'</span>, <span class="string">'好'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'老年'</span>, <span class="string">'是'</span>, <span class="string">'否'</span>, <span class="string">'好'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'老年'</span>, <span class="string">'是'</span>, <span class="string">'否'</span>, <span class="string">'非常好'</span>, <span class="string">'是'</span>],</span><br><span class="line"> [<span class="string">'老年'</span>, <span class="string">'否'</span>, <span class="string">'否'</span>, <span class="string">'一般'</span>, <span class="string">'否'</span>],</span><br><span class="line"> ]</span><br><span class="line">labels = [<span class="string">u'年龄'</span>, <span class="string">u'有工作'</span>, <span class="string">u'有自己的房子'</span>, <span class="string">u'信贷情况'</span>, <span class="string">u'类别'</span>]</span><br><span class="line">train_data = pd.DataFrame(datasets, columns=labels)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">info_gain_train(np.array(datasets))</span><br></pre></td></tr></table></figure>
<p><strong>ID3算法</strong></p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br><span class="line">69</span><br><span class="line">70</span><br><span class="line">71</span><br><span class="line">72</span><br><span class="line">73</span><br><span class="line">74</span><br><span class="line">75</span><br><span class="line">76</span><br><span class="line">77</span><br><span class="line">78</span><br><span class="line">79</span><br><span class="line">80</span><br><span class="line">81</span><br><span class="line">82</span><br><span class="line">83</span><br><span class="line">84</span><br><span class="line">85</span><br><span class="line">86</span><br><span class="line">87</span><br><span class="line">88</span><br><span class="line">89</span><br><span class="line">90</span><br><span class="line">91</span><br><span class="line">92</span><br><span class="line">93</span><br><span class="line">94</span><br><span class="line">95</span><br><span class="line">96</span><br><span class="line">97</span><br><span class="line">98</span><br><span class="line">99</span><br><span class="line">100</span><br><span class="line">101</span><br><span class="line">102</span><br><span class="line">103</span><br><span class="line">104</span><br><span class="line">105</span><br><span class="line">106</span><br><span class="line">107</span><br><span class="line">108</span><br><span class="line">109</span><br><span class="line">110</span><br><span class="line">111</span><br><span class="line">112</span><br><span class="line">113</span><br><span class="line">114</span><br><span class="line">115</span><br><span class="line">116</span><br><span class="line">117</span><br><span class="line">118</span><br><span class="line">119</span><br><span class="line">120</span><br><span class="line">121</span><br><span class="line">122</span><br><span class="line">123</span><br><span class="line">124</span><br><span class="line">125</span><br><span class="line">126</span><br><span class="line">127</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 定义节点类 二叉树</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Node</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, root=<span class="literal">True</span>, label=<span class="literal">None</span>, feature_name=<span class="literal">None</span>, feature=<span class="literal">None</span></span>):</span></span><br><span class="line"> self.root = root</span><br><span class="line"> self.label = label</span><br><span class="line"> self.feature_name = feature_name</span><br><span class="line"> self.feature = feature</span><br><span class="line"> self.tree = {}</span><br><span class="line"> self.result = {</span><br><span class="line"> <span class="string">'label:'</span>: self.label,</span><br><span class="line"> <span class="string">'feature_name'</span>: self.feature_name,</span><br><span class="line"> <span class="string">'tree'</span>: self.tree</span><br><span class="line"> }</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__repr__</span>(<span class="params">self</span>):</span></span><br><span class="line"> <span class="keyword">return</span> <span class="string">'{}'</span>.<span class="built_in">format</span>(self.result)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">add_node</span>(<span class="params">self, val, node</span>):</span></span><br><span class="line"> self.tree[val] = node</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">predict</span>(<span class="params">self, features</span>):</span></span><br><span class="line"> <span class="keyword">if</span> self.root <span class="keyword">is</span> <span class="literal">True</span>:</span><br><span class="line"> <span class="keyword">return</span> self.label</span><br><span class="line"> <span class="keyword">return</span> self.tree[features[self.feature]].predict(features)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">DTree</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, epsilon=<span class="number">0.1</span></span>):</span></span><br><span class="line"> self.epsilon = epsilon</span><br><span class="line"> self._tree = {}</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 熵</span></span><br><span class="line"><span class="meta"> @staticmethod</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">calc_ent</span>(<span class="params">datasets</span>):</span></span><br><span class="line"> data_length = <span class="built_in">len</span>(datasets)</span><br><span class="line"> label_count = {}</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(data_length):</span><br><span class="line"> label = datasets[i][-<span class="number">1</span>]</span><br><span class="line"> <span class="keyword">if</span> label <span class="keyword">not</span> <span class="keyword">in</span> label_count:</span><br><span class="line"> label_count[label] = <span class="number">0</span></span><br><span class="line"> label_count[label] += <span class="number">1</span></span><br><span class="line"> ent = -<span class="built_in">sum</span>([(p / data_length) * log(p / data_length, <span class="number">2</span>)</span><br><span class="line"> <span class="keyword">for</span> p <span class="keyword">in</span> label_count.values()])</span><br><span class="line"> <span class="keyword">return</span> ent</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 经验条件熵</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">cond_ent</span>(<span class="params">self, datasets, axis=<span class="number">0</span></span>):</span></span><br><span class="line"> data_length = <span class="built_in">len</span>(datasets)</span><br><span class="line"> feature_sets = {}</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(data_length):</span><br><span class="line"> feature = datasets[i][axis]</span><br><span class="line"> <span class="keyword">if</span> feature <span class="keyword">not</span> <span class="keyword">in</span> feature_sets:</span><br><span class="line"> feature_sets[feature] = []</span><br><span class="line"> feature_sets[feature].append(datasets[i])</span><br><span class="line"> cond_ent = <span class="built_in">sum</span>([(<span class="built_in">len</span>(p) / data_length) * self.calc_ent(p)</span><br><span class="line"> <span class="keyword">for</span> p <span class="keyword">in</span> feature_sets.values()])</span><br><span class="line"> <span class="keyword">return</span> cond_ent</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 信息增益</span></span><br><span class="line"><span class="meta"> @staticmethod</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">info_gain</span>(<span class="params">ent, cond_ent</span>):</span></span><br><span class="line"> <span class="keyword">return</span> ent - cond_ent</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">info_gain_train</span>(<span class="params">self, datasets</span>):</span></span><br><span class="line"> count = <span class="built_in">len</span>(datasets[<span class="number">0</span>]) - <span class="number">1</span></span><br><span class="line"> ent = self.calc_ent(datasets)</span><br><span class="line"> best_feature = []</span><br><span class="line"> <span class="keyword">for</span> c <span class="keyword">in</span> <span class="built_in">range</span>(count):</span><br><span class="line"> c_info_gain = self.info_gain(ent, self.cond_ent(datasets, axis=c))</span><br><span class="line"> best_feature.append((c, c_info_gain))</span><br><span class="line"> <span class="comment"># 比较大小</span></span><br><span class="line"> best_ = <span class="built_in">max</span>(best_feature, key=<span class="keyword">lambda</span> x: x[-<span class="number">1</span>])</span><br><span class="line"> <span class="keyword">return</span> best_</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">train</span>(<span class="params">self, train_data</span>):</span></span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> input:数据集D(DataFrame格式),特征集A,阈值eta</span></span><br><span class="line"><span class="string"> output:决策树T</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> _, y_train, features = train_data.iloc[:, :</span><br><span class="line"> -<span class="number">1</span>], train_data.iloc[:,</span><br><span class="line"> -<span class="number">1</span>], train_data.columns[:</span><br><span class="line"> -<span class="number">1</span>]</span><br><span class="line"> <span class="comment"># 1,若D中实例属于同一类Ck,则T为单节点树,并将类Ck作为结点的类标记,返回T</span></span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">len</span>(y_train.value_counts()) == <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">return</span> Node(root=<span class="literal">True</span>, label=y_train.iloc[<span class="number">0</span>])</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 2, 若A为空,则T为单节点树,将D中实例树最大的类Ck作为该节点的类标记,返回T</span></span><br><span class="line"> <span class="keyword">if</span> <span class="built_in">len</span>(features) == <span class="number">0</span>:</span><br><span class="line"> <span class="keyword">return</span> Node(</span><br><span class="line"> root=<span class="literal">True</span>,</span><br><span class="line"> label=y_train.value_counts().sort_values(</span><br><span class="line"> ascending=<span class="literal">False</span>).index[<span class="number">0</span>])</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 3,计算最大信息增益 同5.1,Ag为信息增益最大的特征</span></span><br><span class="line"> max_feature, max_info_gain = self.info_gain_train(np.array(train_data))</span><br><span class="line"> max_feature_name = features[max_feature]</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 4,Ag的信息增益小于阈值eta,则置T为单节点树,并将D中是实例数最大的类Ck作为该节点的类标记,返回T</span></span><br><span class="line"> <span class="keyword">if</span> max_info_gain < self.epsilon:</span><br><span class="line"> <span class="keyword">return</span> Node(</span><br><span class="line"> root=<span class="literal">True</span>,</span><br><span class="line"> label=y_train.value_counts().sort_values(</span><br><span class="line"> ascending=<span class="literal">False</span>).index[<span class="number">0</span>])</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 5,构建Ag子集</span></span><br><span class="line"> node_tree = Node(</span><br><span class="line"> root=<span class="literal">False</span>, feature_name=max_feature_name, feature=max_feature)</span><br><span class="line"></span><br><span class="line"> feature_list = train_data[max_feature_name].value_counts().index</span><br><span class="line"> <span class="keyword">for</span> f <span class="keyword">in</span> feature_list:</span><br><span class="line"> sub_train_df = train_data.loc[train_data[max_feature_name] ==</span><br><span class="line"> f].drop([max_feature_name], axis=<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 6, 递归生成树</span></span><br><span class="line"> sub_tree = self.train(sub_train_df)</span><br><span class="line"> node_tree.add_node(f, sub_tree)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># pprint.pprint(node_tree.tree)</span></span><br><span class="line"> <span class="keyword">return</span> node_tree</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">fit</span>(<span class="params">self, train_data</span>):</span></span><br><span class="line"> self._tree = self.train(train_data)</span><br><span class="line"> <span class="keyword">return</span> self._tree</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">predict</span>(<span class="params">self, X_test</span>):</span></span><br><span class="line"> <span class="keyword">return</span> self._tree.predict(X_test)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">datasets, labels = create_data()</span><br><span class="line">data_df = pd.DataFrame(datasets, columns=labels)</span><br><span class="line">dt = DTree()</span><br><span class="line">tree = dt.fit(data_df)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">tree</span><br><span class="line"><span class="comment"># {'label:': None, 'feature_name': '有自己的房子', 'tree': {'否': {'label:': None, 'feature_name': '有工作', 'tree': {'否': {'label:': '否', 'feature_name': None, 'tree': {}}, '是': {'label:': '是', 'feature_name': None, 'tree': {}}}}, '是': {'label:': '是', 'feature_name': None, 'tree': {}}}}</span></span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">dt.predict([<span class="string">'老年'</span>, <span class="string">'否'</span>, <span class="string">'否'</span>, <span class="string">'一般'</span>])</span><br><span class="line"><span class="comment"># '否'</span></span><br></pre></td></tr></table></figure>
<p><strong>scikit-learn实例</strong></p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">## sklearn里并没有实现后剪枝,只有预剪枝操作,例如设置树的最大深度max_depth</span></span><br><span class="line"></span><br><span class="line">iris = load_iris()</span><br><span class="line">df = pd.DataFrame(iris.data, columns=iris.feature_names)</span><br><span class="line">df[<span class="string">'label'</span>] = iris.target</span><br><span class="line">df.columns = [</span><br><span class="line"> <span class="string">'sepal length'</span>, <span class="string">'sepal width'</span>, <span class="string">'petal lenght'</span>, <span class="string">'petal width'</span>, <span class="string">'label'</span></span><br><span class="line">]</span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">shuffle</span>(<span class="params">X,Y</span>):</span></span><br><span class="line"> randomize = np.arange(<span class="built_in">len</span>(X))</span><br><span class="line"> np.random.shuffle(randomize)</span><br><span class="line"> <span class="keyword">return</span> X[randomize], Y[randomize]</span><br><span class="line"></span><br><span class="line">data = np.array(df)</span><br><span class="line">X, Y = data[:, :<span class="number">2</span>], data[:,-<span class="number">1</span>]</span><br><span class="line">X, y = shuffle(X,Y)</span><br><span class="line">X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = <span class="number">0.3</span>)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> sklearn.tree <span class="keyword">import</span> DecisionTreeClassifier</span><br><span class="line">clf = DecisionTreeClassifier()</span><br><span class="line">clf.fit(X_train, Y_train)</span><br><span class="line">clf.score(X_test,Y_test)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> sklearn.tree <span class="keyword">import</span> DecisionTreeClassifier</span><br><span class="line"><span class="keyword">from</span> sklearn <span class="keyword">import</span> preprocessing</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np </span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> sklearn <span class="keyword">import</span> tree</span><br><span class="line"><span class="keyword">import</span> graphviz</span><br><span class="line"></span><br><span class="line">features = [<span class="string">"年龄"</span>, <span class="string">"有工作"</span>, <span class="string">"有自己的房子"</span>, <span class="string">"信贷情况"</span>]</span><br><span class="line">X_train = pd.DataFrame([</span><br><span class="line"> [<span class="string">"青年"</span>, <span class="string">"否"</span>, <span class="string">"否"</span>, <span class="string">"一般"</span>],</span><br><span class="line"> [<span class="string">"青年"</span>, <span class="string">"否"</span>, <span class="string">"否"</span>, <span class="string">"好"</span>],</span><br><span class="line"> [<span class="string">"青年"</span>, <span class="string">"是"</span>, <span class="string">"否"</span>, <span class="string">"好"</span>],</span><br><span class="line"> [<span class="string">"青年"</span>, <span class="string">"是"</span>, <span class="string">"是"</span>, <span class="string">"一般"</span>],</span><br><span class="line"> [<span class="string">"青年"</span>, <span class="string">"否"</span>, <span class="string">"否"</span>, <span class="string">"一般"</span>],</span><br><span class="line"> [<span class="string">"中年"</span>, <span class="string">"否"</span>, <span class="string">"否"</span>, <span class="string">"一般"</span>],</span><br><span class="line"> [<span class="string">"中年"</span>, <span class="string">"否"</span>, <span class="string">"否"</span>, <span class="string">"好"</span>],</span><br><span class="line"> [<span class="string">"中年"</span>, <span class="string">"是"</span>, <span class="string">"是"</span>, <span class="string">"好"</span>],</span><br><span class="line"> [<span class="string">"中年"</span>, <span class="string">"否"</span>, <span class="string">"是"</span>, <span class="string">"非常好"</span>],</span><br><span class="line"> [<span class="string">"中年"</span>, <span class="string">"否"</span>, <span class="string">"是"</span>, <span class="string">"非常好"</span>],</span><br><span class="line"> [<span class="string">"老年"</span>, <span class="string">"否"</span>, <span class="string">"是"</span>, <span class="string">"非常好"</span>],</span><br><span class="line"> [<span class="string">"老年"</span>, <span class="string">"否"</span>, <span class="string">"是"</span>, <span class="string">"好"</span>],</span><br><span class="line"> [<span class="string">"老年"</span>, <span class="string">"是"</span>, <span class="string">"否"</span>, <span class="string">"好"</span>],</span><br><span class="line"> [<span class="string">"老年"</span>, <span class="string">"是"</span>, <span class="string">"否"</span>, <span class="string">"非常好"</span>],</span><br><span class="line"> [<span class="string">"老年"</span>, <span class="string">"否"</span>, <span class="string">"否"</span>, <span class="string">"一般"</span>]</span><br><span class="line">])</span><br><span class="line">y_train = pd.DataFrame([<span class="string">"否"</span>, <span class="string">"否"</span>, <span class="string">"是"</span>, <span class="string">"是"</span>, <span class="string">"否"</span>, </span><br><span class="line"> <span class="string">"否"</span>, <span class="string">"否"</span>, <span class="string">"是"</span>, <span class="string">"是"</span>, <span class="string">"是"</span>, </span><br><span class="line"> <span class="string">"是"</span>, <span class="string">"是"</span>, <span class="string">"是"</span>, <span class="string">"是"</span>, <span class="string">"否"</span>])</span><br><span class="line"></span><br><span class="line"><span class="comment">#数据预处理</span></span><br><span class="line">le_x = preprocessing.LabelEncoder()</span><br><span class="line">le_x.fit(np.unique(X_train)) <span class="comment"># 找出X_tran中的特征值有哪些</span></span><br><span class="line"><span class="comment"># array(['一般', '中年', '否', '好', '是', '老年', '青年', '非常好'], dtype=object)</span></span><br><span class="line"><span class="comment"># 这样 0代表‘一般’,1代表‘中年’,....</span></span><br><span class="line">X_train = X_train.apply(le_x.transform) <span class="comment"># apply()可以传入一个函数</span></span><br><span class="line">le_y = preprocessing.LabelEncoder()</span><br><span class="line">le_y.fit(np.unique(y_train))</span><br><span class="line">y_train = y_train.apply(le_y.transform)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 调用sklearn.DT训练模型</span></span><br><span class="line">model_tree = DecisionTreeClassifier()</span><br><span class="line">model_tree.fit(X_train, y_train)</span><br><span class="line"></span><br></pre></td></tr></table></figure>
<h1 id="Reference"><a href="#Reference" class="headerlink" title="Reference"></a>Reference</h1><p>[1]. <a target="_blank" rel="noopener" href="https://github.com/fengdu78/lihang-code/tree/master/%E7%AC%AC05%E7%AB%A0%20%E5%86%B3%E7%AD%96%E6%A0%91">https://github.com/fengdu78/lihang-code/</a></p>
<p>[2]. 《统计学习与方法》 – 李航</p>
</div>
<footer class="post-footer">
<div class="post-eof"></div>
</footer>
</article>
<article itemscope itemtype="http://schema.org/Article" class="post-block" lang="zh-CN">
<link itemprop="mainEntityOfPage" href="http://example.com/2021/03/15/Naive-Bayes/">
<span hidden itemprop="author" itemscope itemtype="http://schema.org/Person">
<meta itemprop="image" content="/images/avatar.gif">
<meta itemprop="name" content="Guangshan Shui">
<meta itemprop="description" content="">
</span>
<span hidden itemprop="publisher" itemscope itemtype="http://schema.org/Organization">
<meta itemprop="name" content="水广山">
</span>
<header class="post-header">
<h2 class="post-title" itemprop="name headline">
<a href="/2021/03/15/Naive-Bayes/" class="post-title-link" itemprop="url">Naive Bayes</a>
</h2>
<div class="post-meta">
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar"></i>
</span>
<span class="post-meta-item-text">发表于</span>
<time title="创建时间:2021-03-15 16:47:59 / 修改时间:17:29:20" itemprop="dateCreated datePublished" datetime="2021-03-15T16:47:59+08:00">2021-03-15</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-folder"></i>
</span>
<span class="post-meta-item-text">分类于</span>
<span itemprop="about" itemscope itemtype="http://schema.org/Thing">
<a href="/categories/ML/" itemprop="url" rel="index"><span itemprop="name">ML</span></a>
</span>
</span>
</div>
</header>
<div class="post-body" itemprop="articleBody">
<p>生成模型</p>
<p>对于给定数据集,首先基于<strong>特征条件独立</strong>假设学习输入输出的联合概率分布; 然后基于此模型, 对给定的输入 $x$, 利用贝叶斯定理求出后验概率最大的输出 $y$。 朴素贝叶斯法实现简单, 学习与预测的效率都很高, 是一种常用的方法。</p>
<h1 id="1-学习与分类"><a href="#1-学习与分类" class="headerlink" title="1. 学习与分类"></a>1. 学习与分类</h1><ul>
<li>首先学习先验概率分布以及条件概率分布。</li>
</ul>
<p>$$<br>P(Y=c_k),\quad k=1,2,…k<br>$$</p>
<p>$$<br>P(X=x \mid Y=c_{k})=P(X^{(1)}=x^{(1)}, \cdots, X^{(n)}=x^{(n)} \mid Y=c_{k}), \quad k=1,2, \cdots, K<br>$$</p>
<p>于是学到联合概率分布 $P(X,Y)$</p>
<p>由于朴素贝叶斯假设条件独立,于是条件概率为<br>$$<br>\begin{aligned}<br>P(X=x \mid Y=c_{k}) &=P(X^{(1)}=x^{(1)}, \cdots, X^{(n)}=x^{(n)} \mid Y=c_{k}) \\<br>&=\prod_{j=1}^{n} P(X^{(j)}=x^{(j)} \mid Y=c_{k})<br>\end{aligned}<br>$$</p>
<ul>
<li><p><strong>参数学习(极大似然)</strong><br>$$<br>P\left(Y=c_{k}\right)=\frac{\sum_{i=1}^{N} I\left(y_{i}=c_{k}\right)}{N}, \quad k=1,2, \cdots, K<br>$$</p>
<p>$$<br>\begin{array}{l}<br>P\left(X^{(j)}=a_{j l} \mid Y=c_{k}\right)=\frac{\sum_{i=1}^{N} I\left(x_{i}^{(j)}=a_{j l}, y_{i}=c_{k}\right)}{\sum_{i=1}^{N} I\left(y_{i}=c_{k}\right)} \\<br>j=1,2, \cdots, n ; \quad l=1,2, \cdots, S_{j} ; \quad k=1,2, \cdots, K<br>\end{array}<br>$$</p>
</li>
<li><p><strong>分类器</strong></p>
</li>
</ul>
<p>$$<br>y=f(x)=\arg \max_{c_{k}} \frac{P\left(Y=c_{k}\right) \prod_{j} P\left(X^{(j)}=x^{(j)} \mid Y=c_{k}\right)}{\sum_{k} P\left(Y=c_{k}\right) \prod_{j} P\left(X^{(j)}=x^{(j)} \mid Y=c_{k}\right)}<br>$$</p>
<h1 id="2-原理-后验概率最大化-期望风险最小化"><a href="#2-原理-后验概率最大化-期望风险最小化" class="headerlink" title="2. 原理(后验概率最大化 == 期望风险最小化)"></a>2. 原理(后验概率最大化 == 期望风险最小化)</h1><p>假设选择0-1损失函数<br>$$<br>L(Y,f(X)) = \begin{cases}<br>1, & Y \neq f(X) \\<br>0, & Y = f(X)<br>\end{cases}<br>$$<br>期望风险函数为<br>$$<br>R_{\exp }(f)=E[L(Y, f(X))]<br>$$<br>期望是对联合分布 $P(X,Y)$取的,由此条件期望<br>$$<br>R_{\exp }(f)=E_{X} \sum_{k=1}^{K}\left[L\left(c_{k}, f(X)\right)\right] P\left(c_{k} \mid X\right)<br>$$<br>为了使期望风险最小化,只需对 $X=x$ 逐个极小化,<br>$$<br>\begin{aligned}<br>f(x) &=\arg \min_{y \in \mathcal{Y}} \sum_{k=1}^{K} L\left(c_{k}, y\right) P\left(c_{k} \mid X=x\right) \\<br>&=\arg \min_{y \in \mathcal{Y}} \sum_{k=1}^{K} P\left(y \neq c_{k} \mid X=x\right) \\<br>&=\arg \min_{y \in \mathcal{Y}}\left(1-P\left(y=c_{k} \mid X=x\right)\right) \\<br>&=\arg \max_{y \in \mathcal{Y}} P\left(y=c_{k} \mid X=x\right)<br>\end{aligned}<br>$$</p>
<h1 id="3-贝叶斯估计-拉普拉斯平滑"><a href="#3-贝叶斯估计-拉普拉斯平滑" class="headerlink" title="3. 贝叶斯估计(拉普拉斯平滑)"></a>3. 贝叶斯估计(拉普拉斯平滑)</h1><p>用极大似然估计可能会出现所要估计的概率值为0的情况。 这时会影响到后验概率的计算结果, 使分类产生偏差。 解决这一问题的方法是采用贝叶斯估计。具体的条件概率的贝叶斯估计是<br>$$<br>P_{\lambda}\left(X^{(j)}=a_{j l} \mid Y=c_{k}\right)=\frac{\sum_{i=1}^{N} I\left(x_{i}^{(j)}=a_{j l}, y_{i}=c_{k}\right)+\lambda}{\sum_{i=1}^{N} I\left(y_{i}=c_{k}\right)+S_{j} \lambda}<br>$$<br>$l=1,2,\cdots,S_j,k=1,2,…,K$ ,$\lambda \geq 0$,等价于在随机变量各个取值上赋予一个正数。$\lambda = 0$ 是极大似然估计。常取 $\lambda = 1$,称为拉普拉斯平滑。检验可知,概率和为1,并且每个概率大于0</p>
<p>先验概率的贝叶斯估计是<br>$$<br>P_{\lambda}\left(Y=c_{k}\right)=\frac{\sum_{i=1}^{N} I\left(y_{i}=c_{k}\right)+\lambda}{N+K \lambda}<br>$$</p>
<h1 id="4-代码实现"><a href="#4-代码实现" class="headerlink" title="4. 代码实现"></a>4. 代码实现</h1><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line">%matplotlib inline</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> sklearn.datasets <span class="keyword">import</span> load_iris</span><br><span class="line"><span class="keyword">from</span> sklearn.model_selection <span class="keyword">import</span> train_test_split</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> Counter</span><br><span class="line"><span class="keyword">import</span> math</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># import data</span></span><br><span class="line">iris = load_iris()</span><br><span class="line">df = pd.DataFrame(iris.data, columns=iris.feature_names)</span><br><span class="line">df[<span class="string">'labels'</span>] = iris.target</span><br><span class="line">df.columns = [</span><br><span class="line"> <span class="string">'sepal length'</span>, <span class="string">'sepal width'</span>, <span class="string">'petal length'</span>, <span class="string">'petal width'</span>, <span class="string">'label'</span></span><br><span class="line">]</span><br><span class="line">data = np.array(df)</span><br><span class="line">X, Y = data[:, :-<span class="number">1</span>], data[:, -<span class="number">1</span>]</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">_shuffle</span>(<span class="params">X, Y</span>):</span></span><br><span class="line"> randomize = np.arange(<span class="built_in">len</span>(X))</span><br><span class="line"> np.random.shuffle(randomize)</span><br><span class="line"> <span class="keyword">return</span> X[randomize], Y[randomize]</span><br><span class="line"></span><br><span class="line">X, Y = _shuffle(X,Y)</span><br><span class="line"></span><br><span class="line">X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size = <span class="number">0.3</span>)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br><span class="line">67</span><br><span class="line">68</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">NaiveBayes</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self</span>):</span></span><br><span class="line"> self.model = <span class="literal">None</span></span><br><span class="line"> self.Y_mean = <span class="literal">None</span></span><br><span class="line"> </span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算均值</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">mean</span>(<span class="params">self, X</span>):</span></span><br><span class="line"> <span class="keyword">return</span> <span class="built_in">sum</span>(X) / <span class="built_in">float</span>(<span class="built_in">len</span>(X))</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算标准差</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">std</span>(<span class="params">self, X</span>):</span></span><br><span class="line"> avg = self.mean(X)</span><br><span class="line"> <span class="keyword">return</span> math.sqrt(<span class="built_in">sum</span>([<span class="built_in">pow</span>(x - avg, <span class="number">2</span>) <span class="keyword">for</span> x <span class="keyword">in</span> X]) / <span class="built_in">float</span>(<span class="built_in">len</span>(X)))</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 概率密度函数</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">gaussian_probability</span>(<span class="params">self, x, mean, std</span>):</span></span><br><span class="line"> exp = math.exp(-(math.<span class="built_in">pow</span>(x - mean, <span class="number">2</span>) / (<span class="number">2</span> * math.<span class="built_in">pow</span>(std, <span class="number">2</span>))))</span><br><span class="line"> <span class="keyword">return</span> (<span class="number">1</span> / (math.sqrt(<span class="number">2</span> * math.pi)) * std) * exp</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算X_train的mean 和std</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">summarize</span>(<span class="params">self, X_train</span>):</span> <span class="comment"># *X_train 星号作用是解包然后逐个传入 </span></span><br><span class="line"> summaries = [(self.mean(i), self.std(i)) <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">zip</span>(*X_train)] <span class="comment"># 所以zip(*X_train) == zip([1,2,3],[2,3,4],...)</span></span><br><span class="line"> <span class="keyword">return</span> summaries</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">fit</span>(<span class="params">self, X, Y</span>):</span></span><br><span class="line"> labels = <span class="built_in">list</span>(<span class="built_in">set</span>(Y))</span><br><span class="line"> data = {label: [] <span class="keyword">for</span> label <span class="keyword">in</span> labels} <span class="comment"># 初始化</span></span><br><span class="line"> self.Y_mean = np.zeros(<span class="built_in">len</span>(labels))</span><br><span class="line"> <span class="keyword">for</span> x, label <span class="keyword">in</span> <span class="built_in">zip</span>(X, Y):</span><br><span class="line"> data[label].append(x)</span><br><span class="line"> self.Y_mean[<span class="built_in">int</span>(label)] += <span class="number">1.</span></span><br><span class="line"> self.Y_mean /= <span class="built_in">len</span>(Y)</span><br><span class="line"> <span class="comment"># 计算P(x_i|y_k) for i in range...</span></span><br><span class="line"> self.model = {</span><br><span class="line"> label:self.summarize(value)</span><br><span class="line"> <span class="keyword">for</span> label, value <span class="keyword">in</span> data.items()</span><br><span class="line"> }</span><br><span class="line"> <span class="keyword">return</span> <span class="string">'gaussianNB train done'</span></span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">calculate_probabilities</span>(<span class="params">self, input_data</span>):</span></span><br><span class="line"> <span class="comment"># summaries: {0: [(mean1,std1),(mean2,std2),(mean3,std3),(mean4,std4)], 1:...}</span></span><br><span class="line"> probabilities = {}</span><br><span class="line"> <span class="keyword">for</span> label, value <span class="keyword">in</span> self.model.items(): <span class="comment"># value --> summaries</span></span><br><span class="line"> probabilities[label] = self.Y_mean[<span class="built_in">int</span>(label)] <span class="comment"># P(C_i)</span></span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(value)):</span><br><span class="line"> mean, std = value[i]</span><br><span class="line"> probabilities[label] *= self.gaussian_probability(input_data[i], mean, std)</span><br><span class="line"> <span class="comment"># 计算P(x|c_i) * P(C_i)的概率</span></span><br><span class="line"> <span class="keyword">return</span> probabilities</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">predict</span>(<span class="params">self, X_test</span>):</span> </span><br><span class="line"> <span class="comment"># 这里X_test为单个实例</span></span><br><span class="line"> result = []</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(X_test)):</span><br><span class="line"> label = <span class="built_in">sorted</span>(self.calculate_probabilities(X_test[i]).items(), key=<span class="keyword">lambda</span> x:x[-<span class="number">1</span>])[-<span class="number">1</span>][<span class="number">0</span>] <span class="comment"># sorted 默认从小到大 返回一个list</span></span><br><span class="line"> <span class="comment"># 如[(1, 75), (0, 85), (2, 95)]</span></span><br><span class="line"> result.append(label)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> result</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">score</span>(<span class="params">self, X_test, Y_test</span>):</span></span><br><span class="line"> right = <span class="number">0</span></span><br><span class="line"> predictions = self.predict(X_test)</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(Y_test)):</span><br><span class="line"> <span class="keyword">if</span> predictions[i] == Y_test[i]:</span><br><span class="line"> right += <span class="number">1</span></span><br><span class="line"> <span class="keyword">return</span> right / <span class="built_in">float</span>(<span class="built_in">len</span>(Y_test))</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">model = NaiveBayes()</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">model.fit(X_train, Y_train)</span><br><span class="line"><span class="comment"># 'gaussianNB train done'</span></span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">result1 = model.predict(X_test)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">model.score(X_test, Y_test)</span><br></pre></td></tr></table></figure>
<p><strong>scikit-learn 实例</strong></p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> sklearn.naive_bayes <span class="keyword">import</span> GaussianNB</span><br><span class="line">clf = GaussianNB()</span><br><span class="line"><span class="comment">#X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])</span></span><br><span class="line"><span class="comment">#Y = np.array([1, 1, 1, 2, 2, 2])</span></span><br><span class="line">clf.fit(X_train, Y_train)</span><br><span class="line">result2 = clf.predict(X_test)</span><br><span class="line">clf.score(X_test, Y_test)</span><br></pre></td></tr></table></figure>
<h1 id="Reference:"><a href="#Reference:" class="headerlink" title="Reference:"></a>Reference:</h1><p>[1]. <a target="_blank" rel="noopener" href="https://github.com/fengdu78/lihang-code/blob/master/%E7%AC%AC04%E7%AB%A0%20%E6%9C%B4%E7%B4%A0%E8%B4%9D%E5%8F%B6%E6%96%AF/4.NaiveBayes.ipynb">https://github.com/fengdu78/lihang-code</a></p>
<p>[2]. 《统计学习方法》 – 李航</p>
</div>
<footer class="post-footer">
<div class="post-eof"></div>
</footer>
</article>
<article itemscope itemtype="http://schema.org/Article" class="post-block" lang="zh-CN">
<link itemprop="mainEntityOfPage" href="http://example.com/2021/03/14/K-NN/">
<span hidden itemprop="author" itemscope itemtype="http://schema.org/Person">
<meta itemprop="image" content="/images/avatar.gif">
<meta itemprop="name" content="Guangshan Shui">
<meta itemprop="description" content="">
</span>
<span hidden itemprop="publisher" itemscope itemtype="http://schema.org/Organization">
<meta itemprop="name" content="水广山">
</span>
<header class="post-header">
<h2 class="post-title" itemprop="name headline">
<a href="/2021/03/14/K-NN/" class="post-title-link" itemprop="url">K-nearest neighbor(K-NN)</a>
</h2>
<div class="post-meta">
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar"></i>
</span>
<span class="post-meta-item-text">发表于</span>
<time title="创建时间:2021-03-14 19:23:56" itemprop="dateCreated datePublished" datetime="2021-03-14T19:23:56+08:00">2021-03-14</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-calendar-check"></i>
</span>
<span class="post-meta-item-text">更新于</span>
<time title="修改时间:2021-03-15 17:29:49" itemprop="dateModified" datetime="2021-03-15T17:29:49+08:00">2021-03-15</time>
</span>
<span class="post-meta-item">
<span class="post-meta-item-icon">
<i class="far fa-folder"></i>
</span>
<span class="post-meta-item-text">分类于</span>
<span itemprop="about" itemscope itemtype="http://schema.org/Thing">
<a href="/categories/ML/" itemprop="url" rel="index"><span itemprop="name">ML</span></a>
</span>
</span>
</div>
</header>
<div class="post-body" itemprop="articleBody">
<p>基本的分类和回归方法</p>
<p>给定一个训练数据集, 对新的输入实例, 在训练数据集中找到与该实例最邻近的k个实例, 这k个实例的多数属于某个类, 就把该输入实例分为这个类。 </p>
<h1 id="1-k-近邻模型"><a href="#1-k-近邻模型" class="headerlink" title="1. k 近邻模型"></a>1. k 近邻模型</h1><p>三个基本要素:距离度量,k值选择、分类决策规则</p>
<h2 id="1-1-距离度量"><a href="#1-1-距离度量" class="headerlink" title="1.1 距离度量"></a>1.1 距离度量</h2><p>$x_{i}=\left(x_{i}^{(1)}, x_{i}^{(2)}, \cdots, x_{i}^{(n)}\right)^{\mathrm{T}}$ $L_{p}$ 距离:<br>$$<br>L_{p}\left(x_{i}, x_{j}\right)=\left(\sum_{l=1}^{n}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|^{p}\right)^{\frac{1}{p}}<br>$$<br>当 $p=2$ 为<strong>欧式距离</strong>(常用)</p>
<p>当 $p=1 $ 为曼哈顿距离</p>
<p>当 $p=\infty$ 时,是各个坐标距离的最大值<br>$$<br>L_{\infty}(x_{i}, x_{j})=\max_{l}\left|x_{i}^{(l)}-x_{j}^{(l)}\right|<br>$$</p>
<h2 id="1-2-k值选择"><a href="#1-2-k值选择" class="headerlink" title="1.2 k值选择"></a>1.2 k值选择</h2><p>k值选择较小时,近似误差小,估计误差大,过拟合</p>
<p>k值选择较大时,近似误差大,可以减少估计误差</p>
<p><strong>Trick:</strong> k值一般取一个比较小的数值。 通常采用交叉验证法来选取最优的k值</p>
<h2 id="1-3-分类决策规则"><a href="#1-3-分类决策规则" class="headerlink" title="1.3 分类决策规则"></a>1.3 分类决策规则</h2><p>往往是<strong>多数表决</strong>:由输入实例的k个邻近的训练实例中的多数类决定输入实例的类</p>
<p>误分类率:<br>$$<br>\frac{1}{k} \sum_{x_{i} \in N_{k}(x)} I\left(y_{i} \neq c_{j}\right)=1-\frac{1}{k} \sum_{x_{i} \in N_{k}(x)} I\left(y_{i}=c_{j}\right)<br>$$</p>
<h1 id="2-模型实现:kd树"><a href="#2-模型实现:kd树" class="headerlink" title="2. 模型实现:kd树"></a>2. 模型实现:kd树</h1><p>实现k近邻法时, 主要考虑的问题是如何对训练数据进行快速k近邻搜索。 这点在特征空间的维数大及训练数据容量大时尤其必要</p>
<p>线性扫描不可取,因为遍历所有的数据计算量太大</p>
<p><strong>解决方法:kd树(kd tree)</strong></p>
<p>kd树是二叉树, 平衡的kd树搜索时的效率未必是最优的</p>
<p>算法 3.2 (构造平衡 $k d$ 树)</p>
<p>输入: $k$ 维空间数据集 $T= \lbrace x_{1}, x_{2}, \cdots, x_{N} \rbrace$,其中 $x_{i}=\left(x_{i}^{(1)}, x_{i}^{(2)}, \cdots, x_{i}^{(k)}\right)^{\mathrm{T}}$,$i=1,2, \cdots, N$<br>输出: $k d$ 树。<br> (1) 开始: 构造根结点,根结点对应于包含 $T$ 的 $k$ 维空间的超矩形区域。</p>
<p> $x^{(1)}$ 为坐标轴,以 $T$ 中所有实例的 $x^{(1)}$ 坐标的中位数为切分点,将根结点选择对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴 $x^{(1)}$ 垂直的超平面实现。 </p>
<p> 由根结点生成深度为 1 的左、右子结点: 左子结点对应坐标 $x^{(1)}$ 小于切分点的子$x^{(1)}$ 大于切分点的子区域。 区域,右子结点对应于坐标将落在切分超平面上的实例点保存在根结点。</p>
<p> (2)重复: 对深度为 $j$ 的结点,选择 $x^{(l)}$ 为切分的坐标轴,$l=j(\bmod k)+1,$ 以该结点的区域中所有实例的 $x^{(l)}$ 坐标的中位数为切分点,将该结点对应的超矩形区域切分为两个子区域。切分由通过切分点并与坐标轴 $x^{(l)}$ 垂直的超平面实现。<br> 由该结点生成深度为 $j+1$ 的左、右子结点:左子结点对应坐标 $x^{(l)}$ 小于切分点 的子区域,右子结点对应坐标 $x^{(l)}$ 大于切分点的子区域。</p>
<p> 将落在切分超平面上的实例点保存在该结点。</p>
<p> (3)直到两个子区域没有实例存在时停止。从而形成 $k d$ 树的区域划分.</p>
<p><strong>搜索kd树</strong></p>
<p>给定一个目标点, 搜索其最近邻。 首先找到包含目标点的叶结点; 然后从该叶结点出发, 依次回退到父结点; 不断查找与目标点最邻近的结点, 当确定不可能存在更近的结点时终止</p>
<p><strong>算法 $3.3$(用 $k d$ 树的最近邻搜索)</strong></p>
<p>输入: 已构造的 $k d$ 树,目标点 $x$</p>
<p>输出: $x$ 的最近邻。</p>
<p>(1) 在 $k d$ 树中找出包含目标点 $x$ 的叶结点: 从根结点出发,递归地向下访问 $k d$ 树。若目标点 $x$ 当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点。<strong>直到子结点为叶结点为止</strong>。</p>
<p>(2) 以此叶结点为“当前最近点”。</p>
<p>(3) 递归地向上回退,在每个结点进行以下操作:</p>
<p> (a)如果该结点保存的实例点比当前最近点距离目标点更近,则以该实例点为“当前最近点”。</p>
<p> (b)当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一子结点对应的区域是否有更近的点。具体地,检查另一子结点对应的 区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。</p>
<p> 如果相交,可能在另一个子结点对应的区域内存在距目标点更近的点,移动 到另一个子结点。接着,递归地进行最近邻搜索; </p>
<p> 如果不相交,向上回退。</p>
<p>(4) 当回退到根结点时,搜索结束。最后的“当前最近点”即为 $x$ 的最近邻点。</p>
<h1 id="3-代码实现"><a href="#3-代码实现" class="headerlink" title="3. 代码实现"></a>3. 代码实现</h1><figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">## max()函数技巧</span></span><br><span class="line"></span><br><span class="line"><span class="comment">## 初级技巧</span></span><br><span class="line">tmp = <span class="built_in">max</span>(<span class="number">1</span>,<span class="number">2</span>,<span class="number">4</span>)</span><br><span class="line">print(tmp)</span><br><span class="line"><span class="comment">#可迭代对象</span></span><br><span class="line">a = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>]</span><br><span class="line">tmp = <span class="built_in">max</span>(a)</span><br><span class="line">print(tmp)</span><br><span class="line"></span><br><span class="line"><span class="comment">## 中级技巧:key属性的使用</span></span><br><span class="line"><span class="comment"># key参数不为空时,就以key的函数对象为判断的标准。</span></span><br><span class="line"><span class="comment"># 如果我们想找出一组数中绝对值最大的数,就可以配合lamda先进行处理,再找出最大值</span></span><br><span class="line">a = [-<span class="number">9</span>, -<span class="number">8</span>, <span class="number">1</span>, <span class="number">3</span>, -<span class="number">4</span>, <span class="number">6</span>]</span><br><span class="line">tmp = <span class="built_in">max</span>(a, key=<span class="keyword">lambda</span> x: <span class="built_in">abs</span>(x))</span><br><span class="line">print(tmp)</span><br><span class="line"></span><br><span class="line"><span class="comment">## 高级技巧:找出字典中值最大的那组数据</span></span><br><span class="line"><span class="comment">#在对字典进行数据操作的时候,默认只会处理key,而不是value</span></span><br><span class="line"><span class="comment">#先使用zip把字典的keys和values翻转过来,再用max取出值最大的那组数据</span></span><br><span class="line"><span class="comment">#这个时候key是值,value是之前的key</span></span><br><span class="line"><span class="comment">#如果有一组商品,其名称和价格都存在一个字典中,可以用下面的方法快速找到价格最贵的那组商品:</span></span><br><span class="line">prices = {</span><br><span class="line"> <span class="string">'A'</span>:<span class="number">123</span>,</span><br><span class="line"> <span class="string">'B'</span>:<span class="number">450.1</span>,</span><br><span class="line"> <span class="string">'C'</span>:<span class="number">12</span>,</span><br><span class="line"> <span class="string">'E'</span>:<span class="number">444</span>,</span><br><span class="line">}</span><br><span class="line"><span class="comment"># 在对字典进行数据操作的时候,默认只会处理key,而不是value</span></span><br><span class="line"><span class="comment"># 先使用zip把字典的keys和values翻转过来,再用max取出值最大的那组数据</span></span><br><span class="line">max_prices = <span class="built_in">max</span>(<span class="built_in">zip</span>(prices.values(), prices.keys()))</span><br><span class="line">print(max_prices) </span><br><span class="line"><span class="comment">#这个时候key是值,value是之前的key</span></span><br><span class="line"><span class="comment"># (450.1, 'B')</span></span><br><span class="line"></span><br><span class="line"><span class="comment">#当字典中的value相同的时候,才会比较key</span></span><br><span class="line">prices = {</span><br><span class="line"> <span class="string">'A'</span>: <span class="number">123</span>,</span><br><span class="line"> <span class="string">'B'</span>: <span class="number">123</span>,</span><br><span class="line">}</span><br><span class="line">max_prices = <span class="built_in">max</span>(<span class="built_in">zip</span>(prices.values(), prices.keys()))print(max_prices) <span class="comment"># (123, 'B')</span></span><br><span class="line">min_prices = <span class="built_in">min</span>(<span class="built_in">zip</span>(prices.values(), prices.keys()))print(min_prices) <span class="comment"># (123, 'A')</span></span><br></pre></td></tr></table></figure>
<p><strong>K-NN</strong></p>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line">%matplotlib inline</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> sklearn.datasets <span class="keyword">import</span> load_iris</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">## load data</span></span><br><span class="line">iris = load_iris() <span class="comment"># np.array 格式</span></span><br><span class="line">df = pd.DataFrame(iris.data, columns = iris.feature_names)</span><br><span class="line">df[<span class="string">'label'</span>] = iris.target <span class="comment"># 添加label列</span></span><br><span class="line">df[:<span class="number">20</span>] <span class="comment"># 查看一下实际数据</span></span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">data = np.array(df.iloc[:,:]) <span class="comment"># 数据由Datafrom转化为array</span></span><br><span class="line">X, Y = data[:, :-<span class="number">1</span>], data[:, -<span class="number">1</span>]</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">_shuffle</span>(<span class="params">X, Y</span>):</span></span><br><span class="line"> randomize = np.arange(<span class="built_in">len</span>(X))</span><br><span class="line"> np.random.shuffle(randomize)</span><br><span class="line"> <span class="keyword">return</span> X[randomize], Y[randomize]</span><br><span class="line">X, Y = _shuffle(X,Y)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">_train_test_split</span>(<span class="params">X, Y, split_ratio = <span class="number">0.1</span></span>):</span></span><br><span class="line"> train_size = <span class="built_in">int</span>(<span class="built_in">len</span>(X) * (<span class="number">1</span> - split_ratio))</span><br><span class="line"> <span class="keyword">return</span> X[:train_size], Y[:train_size], X[train_size:], Y[train_size:]</span><br><span class="line"></span><br><span class="line">X_train, Y_train, X_test, Y_test = _train_test_split(X, Y, split_ratio = <span class="number">0.1</span>)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">## 数据标准化</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">_normalize</span>(<span class="params">X, train_set = <span class="literal">True</span>, specified_columns = <span class="literal">None</span>, X_mean = <span class="literal">None</span>, X_std = <span class="literal">None</span></span>):</span></span><br><span class="line"> <span class="keyword">if</span> specified_columns == <span class="literal">None</span>:</span><br><span class="line"> specified_columns = np.arange(<span class="built_in">len</span>(X[<span class="number">0</span>]))</span><br><span class="line"> <span class="keyword">if</span> train_set:</span><br><span class="line"> X_mean = np.mean(X[:, specified_columns], axis = <span class="number">0</span>)</span><br><span class="line"> X_std = np.std(X[:, specified_columns], axis = <span class="number">0</span>)</span><br><span class="line"> X[:, specified_columns] = (X[:,specified_columns] - X_mean) / (X_std + <span class="number">1e-8</span>)</span><br><span class="line"> <span class="keyword">return</span> X, X_mean, X_std</span><br><span class="line"></span><br><span class="line">X_train, X_mean, X_std = _normalize(X_train, train_set=<span class="literal">True</span>)</span><br><span class="line">X_test, _, _ = _normalize(X_test, train_set = <span class="literal">None</span>, X_mean = X_mean, X_std = X_std)</span><br></pre></td></tr></table></figure>
<ul>
<li><strong>线性遍历最近的K个点</strong></li>
</ul>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">KNN_LinearSearch</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, X_train, Y_train, k = <span class="number">3</span>, p=<span class="number">2</span></span>):</span></span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> k: 临近点的个数</span></span><br><span class="line"><span class="string"> p: 距离度量</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> self.k = k</span><br><span class="line"> self.p = p</span><br><span class="line"> self.X_train = X_train</span><br><span class="line"> self.Y_train = Y_train</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">predict</span>(<span class="params">self, X_test</span>):</span></span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> X_test,该函数找出最近的k个点,并按照多数表决规则进行预测</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> result = []</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(X_test)):</span><br><span class="line"> X = X_test[i]</span><br><span class="line"> knn_list = []</span><br><span class="line"> <span class="keyword">for</span> j <span class="keyword">in</span> <span class="built_in">range</span>(self.k):</span><br><span class="line"> dist = np.linalg.norm(X - self.X_train[j], <span class="built_in">ord</span>=self.p) <span class="comment">#np中求线性代数范数的公式</span></span><br><span class="line"> knn_list.append((dist, self.Y_train[j]))</span><br><span class="line"> <span class="keyword">for</span> l <span class="keyword">in</span> <span class="built_in">range</span>(self.k, <span class="built_in">len</span>(self.X_train)):</span><br><span class="line"> max_index = knn_list.index(<span class="built_in">max</span>(knn_list, key=<span class="keyword">lambda</span> x: x[<span class="number">0</span>]))</span><br><span class="line"> dist_new = np.linalg.norm(X-self.X_train[l], <span class="built_in">ord</span>=self.p)</span><br><span class="line"> <span class="keyword">if</span> knn_list[max_index][<span class="number">0</span>] > dist_new:</span><br><span class="line"> knn_list[max_index] = (dist, self.Y_train[l])</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 分类决策规则:多数表决</span></span><br><span class="line"> knn_label = [k[-<span class="number">1</span>] <span class="keyword">for</span> k <span class="keyword">in</span> knn_list]</span><br><span class="line"> max_count_label = <span class="built_in">max</span>(knn_label, key=knn_label.count)</span><br><span class="line"> result.append(max_count_label)</span><br><span class="line"> <span class="keyword">return</span> result</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">Correct_rate</span>(<span class="params">self, X_test, Y_test</span>):</span></span><br><span class="line"> right_count = <span class="number">0</span></span><br><span class="line"> predictions = self.predict(X_test)</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="built_in">len</span>(Y_test)):</span><br><span class="line"> <span class="keyword">if</span> predictions[i] == Y_test[i]:</span><br><span class="line"> right_count += <span class="number">1</span></span><br><span class="line"> <span class="keyword">return</span> right_count / <span class="built_in">len</span>(X_test)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">KNN = KNN_LinearSearch(X_train, Y_train)</span><br><span class="line">KNN.predict(X_test)</span><br><span class="line"><span class="comment"># [1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0]</span></span><br><span class="line">KNN.Correct_rate(X_test, Y_test)</span><br></pre></td></tr></table></figure>
<ul>
<li><strong>Kd树-最近邻搜索</strong></li>
</ul>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">KdNode</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, dom_elt, split, left, right</span>):</span></span><br><span class="line"> self.dom_elt = dom_elt <span class="comment"># k维向量节点(k维空间中的一个样本点), 分割的那个样本点 dom_elt = [1,2,3,5,6...]</span></span><br><span class="line"> self.split = split <span class="comment"># 整数(进行分割纬度的序号)</span></span><br><span class="line"> self.left = left <span class="comment"># 该结点分割超平面左子空间构成的kd-tree</span></span><br><span class="line"> self.right = right <span class="comment"># 该结点分割超平面右子空间构成的kd-tree</span></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">KdTree</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span>(<span class="params">self, data</span>):</span></span><br><span class="line"> dim = <span class="built_in">len</span>(data[<span class="number">0</span>]) <span class="comment"># 数据维度</span></span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">CreateNode</span>(<span class="params">split, data_set</span>):</span></span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> data_set:</span><br><span class="line"> <span class="keyword">return</span> <span class="literal">None</span></span><br><span class="line"> <span class="comment"># 按要进行分割的那一切数据排序</span></span><br><span class="line"> data_set.sort(key=<span class="keyword">lambda</span> x: x[split]) <span class="comment"># 这里sort,key的用法和max类似</span></span><br><span class="line"> split_pos = <span class="built_in">len</span>(data_set) // <span class="number">2</span> <span class="comment"># 整数除法</span></span><br><span class="line"> median = data_set[split_pos] <span class="comment"># 中位数分割点</span></span><br><span class="line"> split_next = (split + <span class="number">1</span>) % dim</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 递归的创建kd树</span></span><br><span class="line"> <span class="keyword">return</span> KdNode(median, split, </span><br><span class="line"> CreateNode(split_next, data_set[:split_pos]),</span><br><span class="line"> CreateNode(split_next, data_set[split_pos + <span class="number">1</span>:]))</span><br><span class="line"> self.root = CreateNode(<span class="number">0</span>, data) <span class="comment"># 从第0维开始创建kd树</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># kd树的前序遍历</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">preorder</span>(<span class="params">root</span>):</span></span><br><span class="line"> print(root.dom_elt)</span><br><span class="line"> <span class="keyword">if</span> root.left: <span class="comment"># 节点不为空</span></span><br><span class="line"> preorder(root.left)</span><br><span class="line"> <span class="keyword">if</span> root.right:</span><br><span class="line"> preorder(root.right)</span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">## 寻找最近的k个点</span></span><br><span class="line"><span class="keyword">from</span> math <span class="keyword">import</span> sqrt</span><br><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> namedtuple</span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义一个namedtuple, 分别存放最近坐标点、最近距离和访问过的节点数</span></span><br><span class="line">result = namedtuple(<span class="string">"Result_tuple"</span>, <span class="string">"nearest_point nearest_dist nodes_visited"</span>)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">find_nearest</span>(<span class="params">tree, point</span>):</span></span><br><span class="line"> dim = <span class="built_in">len</span>(point)</span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">travel</span>(<span class="params">kd_node, target, max_dist</span>):</span></span><br><span class="line"> <span class="comment"># 迭代终止条件</span></span><br><span class="line"> <span class="keyword">if</span> kd_node <span class="keyword">is</span> <span class="literal">None</span>:</span><br><span class="line"> <span class="keyword">return</span> result([<span class="number">0</span>] * dim, <span class="built_in">float</span>(<span class="string">"inf"</span>), <span class="number">0</span>) <span class="comment"># float("inf")正无穷,float("-inf")</span></span><br><span class="line"> </span><br><span class="line"> nodes_visited = <span class="number">1</span></span><br><span class="line"> s = kd_node.split <span class="comment"># 分割的纬度</span></span><br><span class="line"> pivot = kd_node.dom_elt <span class="comment"># 进行分割的轴,实例点</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> target[s] <= pivot[s]: <span class="comment"># 如果目标点第s维小于分割点的对应值</span></span><br><span class="line"> nearer_node = kd_node.left <span class="comment">#下一访问节点为左子树根节点</span></span><br><span class="line"> further_node = kd_node.right <span class="comment"># 同时记录右子树以便后期与目标点进行比较</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> nearer_node = kd_node.right</span><br><span class="line"> further_node = kd_node.left</span><br><span class="line"> </span><br><span class="line"> temp1 = travel(nearer_node, target, max_dist) <span class="comment"># 递归遍历</span></span><br><span class="line"> </span><br><span class="line"> nearest = temp1.nearest_point <span class="comment"># 以此叶节点作为当前最近点</span></span><br><span class="line"> dist = temp1.nearest_dist <span class="comment"># 更新最近距离</span></span><br><span class="line"> nodes_visited += temp1.nodes_visited</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> dist < max_dist:</span><br><span class="line"> max_dist = dist</span><br><span class="line"> </span><br><span class="line"> temp_dist = <span class="built_in">abs</span>(pivot[s] - target[s]) <span class="comment"># 第s维上目标点与分割超平面的距离</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> max_dist < temp_dist: <span class="comment"># 判断超球体是否与超平面相交</span></span><br><span class="line"> <span class="keyword">return</span> result(nearest, dist, nodes_visited) <span class="comment"># 不相交则可以直接返回,不用继续判断,也就意味这父节点(分割点)的另一子节点不会比现在的子节点与目标点更近</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 计算目标点和分割点的欧氏距离</span></span><br><span class="line"> temp_dist = sqrt(<span class="built_in">sum</span>((p1 - p2) ** <span class="number">2</span> <span class="keyword">for</span> p1, p2 <span class="keyword">in</span> <span class="built_in">zip</span>(pivot, target)))</span><br><span class="line"> <span class="keyword">if</span> temp_dist < dist: <span class="comment"># 意味着分割点此时为最近点</span></span><br><span class="line"> nearest = pivot <span class="comment"># 更新最近点</span></span><br><span class="line"> dist = temp_dist <span class="comment"># 更新最近距离</span></span><br><span class="line"> max_dist = dist <span class="comment"># 更新超超球体半径</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 检查另一个子节点对应的区域是否有更近的点</span></span><br><span class="line"> temp2 = travel(further_node, target, max_dist)</span><br><span class="line"> </span><br><span class="line"> nodes_visited += temp2.nodes_visited</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> temp2.nearest_dist < dist: <span class="comment"># 如果另一个子节点内存在更近距离</span></span><br><span class="line"> nearest = temp2.nearest_point</span><br><span class="line"> dist = temp2.nearest_dist</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> result(nearest, dist, nodes_visited) </span><br><span class="line"> <span class="keyword">return</span> travel(tree.root, point, <span class="built_in">float</span>(<span class="string">"inf"</span>)) <span class="comment"># 从根节点递归</span></span><br></pre></td></tr></table></figure>
<figure class="highlight python"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">data = [[<span class="number">2</span>,<span class="number">3</span>],[<span class="number">5</span>,<span class="number">4</span>],[<span class="number">9</span>,<span class="number">6</span>],[<span class="number">4</span>,<span class="number">7</span>],[<span class="number">8</span>,<span class="number">1</span>],[<span class="number">7</span>,<span class="number">2</span>]]</span><br><span class="line">kd = KdTree(data)</span><br><span class="line">preorder(kd.root)</span><br><span class="line">[<span class="number">7</span>, <span class="number">2</span>]</span><br><span class="line">[<span class="number">5</span>, <span class="number">4</span>]</span><br><span class="line">[<span class="number">2</span>, <span class="number">3</span>]</span><br><span class="line">[<span class="number">4</span>, <span class="number">7</span>]</span><br><span class="line">[<span class="number">9</span>, <span class="number">6</span>]</span><br><span class="line">[<span class="number">8</span>, <span class="number">1</span>]</span><br></pre></td></tr></table></figure>