Skip to content

Commit

Permalink
Converged centroids now accessible in api
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamil committed Apr 29, 2014
1 parent ede7f73 commit b805d28
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 27 deletions.
28 changes: 28 additions & 0 deletions Cluster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ bool Cluster::step() {
return true; // return true if no change
}

float** Cluster::centroids() {
assert(*_centroids);
int numParams = (*_centroids)->size();
float** c = new float*[_k];
for (int i = 0; i < _k; i++) {
c[i] = new float[numParams];
float* features = _centroids[i]->features();
for (int j = 0; j < numParams; j++) {
c[i][j] = features[j];
}
}
return c;
}

void Cluster::updateCentroids() {
// Examine first training example to get feature count, etc
if (!_data.size())
Expand Down Expand Up @@ -211,3 +225,17 @@ void Cluster::updateUntilConvergence() {
}
}

int Cluster::classify(Data *data) {
float minDist = FLT_MAX;
int classification = -1;
for (int i = 0; i < _k; i++) {
float dist = distance(data, _centroids[i]);
if (dist < minDist) {
minDist = dist;
classification = i;
}
}
cout << "Classification: " << classification << endl;
return classification;
}

2 changes: 2 additions & 0 deletions Cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class Cluster {
Cluster(int k, vector<Data*> data);
~Cluster();
void update();
int classify(Data *data);
float** centroids();

private:
vector<Data*> _data; // Training set (unsupervised)
Expand Down
2 changes: 1 addition & 1 deletion Data.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <vector>
#include <cassert>

#define PRINT_LOG
//#define PRINT_LOG

#ifdef PRINT_LOG
#define LOG(string, ...) printf(string, __VA_ARGS__)
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ CC = g++
SRC = main.cpp
OBJ = Data.o SupervisedData.o BatchDescentLearner.o StochasticLearner.o Learner.o Cluster.o
CFLAGS = -Werror -g -ggdb
LDFLAGS =
LDFLAGS = -lreadline

all: vasco

Expand Down
40 changes: 15 additions & 25 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,35 +44,25 @@ int main() {
if (!data.size())
return 1;

Cluster learner(3, data);
learner.update();
// Repeatedly classify

bool quit = false;
while (!quit) {
char* line = (char*)malloc(sizeof(char)*100);
size_t nbytes = 99;
printf(">> ");
getline(&line, &nbytes, stdin);
for (int i = 0; i < strlen(line); i++)
if (line[i] == '\n')
line[i] = '\0';
if (!strcmp(line, "quit")) {
quit = true;
}
else {
float params[4] = {0};
stringstream ss(line);
for (int i = 0; i < 4; i++) {
ss >> params[i];
}
for (int p = 0; p < 10; p++) {
Cluster learner(3, data);
learner.update();
float** centroids = learner.centroids();

Data *newData = new Data(num_features);
newData->setFeatures(num_features, params);
learner.classify(newData);
free(newData);
for (int i = 0; i < 3; i++) {
cout << "Centroid " << i << endl;
for (int j = 0; j < 4; j++) {
cout << "\t" << centroids[i][j];
}
cout << endl;
}

free(line);
for (int i = 0; i < 3; i++)
free(centroids[i]);
free(centroids);
cout << endl << endl;
}

for (int i = 0; i < set_size; i++) {
Expand Down

0 comments on commit b805d28

Please sign in to comment.