-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
49 lines (38 loc) · 1.13 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# -*- coding: utf-8 -*-
# @Time : 2021/1/10 9:23
# @Author : CHT
# @Site :
# @File : utils.py
# @Software: PyCharm
# @Blog: https://www.zhihu.com/people/xia-gan-yi-dan-chen-hao-tian
# @Function:
import numpy as np
import heapq
from matplotlib import pyplot as plt
from math import inf, nan, log, sqrt
from collections import Counter
# Functions
def argmax(arr, key=lambda x: x):
arr = [key(a) for a in arr]
ans = max(arr)
return arr.index(ans), ans
# Decision Tree
def entropy(p):
s = sum(p)
p = [i/s for i in p]
ans = sum(-i*log(i) for i in p)
return ans
def entropy_of_split(X, Y, col):
# calculate the conditional entropy of splitting data by col
val_cnt = Counter(x[col] for x in X)
ans = 0
for val in val_cnt:
weight = val_cnt[val] / len(X)
entropy_ = entropy(Counter(y for x, y in zip(X, Y) if x[col] == val).values())
ans += weight * entropy_
return ans
def information_gain(X, Y, col):
# 信息增益 = 信息熵 - 条件熵
entropy_of_X = entropy(Counter(Y).values())
entropy_of_col = entropy_of_split(X, Y, col)
return entropy_of_X - entropy_of_col