diff --git a/README.md b/README.md index c42838b0..1e91047f 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,9 @@ functions to marshal as JSON or text for instance. The `transport` provides different way of processing the message. Either sending it via Kafka or send it to a file (or stdout). +The `state` supports external storage for a way of synchronizing templates across multiple GoFlow2 +instances. + GoFlow2 is a wrapper of all the functions and chains thems. You can build your own collector using this base and replace parts: @@ -53,6 +56,7 @@ You can build your own collector using this base and replace parts: * Convert to another format (e.g: Cap'n Proto, Avro, instead of protobuf) * Decode different samples (e.g: not only IP networks, add MPLS) * Different metrics system (e.g: [OpenTelemetry](https://opentelemetry.io/)) +* Other external state storage (e.g: RDBMS, MongoDB, memcached) ### Protocol difference @@ -85,8 +89,8 @@ Production: * Convert to protobuf or json * Prints to the console/file * Sends to Kafka and partition - -Monitoring via Prometheus metrics +* Set up multiple GoFlow2 instances backed by the same external state storage +* Monitoring via Prometheus metrics ## Get started @@ -165,6 +169,12 @@ $ ./goflow2 -listen 'sflow://:6343?count=4,nfl://:2055' More information about workers and resource usage is avaialble on the [Performance page](/docs/performance.md). +When you have multiple GoFlow2 instances, it's important to enable external state storage. +```bash +$ ./goflow2 -state.netflow.templates redis://127.0.0.1:6379/0?prefix=nftemplate -state.sampling redis://127.0.0.1:6379/0?prefix=nfsampling +``` +Details available on [State page](/docs/state_storage.md). + ### Docker You can also run directly with a container: diff --git a/cmd/goflow2/main.go b/cmd/goflow2/main.go index 53f46eb1..d695a9fc 100644 --- a/cmd/goflow2/main.go +++ b/cmd/goflow2/main.go @@ -178,6 +178,17 @@ func main() { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) + err = protoproducer.InitSamplingRate() + if err != nil { + log.Fatal(err) + } + defer protoproducer.CloseSamplingRate() + err = netflow.InitTemplates() + if err != nil { + log.Fatal(err) + } + defer netflow.CloseTemplates() + var receivers []*utils.UDPReceiver var pipes []utils.FlowPipe diff --git a/decoders/netflow/netflow_test.go b/decoders/netflow/netflow_test.go index 8c940072..88e23664 100644 --- a/decoders/netflow/netflow_test.go +++ b/decoders/netflow/netflow_test.go @@ -8,7 +8,9 @@ import ( ) func TestDecodeNetFlowV9(t *testing.T) { - templates := CreateTemplateSystem() + err := InitTemplates() + assert.NoError(t, err) + templates := CreateTemplateSystem("TestDecodeNetFlowV9") // Decode a template template := []byte{ @@ -23,7 +25,7 @@ func TestDecodeNetFlowV9(t *testing.T) { } buf := bytes.NewBuffer(template) var decNfv9 NFv9Packet - err := DecodeMessageVersion(buf, templates, &decNfv9, nil) + err = DecodeMessageVersion(buf, templates, &decNfv9, nil) assert.Nil(t, err) assert.Equal(t, NFv9Packet{ diff --git a/decoders/netflow/templates.go b/decoders/netflow/templates.go index 42be916b..ca9b6560 100644 --- a/decoders/netflow/templates.go +++ b/decoders/netflow/templates.go @@ -1,20 +1,25 @@ package netflow import ( + "encoding/json" + "errors" + "flag" "fmt" + "net/url" + "reflect" + "strings" "sync" + + "github.com/netsampler/goflow2/v2/state" ) var ( ErrorTemplateNotFound = fmt.Errorf("Error template not found") + StateTemplates = flag.String("state.netflow.templates", "memory://", fmt.Sprintf("Define state templates engine URL (available schemes: %s)", strings.Join(state.SupportedSchemes, ", "))) + templatesDB state.State[templatesKey, templatesValue] + templatesInitLock = new(sync.Mutex) ) -type FlowBaseTemplateSet map[uint64]interface{} - -func templateKey(version uint16, obsDomainId uint32, templateId uint16) uint64 { - return (uint64(version) << 48) | (uint64(obsDomainId) << 16) | uint64(templateId) -} - // Store interface that allows storing, removing and retrieving template data type NetFlowTemplateSystem interface { RemoveTemplate(version uint16, obsDomainId uint32, templateId uint16) (interface{}, error) @@ -22,64 +27,174 @@ type NetFlowTemplateSystem interface { AddTemplate(version uint16, obsDomainId uint32, templateId uint16, template interface{}) error } -func (ts *BasicTemplateSystem) GetTemplates() FlowBaseTemplateSet { - ts.templateslock.RLock() - tmp := ts.templates - ts.templateslock.RUnlock() - return tmp +type templatesKey struct { + Key string `json:"key"` + Version uint16 `json:"ver"` + ObsDomainId uint32 `json:"obs"` + TemplateID uint16 `json:"tid"` } -func (ts *BasicTemplateSystem) AddTemplate(version uint16, obsDomainId uint32, templateId uint16, template interface{}) error { - ts.templateslock.Lock() - defer ts.templateslock.Unlock() +const ( + templateTypeTest = 0 + templateTypeTemplateRecord = 1 + templateTypeIPFIXOptionsTemplateRecord = 2 + templateTypeNFv9OptionsTemplateRecord = 3 +) - /*var templateId uint16 - switch templateIdConv := template.(type) { - case IPFIXOptionsTemplateRecord: - templateId = templateIdConv.TemplateId - case NFv9OptionsTemplateRecord: - templateId = templateIdConv.TemplateId - case TemplateRecord: - templateId = templateIdConv.TemplateId - }*/ - key := templateKey(version, obsDomainId, templateId) - ts.templates[key] = template - return nil +type templatesValue struct { + TemplateType int `json:"ttype"` + Data interface{} `json:"data"` +} + +type templatesValueUnmarshal struct { + TemplateType int `json:"ttype"` + Data json.RawMessage `json:"data"` } -func (ts *BasicTemplateSystem) GetTemplate(version uint16, obsDomainId uint32, templateId uint16) (interface{}, error) { - ts.templateslock.RLock() - defer ts.templateslock.RUnlock() - key := templateKey(version, obsDomainId, templateId) - if template, ok := ts.templates[key]; ok { - return template, nil +func (t *templatesValue) UnmarshalJSON(bytes []byte) error { + var v templatesValueUnmarshal + err := json.Unmarshal(bytes, &v) + if err != nil { + return err + } + t.TemplateType = v.TemplateType + switch v.TemplateType { + case templateTypeTest: + var data int + err = json.Unmarshal(v.Data, &data) + if err != nil { + return err + } + t.Data = data + case templateTypeTemplateRecord: + var data TemplateRecord + err = json.Unmarshal(v.Data, &data) + if err != nil { + return err + } + t.Data = data + case templateTypeIPFIXOptionsTemplateRecord: + var data IPFIXOptionsTemplateRecord + err = json.Unmarshal(v.Data, &data) + if err != nil { + return err + } + t.Data = data + case templateTypeNFv9OptionsTemplateRecord: + var data NFv9OptionsTemplateRecord + err = json.Unmarshal(v.Data, &data) + if err != nil { + return err + } + t.Data = data + default: + return fmt.Errorf("unknown template type: %d", v.TemplateType) } - return nil, ErrorTemplateNotFound + return nil +} + +type NetflowTemplate struct { + key string } -func (ts *BasicTemplateSystem) RemoveTemplate(version uint16, obsDomainId uint32, templateId uint16) (interface{}, error) { - ts.templateslock.Lock() - defer ts.templateslock.Unlock() +func (t *NetflowTemplate) RemoveTemplate(version uint16, obsDomainId uint32, templateId uint16) (interface{}, error) { + if v, err := templatesDB.Pop(templatesKey{ + Key: t.key, + Version: version, + ObsDomainId: obsDomainId, + TemplateID: templateId, + }); err != nil && errors.Is(err, state.ErrorKeyNotFound) { + return nil, ErrorTemplateNotFound + } else if err != nil { + return nil, err + } else { + return v.Data, nil + } +} - key := templateKey(version, obsDomainId, templateId) - if template, ok := ts.templates[key]; ok { - delete(ts.templates, key) - return template, nil +func (t *NetflowTemplate) GetTemplate(version uint16, obsDomainId uint32, templateId uint16) (interface{}, error) { + if v, err := templatesDB.Get(templatesKey{ + Key: t.key, + Version: version, + ObsDomainId: obsDomainId, + TemplateID: templateId, + }); err != nil && errors.Is(err, state.ErrorKeyNotFound) { + return nil, ErrorTemplateNotFound + } else if err != nil { + return nil, err + } else { + return v.Data, nil } - return nil, ErrorTemplateNotFound } -type BasicTemplateSystem struct { - templates FlowBaseTemplateSet - templateslock *sync.RWMutex +func (t *NetflowTemplate) AddTemplate(version uint16, obsDomainId uint32, templateId uint16, template interface{}) error { + k := templatesKey{ + Key: t.key, + Version: version, + ObsDomainId: obsDomainId, + TemplateID: templateId, + } + var err error + switch templatec := template.(type) { + case TemplateRecord: + err = templatesDB.Add(k, templatesValue{ + TemplateType: templateTypeTemplateRecord, + Data: templatec, + }) + case IPFIXOptionsTemplateRecord: + err = templatesDB.Add(k, templatesValue{ + TemplateType: templateTypeIPFIXOptionsTemplateRecord, + Data: templatec, + }) + case NFv9OptionsTemplateRecord: + err = templatesDB.Add(k, templatesValue{ + TemplateType: templateTypeNFv9OptionsTemplateRecord, + Data: templatec, + }) + case int: + err = templatesDB.Add(k, templatesValue{ + TemplateType: templateTypeTest, + Data: templatec, + }) + default: + return fmt.Errorf("unknown template type: %s", reflect.TypeOf(template).String()) + } + return err } -// Creates a basic store for NetFlow and IPFIX templates. -// Everyting is stored in memory. -func CreateTemplateSystem() NetFlowTemplateSystem { - ts := &BasicTemplateSystem{ - templates: make(FlowBaseTemplateSet), - templateslock: &sync.RWMutex{}, +func CreateTemplateSystem(key string) NetFlowTemplateSystem { + ts := &NetflowTemplate{ + key: key, } return ts } + +func InitTemplates() error { + templatesInitLock.Lock() + defer templatesInitLock.Unlock() + if templatesDB != nil { + return nil + } + templatesUrl, err := url.Parse(*StateTemplates) + if err != nil { + return err + } + if !templatesUrl.Query().Has("prefix") { + q := templatesUrl.Query() + q.Set("prefix", "goflow2:nf_templates:") + templatesUrl.RawQuery = q.Encode() + } + templatesDB, err = state.NewState[templatesKey, templatesValue](templatesUrl.String()) + return err +} + +func CloseTemplates() error { + templatesInitLock.Lock() + defer templatesInitLock.Unlock() + if templatesDB == nil { + return nil + } + err := templatesDB.Close() + templatesDB = nil + return err +} diff --git a/decoders/netflow/templates_test.go b/decoders/netflow/templates_test.go index 4ee8f5e8..cdb075c0 100644 --- a/decoders/netflow/templates_test.go +++ b/decoders/netflow/templates_test.go @@ -11,13 +11,15 @@ func benchTemplatesAdd(ts NetFlowTemplateSystem, obs uint32, N int, b *testing.B } func BenchmarkTemplatesAdd(b *testing.B) { - ts := CreateTemplateSystem() + InitTemplates() + ts := CreateTemplateSystem("BenchmarkTemplatesAdd") b.Log("Creating", b.N, "templates") benchTemplatesAdd(ts, uint32(b.N)%0xffff+1, b.N, b) } func BenchmarkTemplatesAddGet(b *testing.B) { - ts := CreateTemplateSystem() + InitTemplates() + ts := CreateTemplateSystem("BenchmarkTemplatesAddGet") templates := 1000 b.Log("Adding", templates, "templates") benchTemplatesAdd(ts, 1, templates, b) diff --git a/docs/state_storage.md b/docs/state_storage.md new file mode 100644 index 00000000..20e4f9ba --- /dev/null +++ b/docs/state_storage.md @@ -0,0 +1,28 @@ +# State Storage + +For protocols with template system (Netflow V9 and IPFIX), GoFlow2 stores the information on memory by default. +When using memory, you will lose the template and sampling rate information if GoFlow2 is restarted. So incoming +flows will fail to decode until the next template/option data is sent from the agent. + +However, you can use an external state storage to overcome this issue. With external storage, you will have these +benefits: +- Supports UDP per-packet load balancer (e.g. with F5 or Envoy Proxy) +- Pod/container auto-scaling to handle traffic surge +- Persistent state, GoFlow2 restarts won't need to wait template/option data + +## Memory +The default method for storing state. It's not synced across multiple GoFlow2 instances and lost on process restart. + +## Redis +The supported URL format for redis +is explained at [uri specifications](https://github.com/redis/redis-specifications/blob/master/uri/redis.txt). +GoFlow2 uses the key-value storage provided by redis for persistence, and pub-sub to broadcast any new template data +to other GoFlow2 instances. +GoFlow2 also have other query parameters specific for redis: +- prefix + - this will override key prefix and channel prefix for pubsub + - e.g. redis://127.0.0.1/0?prefix=goflow +- interval + - specify in seconds on how frequent we should re-retrieve values (in case the pubsub doesn't work for some reason). + defaults to `900` seconds, use `0` to disable + - e.g. redis://127.0.0.1/0?interval=0 diff --git a/go.mod b/go.mod index 717abb42..79673e92 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/libp2p/go-reuseport v0.4.0 github.com/oschwald/geoip2-golang v1.9.0 github.com/prometheus/client_golang v1.18.0 + github.com/redis/go-redis/v9 v9.4.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.8.4 github.com/xdg-go/scram v1.1.2 @@ -18,6 +19,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/eapache/go-resiliency v1.3.0 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20230111030713-bf00bc1b83b6 // indirect github.com/eapache/queue v1.1.0 // indirect diff --git a/go.sum b/go.sum index 3dbf7b40..570a8fbd 100644 --- a/go.sum +++ b/go.sum @@ -3,12 +3,16 @@ github.com/Shopify/sarama v1.38.1/go.mod h1:iwv9a67Ha8VNa+TifujYoWGxWnu2kNVAQdSd github.com/Shopify/toxiproxy/v2 v2.5.0 h1:i4LPT+qrSlKNtQf5QliVjdP08GyAH8+BUIc9gT0eahc= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/eapache/go-resiliency v1.3.0 h1:RRL0nge+cWGlxXbUzJ7yMcq6w2XBEr19dCN6HECGaT0= github.com/eapache/go-resiliency v1.3.0/go.mod h1:5yPzW0MIvSe0JDsv0v+DvcjEv2FyD6iZYSs1ZI+iQho= github.com/eapache/go-xerial-snappy v0.0.0-20230111030713-bf00bc1b83b6 h1:8yY/I9ndfrgrXUbOGObLHKBR4Fl3nZXwM2c7OYTT8hM= @@ -67,6 +71,8 @@ github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= +github.com/redis/go-redis/v9 v9.4.0 h1:Yzoz33UZw9I/mFhx4MNrB6Fk+XHO1VukNcCa1+lwyKk= +github.com/redis/go-redis/v9 v9.4.0/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= diff --git a/metrics/templates.go b/metrics/templates.go index f5e21993..41367301 100644 --- a/metrics/templates.go +++ b/metrics/templates.go @@ -15,7 +15,7 @@ type PromTemplateSystem struct { // A default Prometheus template generator function to be used by a pipe func NewDefaultPromTemplateSystem(key string) netflow.NetFlowTemplateSystem { - return NewPromTemplateSystem(key, netflow.CreateTemplateSystem()) + return NewPromTemplateSystem(key, netflow.CreateTemplateSystem(key)) } // Creates a Prometheus template system that wraps another template system. diff --git a/producer/proto/producer_nf.go b/producer/proto/producer_nf.go index d57925cc..74c0528b 100644 --- a/producer/proto/producer_nf.go +++ b/producer/proto/producer_nf.go @@ -3,7 +3,10 @@ package protoproducer import ( "bytes" "encoding/binary" + "flag" "fmt" + "net/url" + "strings" "sync" "time" @@ -11,6 +14,7 @@ import ( "github.com/netsampler/goflow2/v2/decoders/utils" flowmessage "github.com/netsampler/goflow2/v2/pb" "github.com/netsampler/goflow2/v2/producer" + "github.com/netsampler/goflow2/v2/state" ) type SamplingRateSystem interface { @@ -18,55 +22,84 @@ type SamplingRateSystem interface { AddSamplingRate(version uint16, obsDomainId uint32, samplingRate uint32) } -type basicSamplingRateKey struct { - version uint16 - obsDomainId uint32 +type SingleSamplingRateSystem struct { + Sampling uint32 } -type basicSamplingRateSystem struct { - sampling map[basicSamplingRateKey]uint32 - samplinglock *sync.RWMutex +func (s *SingleSamplingRateSystem) AddSamplingRate(version uint16, obsDomainId uint32, samplingRate uint32) { } -func CreateSamplingSystem() SamplingRateSystem { - ts := &basicSamplingRateSystem{ - sampling: make(map[basicSamplingRateKey]uint32), - samplinglock: &sync.RWMutex{}, - } - return ts +func (s *SingleSamplingRateSystem) GetSamplingRate(version uint16, obsDomainId uint32) (uint32, error) { + return s.Sampling, nil } -func (s *basicSamplingRateSystem) AddSamplingRate(version uint16, obsDomainId uint32, samplingRate uint32) { - s.samplinglock.Lock() - defer s.samplinglock.Unlock() - s.sampling[basicSamplingRateKey{ - version: version, - obsDomainId: obsDomainId, - }] = samplingRate +var ( + StateSampling = flag.String("state.sampling", "memory://", fmt.Sprintf("Define state sampling rate engine URL (available schemes: %s)", strings.Join(state.SupportedSchemes, ", "))) + samplingRateDB state.State[samplingRateKey, uint32] + samplingRateInitLock = new(sync.Mutex) +) + +type samplingRateKey struct { + Key string `json:"key"` + Version uint16 `json:"ver"` + ObsDomainId uint32 `json:"obs"` } -func (s *basicSamplingRateSystem) GetSamplingRate(version uint16, obsDomainId uint32) (uint32, error) { - s.samplinglock.RLock() - defer s.samplinglock.RUnlock() - if samplingRate, ok := s.sampling[basicSamplingRateKey{ - version: version, - obsDomainId: obsDomainId, - }]; ok { - return samplingRate, nil - } +type SamplingRate struct { + key string +} - return 0, fmt.Errorf("sampling rate not found") +func (s *SamplingRate) GetSamplingRate(version uint16, obsDomainId uint32) (uint32, error) { + return samplingRateDB.Get(samplingRateKey{ + Key: s.key, + Version: version, + ObsDomainId: obsDomainId, + }) } -type SingleSamplingRateSystem struct { - Sampling uint32 +func (s *SamplingRate) AddSamplingRate(version uint16, obsDomainId uint32, samplingRate uint32) { + _ = samplingRateDB.Add(samplingRateKey{ + Key: s.key, + Version: version, + ObsDomainId: obsDomainId, + }, samplingRate) } -func (s *SingleSamplingRateSystem) AddSamplingRate(version uint16, obsDomainId uint32, samplingRate uint32) { +func CreateSamplingSystem(key string) SamplingRateSystem { + ts := &SamplingRate{ + key: key, + } + return ts } -func (s *SingleSamplingRateSystem) GetSamplingRate(version uint16, obsDomainId uint32) (uint32, error) { - return s.Sampling, nil +func InitSamplingRate() error { + samplingRateInitLock.Lock() + defer samplingRateInitLock.Unlock() + if samplingRateDB != nil { + return nil + } + samplingUrl, err := url.Parse(*StateSampling) + if err != nil { + return err + } + if !samplingUrl.Query().Has("prefix") { + q := samplingUrl.Query() + q.Set("prefix", "goflow2:sampling_rate:") + samplingUrl.RawQuery = q.Encode() + } + samplingRateDB, err = state.NewState[samplingRateKey, uint32](samplingUrl.String()) + return err +} + +func CloseSamplingRate() error { + samplingRateInitLock.Lock() + defer samplingRateInitLock.Unlock() + if samplingRateDB == nil { + return nil + } + err := samplingRateDB.Close() + samplingRateDB = nil + return err } func NetFlowLookFor(dataFields []netflow.DataField, typeId uint16) (bool, interface{}) { diff --git a/producer/proto/proto.go b/producer/proto/proto.go index 47e15971..535c4457 100644 --- a/producer/proto/proto.go +++ b/producer/proto/proto.go @@ -14,7 +14,7 @@ type ProtoProducer struct { cfgMapped *producerConfigMapped samplinglock *sync.RWMutex sampling map[string]SamplingRateSystem - samplingRateSystem func() SamplingRateSystem + samplingRateSystem func(key string) SamplingRateSystem } func (p *ProtoProducer) enrich(flowMessageSet []producer.ProducerMessage, cb func(msg *ProtoProducerMessage)) { @@ -33,7 +33,7 @@ func (p *ProtoProducer) getSamplingRateSystem(args *producer.ProduceArgs) Sampli sampling, ok := p.sampling[key] p.samplinglock.RUnlock() if !ok { - sampling = p.samplingRateSystem() + sampling = p.samplingRateSystem(key) p.samplinglock.Lock() p.sampling[key] = sampling p.samplinglock.Unlock() @@ -95,7 +95,7 @@ func (p *ProtoProducer) Commit(flowMessageSet []producer.ProducerMessage) { func (p *ProtoProducer) Close() {} -func CreateProtoProducer(cfg *ProducerConfig, samplingRateSystem func() SamplingRateSystem) (producer.ProducerInterface, error) { +func CreateProtoProducer(cfg *ProducerConfig, samplingRateSystem func(key string) SamplingRateSystem) (producer.ProducerInterface, error) { cfgMapped, err := mapConfig(cfg) return &ProtoProducer{ cfgMapped: cfgMapped, diff --git a/state/memory.go b/state/memory.go new file mode 100644 index 00000000..21ed5e56 --- /dev/null +++ b/state/memory.go @@ -0,0 +1,49 @@ +package state + +import ( + "sync" +) + +type memoryState[K comparable, V any] struct { + data map[K]V + lock *sync.RWMutex +} + +func (m *memoryState[K, V]) Close() error { + return nil +} + +func (m *memoryState[K, V]) Get(key K) (V, error) { + m.lock.RLock() + defer m.lock.RUnlock() + if v, ok := m.data[key]; ok { + return v, nil + } else { + return v, ErrorKeyNotFound + } +} + +func (m *memoryState[K, V]) Add(key K, value V) error { + m.lock.Lock() + defer m.lock.Unlock() + m.data[key] = value + return nil +} + +func (m *memoryState[K, V]) Delete(key K) error { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.data, key) + return nil +} + +func (m *memoryState[K, V]) Pop(key K) (V, error) { + m.lock.Lock() + defer m.lock.Unlock() + if v, ok := m.data[key]; ok { + delete(m.data, key) + return v, nil + } else { + return v, ErrorKeyNotFound + } +} diff --git a/state/redis.go b/state/redis.go new file mode 100644 index 00000000..02837a0a --- /dev/null +++ b/state/redis.go @@ -0,0 +1,205 @@ +package state + +import ( + "context" + "encoding/json" + "fmt" + "github.com/redis/go-redis/v9" + "net/url" + "strconv" + "strings" + "sync" + "time" +) + +type redisState[K comparable, V any] struct { + memory memoryState[K, V] + urlParsed *url.URL + rPrefix string + refreshInterval int + db *redis.Client + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup +} + +const ( + redisOpAdd = 1 + redisOpDel = 2 +) + +func (r *redisState[K, V]) populate() error { + iter := r.db.Scan(r.ctx, 0, fmt.Sprintf("%s*", r.rPrefix), 0).Iterator() + for iter.Next(r.ctx) { + kRaw := iter.Val() + res := r.db.Get(r.ctx, kRaw) + vRaw, err := res.Bytes() + if err != nil { + return err + } + var k K + var v V + kRaw, _ = strings.CutPrefix(kRaw, r.rPrefix) + if err = json.Unmarshal([]byte(kRaw), &k); err != nil { + return err + } + if err = json.Unmarshal(vRaw, &v); err != nil { + return err + } + if err = r.memory.Add(k, v); err != nil { + return err + } + } + if err := iter.Err(); err != nil { + return err + } + return nil +} + +func (r *redisState[K, V]) init() error { + q := r.urlParsed.Query() + r.rPrefix = q.Get("prefix") + if r.rPrefix == "" { + return fmt.Errorf("'prefix' name is required on redis state engine, place it on your URL query string") + } + q.Del("prefix") + interval := q.Get("interval") + if interval == "" { + interval = "900" + } + q.Del("interval") + r.urlParsed.RawQuery = q.Encode() + var err error + r.refreshInterval, err = strconv.Atoi(interval) + if err != nil { + return err + } + opts, err := redis.ParseURL(r.urlParsed.String()) + if err != nil { + return err + } + r.db = redis.NewClient(opts) + if err != nil { + return err + } + // pre-populate local memory copy from existing redis data + if err = r.populate(); err != nil { + return err + } + // refresh goroutine + if r.refreshInterval > 0 { + r.wg.Add(1) + go func() { + defer r.wg.Done() + refreshCh := time.After(time.Duration(r.refreshInterval) * time.Second) + mainLoop: + for { + select { + case <-refreshCh: + _ = r.populate() + refreshCh = time.After(time.Duration(r.refreshInterval) * time.Second) + case <-r.ctx.Done(): + break mainLoop + } + } + }() + } + // subscribe to value changes + r.wg.Add(1) + go func() { + defer r.wg.Done() + ps := r.db.PSubscribe(r.ctx, fmt.Sprintf("%s*", r.rPrefix)) + defer ps.Close() + ch := ps.Channel() + mainLoop: + for { + select { + case msgRaw := <-ch: + op, err := strconv.Atoi(msgRaw.Payload) + if err != nil { + continue + } + keyRaw, _ := strings.CutPrefix(msgRaw.Channel, r.rPrefix) + var k K + if err = json.Unmarshal([]byte(keyRaw), &k); err != nil { + continue + } + switch op { + case redisOpAdd: + cmd := r.db.Get(r.ctx, msgRaw.Channel) + vBytes, err := cmd.Bytes() + if err != nil { + continue + } + var v V + if err = json.Unmarshal(vBytes, &v); err != nil { + continue + } + _ = r.memory.Add(k, v) + case redisOpDel: + _ = r.memory.Delete(k) + } + case <-r.ctx.Done(): + break mainLoop + } + } + }() + return nil +} + +func (r *redisState[K, V]) Close() error { + r.cancel() + r.wg.Wait() + return r.db.Close() +} + +func (r *redisState[K, V]) Get(key K) (V, error) { + return r.memory.Get(key) +} + +func (r *redisState[K, V]) Add(key K, value V) error { + k, err := json.Marshal(key) + if err != nil { + return err + } + v, err := json.Marshal(value) + if err != nil { + return err + } + kStr := fmt.Sprintf("%s%s", r.rPrefix, string(k)) + setStatus := r.db.Set(r.ctx, kStr, v, 0) + if err = setStatus.Err(); err != nil { + return err + } + pubStatus := r.db.Publish(r.ctx, kStr, redisOpAdd) + if err = pubStatus.Err(); err != nil { + return err + } + return nil +} + +func (r *redisState[K, V]) Delete(key K) error { + k, err := json.Marshal(key) + if err != nil { + return err + } + kStr := fmt.Sprintf("%s%s", r.rPrefix, string(k)) + delStatus := r.db.Del(r.ctx, kStr) + if err = delStatus.Err(); err != nil { + return err + } + pubStatus := r.db.Publish(r.ctx, kStr, redisOpDel) + if err = pubStatus.Err(); err != nil { + return err + } + return nil +} + +func (r *redisState[K, V]) Pop(key K) (V, error) { + v, err := r.Get(key) + if err != nil { + return v, err + } + err = r.Delete(key) + return v, err +} diff --git a/state/state.go b/state/state.go new file mode 100644 index 00000000..60e59dc0 --- /dev/null +++ b/state/state.go @@ -0,0 +1,52 @@ +package state + +import ( + "context" + "fmt" + "net/url" + "sync" +) + +var ( + SupportedSchemes = []string{"memory", "redis"} + ErrorKeyNotFound = fmt.Errorf("key not found") +) + +type State[K comparable, V any] interface { + Close() error + Get(key K) (V, error) + Add(key K, value V) error + Delete(key K) error + Pop(key K) (V, error) +} + +func NewState[K comparable, V any](rawUrl string) (State[K, V], error) { + urlParsed, err := url.Parse(rawUrl) + if err != nil { + return nil, err + } + memory := memoryState[K, V]{ + data: make(map[K]V), + lock: new(sync.RWMutex), + } + switch urlParsed.Scheme { + case "memory": + return &memory, nil + case "redis", "rediss": + ctx, cancel := context.WithCancel(context.Background()) + rd := &redisState[K, V]{ + memory: memory, + urlParsed: urlParsed, + ctx: ctx, + cancel: cancel, + wg: new(sync.WaitGroup), + } + if err = rd.init(); err != nil { + return nil, err + } else { + return rd, nil + } + default: + return nil, fmt.Errorf("unknown state name %s", urlParsed.Scheme) + } +} diff --git a/utils/templates/templates.go b/utils/templates/templates.go index 0f61c80a..b6038756 100644 --- a/utils/templates/templates.go +++ b/utils/templates/templates.go @@ -10,5 +10,5 @@ type TemplateSystemGenerator func(key string) netflow.NetFlowTemplateSystem // Default template generator func DefaultTemplateGenerator(key string) netflow.NetFlowTemplateSystem { - return netflow.CreateTemplateSystem() + return netflow.CreateTemplateSystem(key) }