Skip to content

Commit

Permalink
Fixed decoding bug in the wire protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
codepr committed Mar 20, 2024
1 parent d738b8c commit ac0a1fb
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 34 deletions.
66 changes: 37 additions & 29 deletions src/protocol.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "protocol.h"
#include <stdio.h>
#include <string.h>

static ssize_t encode_string(uint8_t *dst, const char *src, size_t length) {
size_t i = 0, j = 0;
Expand Down Expand Up @@ -120,8 +121,7 @@ static ssize_t decode_string(const uint8_t *ptr, Response *dst) {
}

ssize_t decode_response(const uint8_t *data, Response *dst) {
const uint8_t *ptr = data;
uint8_t byte = *ptr;
uint8_t byte = *data;
ssize_t length = 0;

dst->type = byte == '#' ? ARRAY_RSP : STRING_RSP;
Expand All @@ -130,54 +130,58 @@ ssize_t decode_response(const uint8_t *data, Response *dst) {
case '$':
case '!':
// Treat error and common strings the same for now
length = decode_string(ptr, dst);
length = decode_string(data, dst);
break;
case '#':
ptr++;
data++;
length++;
// Read length
dst->array_response.length = 0;
while (*ptr != '\r' && *(ptr + 1) != '\n') {
while (*data != '\r' && *(data + 1) != '\n') {
dst->array_response.length *= 10;
dst->array_response.length += *ptr - '0';
ptr++;
dst->array_response.length += *data - '0';
data++;
length++;
}

// Jump over \r\n
ptr += 2;
data += 2;
length += 2;

// Read records
size_t j = 0;
size_t total_records = dst->array_response.length;
uint8_t buf[32];
size_t k = 0;
// TODO arena malloc here
dst->array_response.records =
malloc(total_records * sizeof(*dst->array_response.records));
while (total_records-- > 0) {
// Timestamp
if (*ptr++ == ':') {
while (*ptr != '\r' && *(ptr + 1) != '\n') {
dst->array_response.records[j].timestamp *= 10;
dst->array_response.records[j].timestamp += *ptr - '0';
ptr++;
length++;
}
} else {
// Value
uint8_t buf[32];
size_t k = 0;
while (*ptr != '\r' && *(ptr + 1) != '\n') {
buf[k++] = *ptr;
ptr++;
length++;
}
char *end;
dst->array_response.records[j].value =
strtod((char *)buf, &end);
}
if (*data++ != ':')
goto cleanup;

while (*data != '\r' && *(data + 1) != '\n' && length++)
buf[k++] = *data++;

dst->array_response.records[j].timestamp = atoll((const char *)buf);
memset(buf, 0x00, sizeof(buf));
k = 0;

// Skip CRLF + ;
data += 3;
length += 3;

// Value
while (*data != '\r' && *(data + 1) != '\n' && length++)
buf[k++] = *data++;

buf[k] = '\0';

dst->array_response.records[j].value = strtold((char *)buf, NULL);

// Skip CRLF
ptr += 2;
data += 2;
length += 2;
j++;
}
Expand All @@ -187,6 +191,10 @@ ssize_t decode_response(const uint8_t *data, Response *dst) {
}

return length;

cleanup:
free(dst->array_response.records);
return -1;
}

void free_response(Response *rs) {
Expand Down
22 changes: 17 additions & 5 deletions src/server.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#define BACKLOG 128

// testing dummy
static Timeseries_DB *db;
static Timeseries_DB *db = NULL;

static Response execute_statement(const Statement *statement) {
Response rs;
Expand All @@ -38,6 +38,8 @@ static Response execute_statement(const Statement *statement) {
if (!db)
db = tsdb_init(statement->insert.db_name);
ts = ts_get(db, statement->insert.ts_name);
if (!ts)
goto errdefer;
uint64_t timestamp = 0;
for (size_t i = 0; i < statement->insert.record_len; ++i) {
if (statement->insert.records[i].timestamp == -1) {
Expand All @@ -54,8 +56,12 @@ static Response execute_statement(const Statement *statement) {
break;
case STATEMENT_SELECT:
if (!db)
db = tsdb_init(statement->insert.db_name);
db = tsdb_init(statement->select.db_name);

ts = ts_get(db, statement->select.ts_name);
if (!ts)
goto errdefer;
ts_print(ts);
int err = 0;
Points coll;
vec_new(coll);
Expand All @@ -65,9 +71,7 @@ static Response execute_statement(const Statement *statement) {
if (err < 0) {
log_error("Couldn't find the record %lu",
statement->select.start_time);
rs.type = STRING_RSP;
rs.string_response.length = 9;
strncpy(rs.string_response.message, "Not found", 10);
goto errdefer;
} else {
log_info("Record found: %lu %.2lf", r.timestamp, r.value);
rs.type = ARRAY_RSP;
Expand Down Expand Up @@ -106,6 +110,14 @@ static Response execute_statement(const Statement *statement) {
ts_close(ts);

return rs;

errdefer:

rs.type = STRING_RSP;
rs.string_response.length = 9;
strncpy(rs.string_response.message, "Not found", 10);

return rs;
}

static void on_close(ev_tcp_handle *client, int err) {
Expand Down

0 comments on commit ac0a1fb

Please sign in to comment.