Skip to content

Commit

Permalink
fi
Browse files Browse the repository at this point in the history
  • Loading branch information
radare committed Feb 9, 2025
1 parent dbef1ba commit 5b8a8a1
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 188 deletions.
28 changes: 14 additions & 14 deletions src/r2ai.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
#define R_LOG_ORIGIN "r2ai"

#include "r2ai.h"
static R_TH_LOCAL RVDB *db = NULL;
static R_TH_LOCAL RVdb *db = NULL;

#define VDBDIM 16

static void refresh_embeddings(RCore *core) {
RListIter *iter, *iter2;
char *line;
char *file;
// refresh embeddings database
r_vdb_free (db);
db = r_vdb_new (4);
db = r_vdb_new (VDBDIM);
// enumerate .txt files in directory
const char *path = r_config_get (core->config, "r2ai.data.path");
RList *files = r_sys_dir (path);
Expand Down Expand Up @@ -127,7 +129,7 @@ R_IPI char *r2ai(RCore *core, const char *input, char **error, bool dorag) {
} else {
continue;
}
RVDBResultSet *rs = r_vdb_query (db, line, K);
RVdbResultSet *rs = r_vdb_query (db, line, K);
#if 0
eprintf ("-------------VDB\n");
eprintf ("%s\n", vdb_input);
Expand All @@ -138,7 +140,7 @@ R_IPI char *r2ai(RCore *core, const char *input, char **error, bool dorag) {
#if 0
RStrBuf *sb = r_strbuf_new ("");
for (int i = 0; i < rs->size; i++) {
RVDBResult *r = &rs->results[i];
RVdbResult *r = &rs->results[i];
KDNode *n = r->node;
r_strbuf_appendf (sb, "- %s\n", n->text);
}
Expand All @@ -165,7 +167,7 @@ R_IPI char *r2ai(RCore *core, const char *input, char **error, bool dorag) {
content = r_str_newf ("## Prompt\n%s.\n## Context\n%s", input, res);
free (res);
} else {
RVDBResultSet *rs = r_vdb_query (db, content, K);
RVdbResultSet *rs = r_vdb_query (db, content, K);
#if 0
eprintf ("-------------VDB\n");
eprintf ("%s\n", vdb_input);
Expand All @@ -176,7 +178,7 @@ R_IPI char *r2ai(RCore *core, const char *input, char **error, bool dorag) {
RStrBuf *sb = r_strbuf_new ("");
int i;
for (i = 0; i < rs->size; i++) {
RVDBResult *r = &rs->results[i];
RVdbResult *r = &rs->results[i];
KDNode *n = r->node;
r_strbuf_appendf (sb, "- %s\n", n->text);
}
Expand Down Expand Up @@ -357,14 +359,14 @@ static void cmd_r2ai_R(RCore *core, const char *q) {
refresh_embeddings (core);
}
const int K = r_config_get_i (core->config, "r2ai.data.nth");
RVDBResultSet *rs = r_vdb_query (db, q, K);
RVdbResultSet *rs = r_vdb_query (db, q, K);

if (rs) {
R_LOG_INFO ("Query: \"%s\"", q);
R_LOG_INFO ("Found up to %d neighbors (actual found: %d)", K, rs->size);
int i;
for (i = 0; i < rs->size; i++) {
RVDBResult *r = &rs->results[i];
RVdbResult *r = &rs->results[i];
KDNode *n = r->node;
float dist_sq = r->dist_sq;
float cos_sim = 1.0f - (dist_sq * 0.5f); // for normalized vectors
Expand Down Expand Up @@ -522,12 +524,10 @@ static void cmd_r2ai_m(RCore *core, const char *input) {
r_cons_printf ("Model set to %s\n", input);
}

static void load_embeddings(RCore *core, RVDB *db) {
static void load_embeddings(RCore *core, RVdb *db) {
RListIter *iter, *iter2;
char *line;
char *file;
// refresh embeddings database
// db = r_vdb_new (4);
// enumerate .txt files in directory
const char *path = r_config_get (core->config, "r2ai.data.path");
RList *files = r_sys_dir (path);
Expand Down Expand Up @@ -579,18 +579,18 @@ static void cmd_r2ai(RCore *core, const char *input) {
cmd_r2ai_s (core);
} else if (r_str_startswith (input, "-S")) {
if (db == NULL) {
db = r_vdb_new (4);
db = r_vdb_new (VDBDIM);
load_embeddings (core, db);
}
const char *arg = r_str_trim_head_ro (input + 2);
const int K = 10;
eprintf ("vector search\n");
RVDBResultSet *rs = r_vdb_query (db, arg, K);
RVdbResultSet *rs = r_vdb_query (db, arg, K);
if (rs) {
int i;
eprintf ("Found up to %d neighbors (actual found: %d).\n", K, rs->size);
for (i = 0; i < rs->size; i++) {
RVDBResult *r = &rs->results[i];
RVdbResult *r = &rs->results[i];
KDNode *n = r->node;
r_cons_printf ("- (%.4f) %s\n", r->dist_sq, n->text);
}
Expand Down
77 changes: 47 additions & 30 deletions src/r_vdb.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,54 +3,71 @@

/* Vector of floats with `dim` dimensions */
typedef struct {
float *data;
int dim;
float *data;
int dim;
} Vector;

/* KD-node that stores a Vector and associated text */
typedef struct KDNode {
Vector point;
char *text;
struct KDNode *left;
struct KDNode *right;
int split_dim;
Vector point;
char *text;
struct KDNode *left;
struct KDNode *right;
int split_dim;
} KDNode;

/* A k-d tree "database" */
typedef struct token_df {
char *token;
int df;
struct token_df *next;
char *token;
int df; // count
struct token_df *next;
} token_df;

typedef struct RVDB {
KDNode *root;
int dimension;
int size;
int total_docs; // initialize to 0
token_df *df_table; // initialize to NULL
} RVDB;
typedef struct {
char *token;
int count;
float df;
} RVdbToken;

static inline void token_free(void *p) {
if (p) {
RVdbToken *t = (RVdbToken *)p;
free (t->token);
free (t);
}
}

typedef struct {
KDNode *root;
int dimension;
int size;
RList *tokens; // global tokens count
int total_docs; // initialize to 0
// token_df *df_table; // initialize to NULL
} RVdb;

/* Each k-NN result: pointer to KDNode + distance (squared). */
typedef struct {
KDNode *node;
float dist_sq;
} RVDBResult;
float dist_sq;
} RVdbResult;

/*
/*
* A max-heap of up to k results, sorted by dist_sq DESC.
* That way, the root is the *worst* (largest) distance in the set,
* making it easy to pop it when we find a better (smaller-dist) candidate.
*/
typedef struct {
RVDBResult *results;
int capacity;
int size;
} RVDBResultSet;

RVDB *r_vdb_new(int dim);
void r_vdb_insert(RVDB *db, const char *text);
RVDBResultSet *r_vdb_query(RVDB *db, const char *text, int k);
void r_vdb_free(RVDB *db);
void r_vdb_result_free(RVDBResultSet *rs);
RVdbResult *results;
int capacity;
int size;
} RVdbResultSet;

RVdb *r_vdb_new(int dim);
void r_vdb_insert(RVdb *db, const char *text); // add_document
// expose api to add_token
RVdbResultSet *r_vdb_query(RVdb *db, const char *text, int k);
void r_vdb_free(RVdb *db);
void r_vdb_result_free(RVdbResultSet *rs);

#endif
Loading

0 comments on commit 5b8a8a1

Please sign in to comment.