-
Notifications
You must be signed in to change notification settings - Fork 0
/
gmm.h
115 lines (97 loc) · 2.89 KB
/
gmm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
/*
* gmm.h
*
* Contains declarations of functions for training
* Gaussian Mixture Models
*
* Copyright (C) 2015 Sai Nitish Satyavolu
*/
#ifndef GMM_H
#define GMM_H
#ifdef __cplusplus
extern "C" {
#endif
/*
* Type for storing GMM parameter initialization method
*/
typedef enum {RANDOM, KMEANS} InitMethod;
/*
* Type for storing the type of GMM covariance matrix
*/
typedef enum {DIAGONAL, SPHERICAL} CovType;
/*
* The GMM structure
*/
typedef struct _GMM
{
/* --------------------------------- Settings */
int M; // Number of components
int D; // Number of features
int num_max_iter; // Maximum number of iterations
int converged; // Convergence status
double tol; // Convergence tolerance
double reg; // Regularization value
InitMethod init_method; // Initialization method
CovType cov_type; // Covariance type
/* --------------------------- GMM Parameters */
double *weights; // Component weights
double **means; // Component means
double **covars; // Component covariances
/* ---------------------- Auxiliary variables */
double **P_k_giv_xt; // Membership probability matrix
} GMM;
/*
* Function for initializing a new GMM
*/
GMM* gmm_new(int M, int D, const char *cov_type);
/*
* Function to set maximum number of EM iterations
*/
void gmm_set_max_iter(GMM *gmm, int num_max_iter);
/*
* Function to set EM convergence tolerance
*/
void gmm_set_convergence_tol(GMM *gmm, double tol);
/*
* Function to set regularization value of covariance matrix
*/
void gmm_set_regularization_value(GMM *gmm, double reg);
/*
* Function to set GMM parameter initialization method
*/
void gmm_set_initialization_method(GMM *gmm, const char *method);
/*
* Function to fit a GMM on a given set of data points
*/
void gmm_fit(GMM *gmm, const double *X, int N);
/*
* Function to score a set of data points using the GMM
*/
double gmm_score(GMM *gmm, const double *X, int N);
/*
* Function to print the GMM parameters
*/
void gmm_print_params(const GMM *gmm);
/*
* Function to free the GMM
*/
void gmm_free(GMM *gmm);
/*
* Internal functions (do not call them!)
*/
void _gmm_init_params(GMM *gmm, const double *X, int N);
void _gmm_init_params_random(GMM *gmm, const double *X, int N);
void _gmm_init_params_kmeans(GMM *gmm, const double *X, int N);
double _gmm_em_step(GMM *gmm, const double *X, int N);
double _gmm_compute_membership_prob(GMM *gmm, const double *X, int N);
void _gmm_update_params(GMM *gmm, const double *X, int N);
double _gmm_log_gaussian_pdf(const double *x, const double *mean, const double *covar, int D, CovType cov_type);
double _gmm_vec_l2_dist(const double *x, const double *y, int D);
void _gmm_vec_add(double *x, const double *y, double a, double b, int D);
void _gmm_vec_divide_by_scalar(double *x, double a, int D);
double _gmm_vec_dot_prod(const double *x, const double *y, int D);
double _gmm_pow2(double x);
#ifdef __cplusplus
}
#endif
#endif