Skip to content
This repository has been archived by the owner on Dec 17, 2024. It is now read-only.

Commit

Permalink
fix: queueID was taken wrong sporadically (#2)
Browse files Browse the repository at this point in the history
* fix: queueID was taken wrong sporadically

* fix: handler types
  • Loading branch information
jlorgal authored May 1, 2018
1 parent 3b85f77 commit 8b4dece
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 15 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Each packet that is processed by a netfilter queue is encapsulated in the type `
// Packet struct provides the packet data and methods to accept, drop or modify the packet.
type Packet struct {
Buffer []byte
id C.uint32_t
id uint32
q *Queue
}
Expand Down Expand Up @@ -95,6 +95,7 @@ func NewQueue(id uint16) *Queue {
}
queueCfg := &nfqueue.QueueConfig{
MaxPackets: 1000,
BufferSize: 16 * 1024 * 1024,
QueueFlags: []nfqueue.QueueFlag{nfqueue.FailOpen},
}
// Pass as packet handler the current instance because it implements nfqueue.PacketHandler interface
Expand Down
10 changes: 6 additions & 4 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ import (
)

//export handle
func handle(id C.uint, buffer *C.uchar, len C.int, cbData *unsafe.Pointer) int {
queueID := (*uint16)(unsafe.Pointer(cbData))
q := queueRegistry.Get(*queueID)
func handle(id uint32, buffer *C.uchar, len C.int, queueID int) int {
q := queueRegistry.Get(uint16(queueID))
if q == nil {
return 0
}
packet := &Packet{
id: uint32(id),
id: id,
Buffer: C.GoBytes(unsafe.Pointer(buffer), len),
q: q,
}
Expand Down
9 changes: 3 additions & 6 deletions nfqueue.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,8 @@ func (q *Queue) Start() error {
// It is not possible to pass the queue as callback data due to error:
// runtime error: cgo argument has Go pointer to Go pointer
// As a result, we have to pass the queue ID and use the registry to retrieve the queue.
qid := q.ID
cb := (*C.nfq_callback)(C.nfqueue_cb)
if q.qh = C.nfq_create_queue(q.h, C.u_int16_t(q.ID), cb, unsafe.Pointer(&qid)); q.qh == nil {
return errors.New("Error in nfq_create_queue")
if q.qh = C.nfqueue_create_queue(q.h, C.u_int16_t(q.ID)); q.qh == nil {
return errors.New("Error in nfqueue_create_queue")
}

// Configure mode (packet copy) and the packet size. Note that this is not configurable on purpose.
Expand All @@ -156,8 +154,7 @@ func (q *Queue) Start() error {
}
}

q.fd = C.nfq_fd(q.h)
if q.fd < 0 {
if q.fd = C.nfq_fd(q.h); q.fd < 0 {
return errors.New("Error in nfq_fd")
}

Expand Down
9 changes: 6 additions & 3 deletions nfqueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,20 @@ const uint MAX_PACKET_SIZE = 65535;
// - id (packet identifier)
// - buffer (pointer to the packet data starting from IP layer)
// - len (buffer length)
// - cb_data (pointer to the queue identifier)
extern int handle(uint32_t id, unsigned char* buffer, int len, void *cb_data);
// - queue_id (queue identifier)
extern int handle(uint32_t id, unsigned char* buffer, int len, int queue_id);

int nfqueue_cb(struct nfq_q_handle *qh, struct nfgenmsg *nfmsg, struct nfq_data *nfa, void *cb_data)
{
unsigned char *buffer = NULL;
struct nfqnl_msg_packet_hdr *ph = nfq_get_msg_packet_hdr(nfa);
uint32_t id = ntohl(ph->packet_id);
int ret = nfq_get_payload(nfa, &buffer);
return handle(id, buffer, ret, (intptr_t)cb_data);
}

return handle(id, buffer, ret, cb_data);
static struct nfq_q_handle *nfqueue_create_queue(struct nfq_handle *h, u_int16_t queue_id) {
return nfq_create_queue(h, queue_id, &nfqueue_cb, (void *)(intptr_t)queue_id);
}

static int nfqueue_loop(struct nfq_handle *h, int fd)
Expand Down
5 changes: 4 additions & 1 deletion registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,8 @@ func (r *QueueRegistry) Unregister(queueID uint16) {

// Get returns a queue from the registry based on the queueID.
func (r *QueueRegistry) Get(queueID uint16) *Queue {
return r.queues[queueID]
if len(r.queues) > int(queueID) {
return r.queues[queueID]
}
return nil
}

0 comments on commit 8b4dece

Please sign in to comment.