Skip to content

Commit

Permalink
Fixed some protocol decoding bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
codepr committed Mar 19, 2024
1 parent b2eeadc commit d738b8c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/logging.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ static const char *llevels = "+#!";

void rr_log(rr_log_level level, const char *fmt, ...) {
FILE *fp = level > R_INFO ? stderr : stdout;
fprintf(fp, "%c %lu ", llevels[level], time(NULL));
fprintf(fp, "%lu %c ", time(NULL), llevels[level]);
va_list args;
va_start(args, fmt);
vfprintf(fp, fmt, args);
Expand Down
24 changes: 13 additions & 11 deletions src/protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ ssize_t encode_response(const Response *r, uint8_t *dst) {
// Array response
dst[0] = '#';
ssize_t i = 1;
size_t length = r->array_response.length, j = 0;
size_t j = 0;

// Array length
size_t n = snprintf((char *)dst + i, 20, "%lu", r->array_response.length);
Expand All @@ -73,7 +73,7 @@ ssize_t encode_response(const Response *r, uint8_t *dst) {
dst[i++] = '\n';

// Records
while (length-- > 0) {
while (j < r->array_response.length) {
// Timestamp
dst[i++] = ':';
n = snprintf((char *)dst + i, 21, "%lu",
Expand All @@ -83,18 +83,14 @@ ssize_t encode_response(const Response *r, uint8_t *dst) {
dst[i++] = '\n';
// Value
dst[i++] = ';';
n = snprintf((char *)dst + i, 21, "%.20lf",
n = snprintf((char *)dst + i, 21, "%lf",
r->array_response.records[j].value);
i += n;
dst[i++] = '\r';
dst[i++] = '\n';
j++;
}

// CRLF
dst[i++] = '\r';
dst[i++] = '\n';

return i;
}

Expand Down Expand Up @@ -128,16 +124,19 @@ ssize_t decode_response(const uint8_t *data, Response *dst) {
uint8_t byte = *ptr;
ssize_t length = 0;

dst->type = byte == '*' ? ARRAY_RSP : STRING_RSP;
dst->type = byte == '#' ? ARRAY_RSP : STRING_RSP;

switch (byte) {
case '$':
case '!':
// Treat error and common strings the same for now
length = decode_string(ptr, dst);
break;
case '*':
case '#':
ptr++;
length++;
// Read length
dst->array_response.length = 0;
while (*ptr != '\r' && *(ptr + 1) != '\n') {
dst->array_response.length *= 10;
dst->array_response.length += *ptr - '0';
Expand Down Expand Up @@ -166,7 +165,7 @@ ssize_t decode_response(const uint8_t *data, Response *dst) {
}
} else {
// Value
uint8_t buf[20];
uint8_t buf[32];
size_t k = 0;
while (*ptr != '\r' && *(ptr + 1) != '\n') {
buf[k++] = *ptr;
Expand All @@ -190,4 +189,7 @@ ssize_t decode_response(const uint8_t *data, Response *dst) {
return length;
}

void free_response(Response *rs) { free(rs->array_response.records); }
void free_response(Response *rs) {
if (rs->type == ARRAY_RSP)
free(rs->array_response.records);
}
40 changes: 28 additions & 12 deletions src/server.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,22 @@ static Response execute_statement(const Statement *statement) {

switch (statement->type) {
case STATEMENT_CREATE:
if (statement->create.mask == 0)
if (statement->create.mask == 0) {
db = tsdb_init(statement->create.db_name);
else
} else {
if (!db)
db = tsdb_init(statement->create.db_name);

(void)ts_create(db, statement->create.ts_name, 0, DP_IGNORE);
}
rs.type = STRING_RSP;
rs.string_response.rc = 0;
strncpy(rs.string_response.message, "Ok", 3);
rs.string_response.length = 3;
break;
case STATEMENT_INSERT:
if (!db)
db = tsdb_init(statement->insert.db_name);
ts = ts_get(db, statement->insert.ts_name);
uint64_t timestamp = 0;
for (size_t i = 0; i < statement->insert.record_len; ++i) {
Expand All @@ -47,23 +53,30 @@ static Response execute_statement(const Statement *statement) {
rs.string_response.length = 3;
break;
case STATEMENT_SELECT:
if (!db)
db = tsdb_init(statement->insert.db_name);
ts = ts_get(db, statement->select.ts_name);
int err = 0;
Points coll;
vec_new(coll);

if (statement->select.mask & SM_SINGLE) {
err = ts_find(ts, statement->select.start_time, &r);
if (err < 0)
if (err < 0) {
log_error("Couldn't find the record %lu",
statement->select.start_time);
else
rs.type = STRING_RSP;
rs.string_response.length = 9;
strncpy(rs.string_response.message, "Not found", 10);
} else {
log_info("Record found: %lu %.2lf", r.timestamp, r.value);
rs.array_response.length = 1;
rs.array_response.records =
calloc(1, sizeof(*rs.array_response.records));
rs.array_response.records[0].timestamp = r.timestamp;
rs.array_response.records[0].value = r.value;
rs.type = ARRAY_RSP;
rs.array_response.length = 1;
rs.array_response.records =
calloc(1, sizeof(*rs.array_response.records));
rs.array_response.records[0].timestamp = r.timestamp;
rs.array_response.records[0].value = r.value;
}
} else if (statement->select.mask & SM_RANGE) {
err = ts_range(ts, statement->select.start_time,
statement->select.end_time, &coll);
Expand Down Expand Up @@ -112,7 +125,6 @@ static void on_write(ev_tcp_handle *client) {
static void on_data(ev_tcp_handle *client) {
if (client->buffer.size == 0)
return;
log_info("Data: %s", client->buffer.buf);
Request rq;
Response rs;
ssize_t n = decode_request((const uint8_t *)client->buffer.buf, &rq);
Expand All @@ -129,7 +141,12 @@ static void on_data(ev_tcp_handle *client) {
rs = execute_statement(&statement);
}

(void)encode_response(&rs, (uint8_t *)client->buffer.buf);
ev_tcp_zero_buffer(client);

n = encode_response(&rs, (uint8_t *)client->buffer.buf);
client->buffer.size = n;
log_info("Data: %s", client->buffer.buf);
free_response(&rs);

ev_tcp_queue_write(client);
}
Expand All @@ -148,7 +165,6 @@ static void on_connection(ev_tcp_handle *server) {
}

int roachdb_server_run(const char *host, int port) {
db = tsdb_init("testdb");
ev_context *ctx = ev_get_context();
ev_tcp_server server;
ev_tcp_server_init(&server, ctx, BACKLOG);
Expand Down

0 comments on commit d738b8c

Please sign in to comment.