diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 000000000..d921d0ffd --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: +- package-ecosystem: gomod + directory: "/" + schedule: + interval: daily + open-pull-requests-limit: 10 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 000000000..278241b5d --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,53 @@ +name: CI +on: + push: + tags: + - v* + branches: + - master + - main + pull_request: + branches: + - master + - main +jobs: + golangci: + name: lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + version: v1.41 + + build: + name: build + runs-on: ubuntu-latest + strategy: + matrix: + go: [1.16] + fix-version: + - FIX_TEST= + - FIX_TEST=fix40 + - FIX_TEST=fix41 + - FIX_TEST=fix42 + - FIX_TEST=fix43 + - FIX_TEST=fix44 + - FIX_TEST=fix50 + - FIX_TEST=fix50sp1 + - FIX_TEST=fix50sp2 + steps: + - name: Setup + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + - name: Check out source + uses: actions/checkout@v2 + - name: Run Mongo + run: docker run -d -p 27017:27017 mongo + - name: Test + env: + GO111MODULE: "on" + MONGODB_TEST_CXN: "localhost" + run: make generate; if [ -z "$FIX_TEST" ]; then make build; make; else make build_accept; make $FIX_TEST; fi \ No newline at end of file diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 55955e15c..000000000 --- a/.travis.yml +++ /dev/null @@ -1,33 +0,0 @@ -language: go -sudo: false - -go: - - 1.9 - - tip - -services: - - mongodb - -env: - global: - - MONGODB_TEST_CXN=localhost - matrix: - - FIX_TEST= - - FIX_TEST=fix40 - - FIX_TEST=fix41 - - FIX_TEST=fix42 - - FIX_TEST=fix43 - - FIX_TEST=fix44 - - FIX_TEST=fix50 - - FIX_TEST=fix50sp1 - - FIX_TEST=fix50sp2 - -matrix: - allow_failures: - - go: tip - -install: - - go get -u github.com/golang/dep/cmd/dep - - dep ensure - -script: make generate; if [ -z "$FIX_TEST" ]; then make build; make; else make build_accept; make $FIX_TEST; fi diff --git a/Gopkg.lock b/Gopkg.lock deleted file mode 100644 index 1504285e2..000000000 --- a/Gopkg.lock +++ /dev/null @@ -1,97 +0,0 @@ -# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. - - -[[projects]] - digest = "1:56c130d885a4aacae1dd9c7b71cfe39912c7ebc1ff7d2b46083c8812996dc43b" - name = "github.com/davecgh/go-spew" - packages = ["spew"] - pruneopts = "" - revision = "346938d642f2ec3594ed81d874461961cd0faa76" - version = "v1.1.0" - -[[projects]] - branch = "master" - digest = "1:e9ffb9315dce0051beb757d0f0fc25db57c4da654efc4eada4ea109c2d9da815" - name = "github.com/globalsign/mgo" - packages = [ - ".", - "bson", - "internal/json", - "internal/sasl", - "internal/scram", - ] - pruneopts = "" - revision = "eeefdecb41b842af6dc652aaea4026e8403e62df" - -[[projects]] - digest = "1:1cc12f4618ce8d71ca28ef3708f4e98e1318ab6f06ecfffb6781b893f271c89c" - name = "github.com/mattn/go-sqlite3" - packages = ["."] - pruneopts = "" - revision = "ca5e3819723d8eeaf170ad510e7da1d6d2e94a08" - version = "v1.2.0" - -[[projects]] - digest = "1:256484dbbcd271f9ecebc6795b2df8cad4c458dd0f5fd82a8c2fa0c29f233411" - name = "github.com/pmezard/go-difflib" - packages = ["difflib"] - pruneopts = "" - revision = "792786c7400a136282c1664665ae0a8db921c6c2" - version = "v1.0.0" - -[[projects]] - branch = "master" - digest = "1:68a81aa25065b50a4bf1ffd115ff3634704f61f675d0140b31492e9fcca55421" - name = "github.com/shopspring/decimal" - packages = ["."] - pruneopts = "" - revision = "aed1bfe463fa3c9cc268d60dcc1491db613bff7e" - -[[projects]] - branch = "master" - digest = "1:ed7ac53c7d59041f27964d3f04e021b45ecb5f23c842c84d778a7f1fb67e2ce9" - name = "github.com/stretchr/objx" - packages = ["."] - pruneopts = "" - revision = "1a9d0bb9f541897e62256577b352fdbc1fb4fd94" - -[[projects]] - digest = "1:3926a4ec9a4ff1a072458451aa2d9b98acd059a45b38f7335d31e06c3d6a0159" - name = "github.com/stretchr/testify" - packages = [ - "assert", - "mock", - "require", - "suite", - ] - pruneopts = "" - revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" - version = "v1.1.4" - -[[projects]] - branch = "master" - digest = "1:898bc7c802c1e0c20cecd65811e90b7b9bc5651b4a07aefd159451bfb200b2b3" - name = "golang.org/x/net" - packages = [ - "context", - "proxy", - ] - pruneopts = "" - revision = "a04bdaca5b32abe1c069418fb7088ae607de5bd0" - -[solve-meta] - analyzer-name = "dep" - analyzer-version = 1 - input-imports = [ - "github.com/globalsign/mgo", - "github.com/globalsign/mgo/bson", - "github.com/mattn/go-sqlite3", - "github.com/shopspring/decimal", - "github.com/stretchr/testify/assert", - "github.com/stretchr/testify/mock", - "github.com/stretchr/testify/require", - "github.com/stretchr/testify/suite", - "golang.org/x/net/proxy", - ] - solver-name = "gps-cdcl" - solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml deleted file mode 100644 index 3385a5e25..000000000 --- a/Gopkg.toml +++ /dev/null @@ -1,34 +0,0 @@ - -# Gopkg.toml example -# -# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md -# for detailed Gopkg.toml documentation. -# -# required = ["github.com/user/thing/cmd/thing"] -# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] -# -# [[constraint]] -# name = "github.com/user/project" -# version = "1.0.0" -# -# [[constraint]] -# name = "github.com/user/project2" -# branch = "dev" -# source = "github.com/myfork/project2" -# -# [[override]] -# name = "github.com/x/y" -# version = "2.4.0" - - -[[constraint]] - name = "github.com/mattn/go-sqlite3" - version = "1.2.0" - -[[constraint]] - name = "github.com/shopspring/decimal" - branch = "master" - -[[constraint]] - name = "github.com/stretchr/testify" - version = "1.1.4" diff --git a/Makefile b/Makefile index d513ded8a..8b7d0ba2d 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,11 @@ all: vet test -generate: +clean: + rm -rf gen + +generate: clean mkdir -p gen; cd gen; go run ../cmd/generate-fix/generate-fix.go ../spec/*.xml + go get -u all generate-dist: cd ..; go run quickfix/cmd/generate-fix/generate-fix.go quickfix/spec/*.xml diff --git a/README.md b/README.md index f9a59056f..fbfd16f81 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ QuickFIX/Go =========== -[![GoDoc](https://godoc.org/github.com/quickfixgo/quickfix?status.png)](https://godoc.org/github.com/quickfixgo/quickfix) [![Build Status](https://travis-ci.org/quickfixgo/quickfix.svg?branch=master)](https://travis-ci.org/quickfixgo/quickfix) [![Go Report Card](https://goreportcard.com/badge/github.com/quickfixgo/quickfix)](https://goreportcard.com/report/github.com/quickfixgo/quickfix) +[![Build Status](https://github.com/quickfixgo/quickfix/workflows/CI/badge.svg)](https://github.com/quickfixgo/quickfix/actions) [![GoDoc](https://godoc.org/github.com/quickfixgo/quickfix?status.png)](https://godoc.org/github.com/quickfixgo/quickfix) [![Go Report Card](https://goreportcard.com/badge/github.com/quickfixgo/quickfix)](https://goreportcard.com/report/github.com/quickfixgo/quickfix) - Website: http://www.quickfixgo.org - Mailing list: [Google Groups](https://groups.google.com/forum/#!forum/quickfixgo) @@ -54,24 +54,16 @@ Following installation, `generate-fix` is installed to `$GOPATH/bin/generate-fix Developing QuickFIX/Go ---------------------- -If you wish to work on QuickFIX/Go itself, you will first need [Go](http://www.golang.org) installed on your machine (version 1.6+ is *required*). +If you wish to work on QuickFIX/Go itself, you will first need [Go](http://www.golang.org) installed and configured on your machine (version 1.13+ is preferred, but the minimum required version is 1.6). -For local dev first make sure Go is properly installed, including setting up a [GOPATH](http://golang.org/doc/code.html#GOPATH). - -Next, using [Git](https://git-scm.com/), clone this repository into `$GOPATH/src/github.com/quickfixgo/quickfix`. +Next, using [Git](https://git-scm.com/), clone the repository via `git clone git@github.com:quickfixgo/quickfix.git` ### Installing Dependencies -QuickFIX/Go uses [dep](https://github.com/golang/dep) to manage the vendored dependencies. Install dep with `go get`: - -```sh -$ go get -u github.com/golang/dep/cmd/dep -``` - -Run `dep ensure` to install the correct versioned dependencies into `vendor/`, which Go 1.6+ automatically recognizes and loads. +As of Go version 1.13, QuickFIX/Go uses [modules](https://github.com/golang/go/wiki/Modules) to manage dependencies. You may require `GO111MODULE=on`. To install dependencies, run ```sh -$ $GOPATH/bin/dep ensure +go mod download ``` **Note:** No vendored dependencies are included in the QuickFIX/Go source. @@ -117,37 +109,22 @@ To run acceptance tests, If you are developing QuickFIX/Go, there are a few tasks you might need to perform related to dependencies. -#### Adding a dependency - -If you are adding a dependency, you will need to update the dep manifest in the same Pull Request as the code that depends on it. You should do this in a separate commit from your code, as this makes PR review easier and Git history simpler to read in the future. +#### Adding/updating a dependency -To add a dependency: - -1. Add the dependency using `dep`: -```bash -$ dep ensure -add github.com/foo/bar -``` -2. Review the changes in git and commit them. +If you are adding or updating a dependency, you will need to update the `go.mod` and `go.sum` in the same Pull Request as the code that depends on it. You should do this in a separate commit from your code, as this makes PR review easier and Git history simpler to read in the future. -#### Updating a dependency - -To update a dependency to the latest version allowed by constraints in `Gopkg.toml`: - -1. Run: -```bash -$ dep ensure -update github.com/foo/bar +1. Add or update the dependency like usual: +```sh +go get -u github.com/foo/bar ``` -2. Review the changes in git and commit them. - -To change the allowed version/branch/revision of a dependency: - -1. Manually edit `Gopkg.toml` -2. Run: -```bash -$ dep ensure +2. Update the module-related files: +```sh +go mod tidy ``` 3. Review the changes in git and commit them. +Note that to specify a specific revision, you can manually edit the `go.mod` file and run `go mod tidy` + Licensing --------- diff --git a/accepter_test.go b/accepter_test.go new file mode 100644 index 000000000..6924447fe --- /dev/null +++ b/accepter_test.go @@ -0,0 +1,56 @@ +package quickfix + +import ( + "net" + "testing" + + "github.com/armon/go-proxyproto" + "github.com/stretchr/testify/assert" +) + +func TestAcceptor_Start(t *testing.T) { + settingsWithTCPProxy := NewSettings() + settingsWithTCPProxy.GlobalSettings().Set("UseTCPProxy", "Y") + + settingsWithNoTCPProxy := NewSettings() + settingsWithNoTCPProxy.GlobalSettings().Set("UseTCPProxy", "N") + + genericSettings := NewSettings() + + const ( + GenericListener = iota + ProxyListener + ) + + acceptorStartTests := []struct { + name string + settings *Settings + listenerType int + }{ + {"with TCP proxy set", settingsWithTCPProxy, ProxyListener}, + {"with no TCP proxy set", settingsWithNoTCPProxy, GenericListener}, + {"no TCP proxy configuration set", genericSettings, GenericListener}, + } + + for _, tt := range acceptorStartTests { + t.Run(tt.name, func(t *testing.T) { + tt.settings.GlobalSettings().Set("SocketAcceptPort", "5001") + + acceptor := &Acceptor{settings: tt.settings} + if err := acceptor.Start(); err != nil { + assert.NotNil(t, err) + } + if tt.listenerType == ProxyListener { + _, ok := acceptor.listener.(*proxyproto.Listener) + assert.True(t, ok) + } + + if tt.listenerType == GenericListener { + _, ok := acceptor.listener.(*net.TCPListener) + assert.True(t, ok) + } + + acceptor.Stop() + }) + } +} \ No newline at end of file diff --git a/acceptor.go b/acceptor.go index 1fd1be6f8..ac0ab6e52 100644 --- a/acceptor.go +++ b/acceptor.go @@ -10,25 +10,37 @@ import ( "strconv" "sync" + "github.com/armon/go-proxyproto" "github.com/quickfixgo/quickfix/config" ) //Acceptor accepts connections from FIX clients and manages the associated sessions. type Acceptor struct { - app Application - settings *Settings - logFactory LogFactory - storeFactory MessageStoreFactory - globalLog Log - sessions map[SessionID]*session - sessionGroup sync.WaitGroup - listener net.Listener - listenerShutdown sync.WaitGroup - dynamicSessions bool - dynamicSessionChan chan *session + app Application + settings *Settings + logFactory LogFactory + storeFactory MessageStoreFactory + globalLog Log + sessions map[SessionID]*session + sessionGroup sync.WaitGroup + listener net.Listener + listenerShutdown sync.WaitGroup + dynamicSessions bool + dynamicQualifier bool + dynamicQualifierCount int + dynamicSessionChan chan *session + sessionAddr map[SessionID]net.Addr + connectionValidator ConnectionValidator sessionFactory } +// ConnectionValidator is an interface allowing to implement a custom authentication logic. +type ConnectionValidator interface { + // Validate the connection for validity. This can be a part of authentication process. + // For example, you may tie up a SenderCompID to an IP range, or to a specific TLS certificate as a part of mTLS. + Validate(netConn net.Conn, session SessionID) error +} + //Start accepting connections. func (a *Acceptor) Start() error { socketAcceptHost := "" @@ -49,11 +61,24 @@ func (a *Acceptor) Start() error { return err } + var useTCPProxy bool + if a.settings.GlobalSettings().HasSetting(config.UseTCPProxy) { + if useTCPProxy, err = a.settings.GlobalSettings().BoolSetting(config.UseTCPProxy); err != nil { + return err + } + } + address := net.JoinHostPort(socketAcceptHost, strconv.Itoa(socketAcceptPort)) if tlsConfig != nil { if a.listener, err = tls.Listen("tcp", address, tlsConfig); err != nil { return err } + } else if useTCPProxy { + listener, err := net.Listen("tcp", address) + if err != nil { + return err + } + a.listener = &proxyproto.Listener{Listener: listener} } else { if a.listener, err = net.Listen("tcp", address); err != nil { return err @@ -98,6 +123,12 @@ func (a *Acceptor) Stop() { a.sessionGroup.Wait() } +//Get remote IP address for a given session. +func (a *Acceptor) RemoteAddr(sessionID SessionID) (net.Addr, bool) { + addr, ok := a.sessionAddr[sessionID] + return addr, ok +} + //NewAcceptor creates and initializes a new Acceptor. func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Settings, logFactory LogFactory) (a *Acceptor, err error) { a = &Acceptor{ @@ -106,11 +137,18 @@ func NewAcceptor(app Application, storeFactory MessageStoreFactory, settings *Se settings: settings, logFactory: logFactory, sessions: make(map[SessionID]*session), + sessionAddr: make(map[SessionID]net.Addr), } if a.settings.GlobalSettings().HasSetting(config.DynamicSessions) { if a.dynamicSessions, err = settings.globalSettings.BoolSetting(config.DynamicSessions); err != nil { return } + + if a.settings.GlobalSettings().HasSetting(config.DynamicQualifier) { + if a.dynamicQualifier, err = settings.globalSettings.BoolSetting(config.DynamicQualifier); err != nil { + return + } + } } if a.globalLog, err = logFactory.Create(); err != nil { @@ -237,6 +275,19 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { SenderCompID: string(targetCompID), SenderSubID: string(targetSubID), SenderLocationID: string(targetLocationID), TargetCompID: string(senderCompID), TargetSubID: string(senderSubID), TargetLocationID: string(senderLocationID), } + + // We have a session ID and a network connection. This seems to be a good place for any custom authentication logic. + if a.connectionValidator != nil { + if err := a.connectionValidator.Validate(netConn, sessID); err != nil { + a.globalLog.OnEventf("Unable to validate a connection %v", err.Error()) + return + } + } + + if a.dynamicQualifier { + a.dynamicQualifierCount++ + sessID.Qualifier = strconv.Itoa(a.dynamicQualifierCount) + } session, ok := a.sessions[sessID] if !ok { if !a.dynamicSessions { @@ -253,6 +304,7 @@ func (a *Acceptor) handleConnection(netConn net.Conn) { defer session.stop() } + a.sessionAddr[sessID] = netConn.RemoteAddr() msgIn := make(chan fixIn) msgOut := make(chan []byte) @@ -297,7 +349,13 @@ LOOP: complete <- sessionID }() case id := <-complete: - delete(sessions, id) + session, ok := sessions[id] + if ok { + delete(a.sessionAddr, session.sessionID) + delete(sessions, id) + } else { + a.globalLog.OnEventf("Missing dynamic session %v!", id) + } } } @@ -312,3 +370,12 @@ LOOP: } } } + +// SetConnectionValidator sets an optional connection validator. +// Use it when you need a custom authentication logic that includes lower level interactions, +// like mTLS auth or IP whitelistening. +// To remove a previously set validator call it with a nil value: +// a.SetConnectionValidator(nil) +func (a *Acceptor) SetConnectionValidator(validator ConnectionValidator) { + a.connectionValidator = validator +} diff --git a/cmd/generate-fix/internal/generate.go b/cmd/generate-fix/internal/generate.go index 9fbefea29..3dc428980 100644 --- a/cmd/generate-fix/internal/generate.go +++ b/cmd/generate-fix/internal/generate.go @@ -13,6 +13,7 @@ import ( var ( useFloat = flag.Bool("use-float", false, "By default, FIX float fields are represented as arbitrary-precision fixed-point decimal numbers. Set to 'true' to instead generate FIX float fields as float64 values.") + pkgRoot = flag.String("pkg-root", "github.com/quickfixgo", "Set a string here to provide a custom import path for generated packages.") tabWidth = 8 printerMode = printer.UseSpaces | printer.TabIndent ) diff --git a/cmd/generate-fix/internal/helpers.go b/cmd/generate-fix/internal/helpers.go index dba773d89..8fe4cd4c6 100644 --- a/cmd/generate-fix/internal/helpers.go +++ b/cmd/generate-fix/internal/helpers.go @@ -1,21 +1,6 @@ package internal -import ( - "os" - "path/filepath" - "strings" -) - // getImportPathRoot returns the root path to use in import statements. -// The root path is determined by stripping "$GOPATH/src/" from the current -// working directory. For example, when generating code within the QuickFIX/Go -// source tree, the returned root path will be "github.com/quickfixgo/quickfix". func getImportPathRoot() string { - pwd, err := os.Getwd() - if err != nil { - panic(err) - } - goSrcPath := filepath.Join(os.Getenv("GOPATH"), "src") - importPathRoot := filepath.ToSlash(strings.Replace(pwd, goSrcPath, "", 1)) - return strings.TrimLeft(importPathRoot, "/") + return *pkgRoot } diff --git a/cmd/generate-fix/internal/templates.go b/cmd/generate-fix/internal/templates.go index 95aa5276e..cae2d567d 100644 --- a/cmd/generate-fix/internal/templates.go +++ b/cmd/generate-fix/internal/templates.go @@ -57,7 +57,7 @@ Set{{ .Name }}(f {{ .Name }}RepeatingGroup){ {{ define "setters" }} {{ range .Fields }} -//Set{{ .Name }} sets {{ .Name }}, Tag {{ .Tag }} +// Set{{ .Name }} sets {{ .Name }}, Tag {{ .Tag }}. func ({{ template "receiver" }} {{ $.Name }}) {{ if .IsGroup }}{{ template "groupsetter" . }}{{ else }}{{ template "fieldsetter" . }}{{ end }} {{ end }}{{ end }} @@ -97,13 +97,13 @@ Get{{ .Name }}() ({{ .Name }}RepeatingGroup, quickfix.MessageRejectError) { {{ define "getters" }} {{ range .Fields }} -//Get{{ .Name }} gets {{ .Name }}, Tag {{ .Tag }} +// Get{{ .Name }} gets {{ .Name }}, Tag {{ .Tag }}. func ({{ template "receiver" }} {{ $.Name }}) {{if .IsGroup}}{{ template "groupgetter" . }}{{ else }}{{ template "fieldvaluegetter" .}}{{ end }} {{ end }}{{ end }} {{ define "hasers" }} {{range .Fields}} -//Has{{ .Name}} returns true if {{ .Name}} is present, Tag {{ .Tag}} +// Has{{ .Name}} returns true if {{ .Name}} is present, Tag {{ .Tag}}. func ({{ template "receiver" }} {{ $.Name }}) Has{{ .Name}}() bool { return {{ template "receiver" }}.Has(tag.{{ .Name}}) } @@ -121,7 +121,7 @@ quickfix.GroupTemplate{ {{ define "groups" }} {{ range .Fields }} {{ if .IsGroup }} -//{{ .Name }} is a repeating group element, Tag {{ .Tag }} +// {{ .Name }} is a repeating group element, Tag {{ .Tag }}. type {{ .Name }} struct { *quickfix.Group } @@ -131,24 +131,24 @@ type {{ .Name }} struct { {{ template "hasers" . }} {{ template "groups" . }} -//{{ .Name }}RepeatingGroup is a repeating group, Tag {{ .Tag }} +// {{ .Name }}RepeatingGroup is a repeating group, Tag {{ .Tag }}. type {{ .Name }}RepeatingGroup struct { *quickfix.RepeatingGroup } -//New{{ .Name }}RepeatingGroup returns an initialized, {{ .Name }}RepeatingGroup +// New{{ .Name }}RepeatingGroup returns an initialized, {{ .Name }}RepeatingGroup. func New{{ .Name }}RepeatingGroup() {{ .Name }}RepeatingGroup { return {{ .Name }}RepeatingGroup{ quickfix.NewRepeatingGroup(tag.{{ .Name }}, {{ template "group_template" .Fields }})} } -//Add create and append a new {{ .Name }} to this group +// Add create and append a new {{ .Name }} to this group. func ({{ template "receiver" }} {{ .Name }}RepeatingGroup) Add() {{ .Name }} { g := {{ template "receiver" }}.RepeatingGroup.Add() return {{ .Name }}{g} } -//Get returns the ith {{ .Name }} in the {{ .Name }}RepeatinGroup +// Get returns the ith {{ .Name }} in the {{ .Name }}RepeatinGroup. func ({{ template "receiver" }} {{ .Name}}RepeatingGroup) Get(i int) {{ .Name }} { return {{ .Name }}{ {{ template "receiver" }}.RepeatingGroup.Get(i) } } @@ -174,12 +174,12 @@ import( "{{ importRootPath }}/tag" ) -//Header is the {{ .Package }} Header type +// Header is the {{ .Package }} Header type. type Header struct { *quickfix.Header } -//NewHeader returns a new, initialized Header instance +// NewHeader returns a new, initialized Header instance. func NewHeader(header *quickfix.Header) (h Header) { h.Header = header h.SetBeginString("{{ beginString .FIXSpec }}") @@ -209,7 +209,7 @@ import( "{{ importRootPath }}/tag" ) -//Trailer is the {{ .Package }} Trailer type +// Trailer is the {{ .Package }} Trailer type. type Trailer struct { *quickfix.Trailer } @@ -238,7 +238,7 @@ import( "{{ importRootPath }}/tag" ) -//{{ .Name }} is the {{ .FIXPackage }} {{ .Name }} type, MsgType = {{ .MsgType }} +// {{ .Name }} is the {{ .FIXPackage }} {{ .Name }} type, MsgType = {{ .MsgType }}. type {{ .Name }} struct { {{ .TransportPackage }}.Header *quickfix.Body @@ -246,7 +246,7 @@ type {{ .Name }} struct { Message *quickfix.Message } -//FromMessage creates a {{ .Name }} from a quickfix.Message instance +// FromMessage creates a {{ .Name }} from a quickfix.Message instance. func FromMessage(m *quickfix.Message) {{ .Name }} { return {{ .Name }}{ Header: {{ .TransportPackage}}.Header{&m.Header}, @@ -256,13 +256,13 @@ func FromMessage(m *quickfix.Message) {{ .Name }} { } } -//ToMessage returns a quickfix.Message instance +// ToMessage returns a quickfix.Message instance. func (m {{ .Name }}) ToMessage() *quickfix.Message { return m.Message } {{ $required_fields := requiredFields .MessageDef -}} -//New returns a {{ .Name }} initialized with the required fields for {{ .Name }} +// New returns a {{ .Name }} initialized with the required fields for {{ .Name }}. func New({{template "field_args" $required_fields }}) (m {{ .Name }}) { m.Message = quickfix.NewMessage() m.Header = {{ .TransportPackage }}.NewHeader(&m.Message.Header) @@ -277,10 +277,10 @@ func New({{template "field_args" $required_fields }}) (m {{ .Name }}) { return } -//A RouteOut is the callback type that should be implemented for routing Message +// A RouteOut is the callback type that should be implemented for routing Message. type RouteOut func(msg {{ .Name }}, sessionID quickfix.SessionID) quickfix.MessageRejectError -//Route returns the beginstring, message type, and MessageRoute for this Message type +// Route returns the beginstring, message type, and MessageRoute for this Message type. func Route(router RouteOut) (string, string, quickfix.MessageRoute) { r:=func(msg *quickfix.Message, sessionID quickfix.SessionID) quickfix.MessageRejectError { return router(FromMessage(msg), sessionID) @@ -319,28 +319,28 @@ import( {{- $base_type := quickfixType . -}} {{ if and .Enums (ne $base_type "FIXBoolean") }} -//{{ .Name }}Field is a enum.{{ .Name }} field +// {{ .Name }}Field is a enum.{{ .Name }} field. type {{ .Name }}Field struct { quickfix.FIXString } {{ else }} -//{{ .Name }}Field is a {{ .Type }} field +// {{ .Name }}Field is a {{ .Type }} field. type {{ .Name }}Field struct { quickfix.{{ $base_type }} } {{ end }} -//Tag returns tag.{{ .Name }} ({{ .Tag }}) +// Tag returns tag.{{ .Name }} ({{ .Tag }}). func (f {{ .Name }}Field) Tag() quickfix.Tag { return tag.{{ .Name }} } {{ if eq $base_type "FIXUTCTimestamp" }} -//New{{ .Name }} returns a new {{ .Name }}Field initialized with val +// New{{ .Name }} returns a new {{ .Name }}Field initialized with val. func New{{ .Name }}(val time.Time) {{ .Name }}Field { return New{{ .Name }}WithPrecision(val, quickfix.Millis) } -//New{{ .Name }}NoMillis returns a new {{ .Name }}Field initialized with val without millisecs +// New{{ .Name }}NoMillis returns a new {{ .Name }}Field initialized with val without millisecs. func New{{ .Name }}NoMillis(val time.Time) {{ .Name }}Field { return New{{ .Name }}WithPrecision(val, quickfix.Seconds) } -//New{{ .Name }}WithPrecision returns a new {{ .Name }}Field initialized with val of specified precision +// New{{ .Name }}WithPrecision returns a new {{ .Name }}Field initialized with val of specified precision. func New{{ .Name }}WithPrecision(val time.Time, precision quickfix.TimestampPrecision) {{ .Name }}Field { return {{ .Name }}Field{ quickfix.FIXUTCTimestamp{ Time: val, Precision: precision } } } @@ -350,12 +350,12 @@ func New{{ .Name }}(val enum.{{ .Name }}) {{ .Name }}Field { return {{ .Name }}Field{ quickfix.FIXString(val) } } {{ else if eq $base_type "FIXDecimal" }} -//New{{ .Name }} returns a new {{ .Name }}Field initialized with val and scale +// New{{ .Name }} returns a new {{ .Name }}Field initialized with val and scale. func New{{ .Name }}(val decimal.Decimal, scale int32) {{ .Name }}Field { return {{ .Name }}Field{ quickfix.FIXDecimal{ Decimal: val, Scale: scale} } } {{ else }} -//New{{ .Name }} returns a new {{ .Name }}Field initialized with val +// New{{ .Name }} returns a new {{ .Name }}Field initialized with val. func New{{ .Name }}(val {{ quickfixValueType $base_type }}) {{ .Name }}Field { return {{ .Name }}Field{ quickfix.{{ $base_type }}(val) } } @@ -386,7 +386,7 @@ func (f {{ .Name }}Field) Value() ({{ quickfixValueType $base_type }}) { package enum {{ range $ft := . }} {{ if $ft.Enums }} -//Enum values for {{ $ft.Name }} +// {{ $ft.Name }} field enumeration values. type {{ $ft.Name }} string const( {{ range $ft.Enums }} diff --git a/config/configuration.go b/config/configuration.go index afa442a33..adfd9ff54 100644 --- a/config/configuration.go +++ b/config/configuration.go @@ -20,6 +20,7 @@ const ( SocketCertificateFile string = "SocketCertificateFile" SocketCAFile string = "SocketCAFile" SocketInsecureSkipVerify string = "SocketInsecureSkipVerify" + SocketServerName string = "SocketServerName" SocketMinimumTLSVersion string = "SocketMinimumTLSVersion" SocketTimeout string = "SocketTimeout" SocketUseSSL string = "SocketUseSSL" @@ -28,6 +29,7 @@ const ( ProxyPort string = "ProxyPort" ProxyUser string = "ProxyUser" ProxyPassword string = "ProxyPassword" + UseTCPProxy string = "UseTCPProxy" DefaultApplVerID string = "DefaultApplVerID" StartTime string = "StartTime" EndTime string = "EndTime" @@ -42,6 +44,8 @@ const ( ResetOnLogout string = "ResetOnLogout" ResetOnDisconnect string = "ResetOnDisconnect" ReconnectInterval string = "ReconnectInterval" + LogoutTimeout string = "LogoutTimeout" + LogonTimeout string = "LogonTimeout" HeartBtInt string = "HeartBtInt" FileLogPath string = "FileLogPath" FileStorePath string = "FileStorePath" @@ -59,4 +63,5 @@ const ( PersistMessages string = "PersistMessages" RejectInvalidMessage string = "RejectInvalidMessage" DynamicSessions string = "DynamicSessions" + DynamicQualifier string = "DynamicQualifier" ) diff --git a/config/doc.go b/config/doc.go index d02491c16..eff831de9 100644 --- a/config/doc.go +++ b/config/doc.go @@ -230,6 +230,18 @@ Time between reconnection attempts in seconds. Only used for initiators. Valu Defaults to 30 +LogoutTimeout + +Session setting for logout timeout in seconds. Only used for initiators. Value must be positive integer. + +Defaults to 2 + +LogonTimeout + +Session setting for logon timeout in seconds. Only used for initiators. Value must be positive integer. + +Defaults to 10 + HeartBtInt Heartbeat interval in seconds. Only used for initiators. Value must be positive integer. @@ -280,6 +292,10 @@ SocketCAFile Optional root CA to use for secure TLS connections. For acceptors, client certificates will be verified against this CA. For initiators, clients will use the CA to verify the server certificate. If not configurated, initiators will verify the server certificate using the host's root CA set. +SocketServerName + +The expected server name on a returned certificate, unless SocketInsecureSkipVerify is true. This is for the TLS Server Name Indication extension. Initiator only. + SocketMinimumTLSVersion Specify the Minimum TLS version to use when creating a secure connection. The valid choices are SSL30, TLS10, TLS11, TLS12. Defaults to TLS12. @@ -309,6 +325,12 @@ ProxyPassword Proxy password +UseTCPProxy + +Use TCP proxy for servers listening behind HAProxy of Amazon ELB load balancers. The server can then receive the address of the client instead of the load balancer's. Valid Values: + Y + N + PersistMessages If set to N, no messages will be persisted. This will force QuickFIX/Go to always send GapFills instead of resending messages. Use this if you know you never want to resend a message. Useful for market data streams. Valid Values: diff --git a/datadictionary/datadictionary.go b/datadictionary/datadictionary.go index 410ca8ee1..49ba19015 100644 --- a/datadictionary/datadictionary.go +++ b/datadictionary/datadictionary.go @@ -3,6 +3,7 @@ package datadictionary import ( "encoding/xml" + "io" "os" ) @@ -197,26 +198,7 @@ func (f FieldDef) childTags() []int { for _, f := range f.Fields { tags = append(tags, f.Tag()) - for _, t := range f.childTags() { - tags = append(tags, t) - } - } - - return tags -} - -func (f FieldDef) requiredChildTags() []int { - var tags []int - - for _, f := range f.Fields { - if !f.Required() { - continue - } - - tags = append(tags, f.Tag()) - for _, t := range f.requiredChildTags() { - tags = append(tags, t) - } + tags = append(tags, f.childTags()...) } return tags @@ -324,15 +306,20 @@ func Parse(path string) (*DataDictionary, error) { } defer xmlFile.Close() + return ParseSrc(xmlFile) +} + +//ParseSrc loads and and build a datadictionary instance from an xml source. +func ParseSrc(xmlSrc io.Reader) (*DataDictionary, error) { doc := new(XMLDoc) - decoder := xml.NewDecoder(xmlFile) + decoder := xml.NewDecoder(xmlSrc) if err := decoder.Decode(doc); err != nil { return nil, err } b := new(builder) - var dict *DataDictionary - if dict, err = b.build(doc); err != nil { + dict, err := b.build(doc) + if err != nil { return nil, err } diff --git a/datadictionary/datadictionary_test.go b/datadictionary/datadictionary_test.go index 4179716c5..11d98bfca 100644 --- a/datadictionary/datadictionary_test.go +++ b/datadictionary/datadictionary_test.go @@ -82,7 +82,7 @@ func TestFieldsByTag(t *testing.T) { func TestEnumFieldsByTag(t *testing.T) { d, _ := dict() - f, _ := d.FieldTypeByTag[658] + f := d.FieldTypeByTag[658] var tests = []struct { Value string @@ -141,7 +141,7 @@ func TestDataDictionaryTrailer(t *testing.T) { func TestMessageRequiredTags(t *testing.T) { d, _ := dict() - nos, _ := d.Messages["D"] + nos := d.Messages["D"] var tests = []struct { *MessageDef @@ -169,7 +169,7 @@ func TestMessageRequiredTags(t *testing.T) { func TestMessageTags(t *testing.T) { d, _ := dict() - nos, _ := d.Messages["D"] + nos := d.Messages["D"] var tests = []struct { *MessageDef diff --git a/dialer.go b/dialer.go index 610fc1aeb..1645076bf 100644 --- a/dialer.go +++ b/dialer.go @@ -2,9 +2,10 @@ package quickfix import ( "fmt" + "net" + "github.com/quickfixgo/quickfix/config" "golang.org/x/net/proxy" - "net" ) func loadDialerConfig(settings *SessionSettings) (dialer proxy.Dialer, err error) { diff --git a/dialer_test.go b/dialer_test.go index 5e06b82cc..510c18f04 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -1,11 +1,12 @@ package quickfix import ( - "github.com/quickfixgo/quickfix/config" - "github.com/stretchr/testify/suite" "net" "testing" "time" + + "github.com/quickfixgo/quickfix/config" + "github.com/stretchr/testify/suite" ) type DialerTestSuite struct { diff --git a/errors.go b/errors.go index af949ebd4..19ca16e50 100644 --- a/errors.go +++ b/errors.go @@ -33,6 +33,7 @@ type MessageRejectError interface { //RejectReason, tag 373 for session rejects, tag 380 for business rejects. RejectReason() int + BusinessRejectRefID() string RefTagID() *Tag IsBusinessReject() bool } @@ -50,20 +51,25 @@ func (RejectLogon) RefTagID() *Tag { return nil } //RejectReason implements MessageRejectError func (RejectLogon) RejectReason() int { return 0 } +//BusinessRejectRefID implements MessageRejectError +func (RejectLogon) BusinessRejectRefID() string { return "" } + //IsBusinessReject implements MessageRejectError func (RejectLogon) IsBusinessReject() bool { return false } type messageRejectError struct { - rejectReason int - text string - refTagID *Tag - isBusinessReject bool + rejectReason int + text string + businessRejectRefID string + refTagID *Tag + isBusinessReject bool } -func (e messageRejectError) Error() string { return e.text } -func (e messageRejectError) RefTagID() *Tag { return e.refTagID } -func (e messageRejectError) RejectReason() int { return e.rejectReason } -func (e messageRejectError) IsBusinessReject() bool { return e.isBusinessReject } +func (e messageRejectError) Error() string { return e.text } +func (e messageRejectError) RefTagID() *Tag { return e.refTagID } +func (e messageRejectError) RejectReason() int { return e.rejectReason } +func (e messageRejectError) BusinessRejectRefID() string { return e.businessRejectRefID } +func (e messageRejectError) IsBusinessReject() bool { return e.isBusinessReject } //NewMessageRejectError returns a MessageRejectError with the given error message, reject reason, and optional reftagid func NewMessageRejectError(err string, rejectReason int, refTagID *Tag) MessageRejectError { @@ -76,6 +82,12 @@ func NewBusinessMessageRejectError(err string, rejectReason int, refTagID *Tag) return messageRejectError{text: err, rejectReason: rejectReason, refTagID: refTagID, isBusinessReject: true} } +//NewBusinessMessageRejectErrorWithRefID returns a MessageRejectError with the given error mesage, reject reason, refID, and optional reftagid. +//Reject is treated as a business level reject +func NewBusinessMessageRejectErrorWithRefID(err string, rejectReason int, businessRejectRefID string, refTagID *Tag) MessageRejectError { + return messageRejectError{text: err, rejectReason: rejectReason, refTagID: refTagID, businessRejectRefID: businessRejectRefID, isBusinessReject: true} +} + //IncorrectDataFormatForValue returns an error indicating a field that cannot be parsed as the type required. func IncorrectDataFormatForValue(tag Tag) MessageRejectError { return NewMessageRejectError("Incorrect data format for value", rejectReasonIncorrectDataFormatForValue, &tag) diff --git a/errors_test.go b/errors_test.go index 56554af51..e612d3242 100644 --- a/errors_test.go +++ b/errors_test.go @@ -52,6 +52,33 @@ func TestNewBusinessMessageRejectError(t *testing.T) { } } +func TestNewBusinessMessageRejectErrorWithRefID(t *testing.T) { + var ( + expectedErrorString = "Custom error" + expectedRejectReason = 5 + expectedbusinessRejectRefID = "1" + expectedRefTagID Tag = 44 + expectedIsBusinessReject = true + ) + msgRej := NewBusinessMessageRejectErrorWithRefID(expectedErrorString, expectedRejectReason, expectedbusinessRejectRefID, &expectedRefTagID) + + if strings.Compare(msgRej.Error(), expectedErrorString) != 0 { + t.Errorf("expected: %s, got: %s\n", expectedErrorString, msgRej.Error()) + } + if msgRej.RejectReason() != expectedRejectReason { + t.Errorf("expected: %d, got: %d\n", expectedRejectReason, msgRej.RejectReason()) + } + if strings.Compare(msgRej.BusinessRejectRefID(), expectedbusinessRejectRefID) != 0 { + t.Errorf("expected: %s, got: %s\n", expectedbusinessRejectRefID, msgRej.BusinessRejectRefID()) + } + if *msgRej.RefTagID() != expectedRefTagID { + t.Errorf("expected: %d, got: %d\n", expectedRefTagID, msgRej.RefTagID()) + } + if msgRej.IsBusinessReject() != expectedIsBusinessReject { + t.Error("Expected IsBusinessReject to be true\n") + } +} + func TestIncorrectDataFormatForValue(t *testing.T) { var ( expectedErrorString = "Incorrect data format for value" diff --git a/field_map_test.go b/field_map_test.go index c984f7c45..f07f6f396 100644 --- a/field_map_test.go +++ b/field_map_test.go @@ -157,7 +157,7 @@ func TestFieldMap_CopyInto(t *testing.T) { assert.Equal(t, "a", s) // old fields cleared - s, err = fMapB.GetString(3) + _, err = fMapB.GetString(3) assert.NotNil(t, err) // check that ordering is overwritten diff --git a/file_log_test.go b/file_log_test.go index b807dca42..3358dc20a 100644 --- a/file_log_test.go +++ b/file_log_test.go @@ -11,7 +11,7 @@ import ( func TestFileLog_NewFileLogFactory(t *testing.T) { - factory, err := NewFileLogFactory(NewSettings()) + _, err := NewFileLogFactory(NewSettings()) if err == nil { t.Error("Should expect error when settings have no file log path") @@ -39,7 +39,7 @@ SessionQualifier=BS stringReader := strings.NewReader(cfg) settings, _ := ParseSettings(stringReader) - factory, err = NewFileLogFactory(settings) + factory, err := NewFileLogFactory(settings) if err != nil { t.Error("Did not expect error", err) diff --git a/filestore.go b/filestore.go index aa2f7cde3..6d010f897 100644 --- a/filestore.go +++ b/filestore.go @@ -2,12 +2,14 @@ package quickfix import ( "fmt" + "io" "io/ioutil" "os" "path" "strconv" "time" + "github.com/pkg/errors" "github.com/quickfixgo/quickfix/config" ) @@ -81,9 +83,12 @@ func newFileStore(sessionID SessionID, dirname string) (*fileStore, error) { // Reset deletes the store files and sets the seqnums back to 1 func (store *fileStore) Reset() error { - store.cache.Reset() + if err := store.cache.Reset(); err != nil { + return errors.Wrap(err, "cache reset") + } + if err := store.Close(); err != nil { - return err + return errors.Wrap(err, "close") } if err := removeFile(store.bodyFname); err != nil { return err @@ -105,7 +110,10 @@ func (store *fileStore) Reset() error { // Refresh closes the store files and then reloads from them func (store *fileStore) Refresh() (err error) { - store.cache.Reset() + if err = store.cache.Reset(); err != nil { + err = errors.Wrap(err, "cache reset") + return + } if err = store.Close(); err != nil { return err @@ -138,8 +146,13 @@ func (store *fileStore) Refresh() (err error) { } } - store.SetNextSenderMsgSeqNum(store.NextSenderMsgSeqNum()) - store.SetNextTargetMsgSeqNum(store.NextTargetMsgSeqNum()) + if err := store.SetNextSenderMsgSeqNum(store.NextSenderMsgSeqNum()); err != nil { + return errors.Wrap(err, "set next sender") + } + + if err := store.SetNextTargetMsgSeqNum(store.NextTargetMsgSeqNum()); err != nil { + return errors.Wrap(err, "set next target") + } return nil } @@ -166,13 +179,17 @@ func (store *fileStore) populateCache() (creationTimePopulated bool, err error) if senderSeqNumBytes, err := ioutil.ReadFile(store.senderSeqNumsFname); err == nil { if senderSeqNum, err := strconv.Atoi(string(senderSeqNumBytes)); err == nil { - store.cache.SetNextSenderMsgSeqNum(senderSeqNum) + if err = store.cache.SetNextSenderMsgSeqNum(senderSeqNum); err != nil { + return creationTimePopulated, errors.Wrap(err, "cache set next sender") + } } } if targetSeqNumBytes, err := ioutil.ReadFile(store.targetSeqNumsFname); err == nil { if targetSeqNum, err := strconv.Atoi(string(targetSeqNumBytes)); err == nil { - store.cache.SetNextTargetMsgSeqNum(targetSeqNum) + if err = store.cache.SetNextTargetMsgSeqNum(targetSeqNum); err != nil { + return creationTimePopulated, errors.Wrap(err, "cache set next target") + } } } @@ -180,7 +197,7 @@ func (store *fileStore) populateCache() (creationTimePopulated bool, err error) } func (store *fileStore) setSession() error { - if _, err := store.sessionFile.Seek(0, os.SEEK_SET); err != nil { + if _, err := store.sessionFile.Seek(0, io.SeekStart); err != nil { return fmt.Errorf("unable to rewind file: %s: %s", store.sessionFname, err.Error()) } @@ -198,7 +215,7 @@ func (store *fileStore) setSession() error { } func (store *fileStore) setSeqNum(f *os.File, seqNum int) error { - if _, err := f.Seek(0, os.SEEK_SET); err != nil { + if _, err := f.Seek(0, io.SeekStart); err != nil { return fmt.Errorf("unable to rewind file: %s: %s", f.Name(), err.Error()) } if _, err := fmt.Fprintf(f, "%019d", seqNum); err != nil { @@ -222,25 +239,33 @@ func (store *fileStore) NextTargetMsgSeqNum() int { // SetNextSenderMsgSeqNum sets the next MsgSeqNum that will be sent func (store *fileStore) SetNextSenderMsgSeqNum(next int) error { - store.cache.SetNextSenderMsgSeqNum(next) + if err := store.cache.SetNextSenderMsgSeqNum(next); err != nil { + return errors.Wrap(err, "cache") + } return store.setSeqNum(store.senderSeqNumsFile, next) } // SetNextTargetMsgSeqNum sets the next MsgSeqNum that should be received func (store *fileStore) SetNextTargetMsgSeqNum(next int) error { - store.cache.SetNextTargetMsgSeqNum(next) + if err := store.cache.SetNextTargetMsgSeqNum(next); err != nil { + return errors.Wrap(err, "cache") + } return store.setSeqNum(store.targetSeqNumsFile, next) } // IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent func (store *fileStore) IncrNextSenderMsgSeqNum() error { - store.cache.IncrNextSenderMsgSeqNum() + if err := store.cache.IncrNextSenderMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache") + } return store.setSeqNum(store.senderSeqNumsFile, store.cache.NextSenderMsgSeqNum()) } // IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received func (store *fileStore) IncrNextTargetMsgSeqNum() error { - store.cache.IncrNextTargetMsgSeqNum() + if err := store.cache.IncrNextTargetMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache") + } return store.setSeqNum(store.targetSeqNumsFile, store.cache.NextTargetMsgSeqNum()) } diff --git a/fileutil.go b/fileutil.go index 9fa11c964..5334f271c 100644 --- a/fileutil.go +++ b/fileutil.go @@ -4,6 +4,8 @@ import ( "fmt" "os" "strings" + + "github.com/pkg/errors" ) func sessionIDFilenamePrefix(s SessionID) string { @@ -44,9 +46,8 @@ func closeFile(f *os.File) error { // removeFile behaves like os.Remove, except that no error is returned if the file does not exist func removeFile(fname string) error { - err := os.Remove(fname) - if (err != nil) && !os.IsNotExist(err) { - return err + if err := os.Remove(fname); (err != nil) && !os.IsNotExist(err) { + return errors.Wrapf(err, "remove %v", fname) } return nil } diff --git a/fileutil_test.go b/fileutil_test.go index 4817c7862..f634651df 100644 --- a/fileutil_test.go +++ b/fileutil_test.go @@ -53,6 +53,7 @@ func TestOpenOrCreateFile(t *testing.T) { // Then it should be created f, err := openOrCreateFile(fname, 0664) + require.Nil(t, err) requireFileExists(t, fname) // When the file already exists diff --git a/fix_int_test.go b/fix_int_test.go index 173cb5ee5..64142c045 100644 --- a/fix_int_test.go +++ b/fix_int_test.go @@ -32,6 +32,6 @@ func BenchmarkFIXInt_Read(b *testing.B) { var field FIXInt for i := 0; i < b.N; i++ { - field.Read(intBytes) + _ = field.Read(intBytes) } } diff --git a/go.mod b/go.mod new file mode 100644 index 000000000..bb753f61f --- /dev/null +++ b/go.mod @@ -0,0 +1,17 @@ +module github.com/quickfixgo/quickfix + +go 1.15 + +require ( + github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a + github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 + github.com/kr/text v0.2.0 // indirect + github.com/mattn/go-sqlite3 v1.14.7 + github.com/pkg/errors v0.9.1 + github.com/shopspring/decimal v1.2.0 + github.com/stretchr/objx v0.3.0 // indirect + github.com/stretchr/testify v1.7.0 + golang.org/x/net v0.0.0-20210614182718-04defd469f4e + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..277635b4b --- /dev/null +++ b/go.sum @@ -0,0 +1,41 @@ +github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a h1:AP/vsCIvJZ129pdm9Ek7bH7yutN3hByqsMoNrWAxRQc= +github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8 h1:DujepqpGd1hyOd7aW59XpK7Qymp8iy83xq74fLr21is= +github.com/globalsign/mgo v0.0.0-20181015135952-eeefdecb41b8/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= +github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.3.0 h1:NGXK3lHquSN08v5vWalVI/L8XU9hdzE/G6xsrze47As= +github.com/stretchr/objx v0.3.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/net v0.0.0-20210614182718-04defd469f4e h1:XpT3nA5TvE525Ne3hInMh6+GETgn27Zfm9dxsThnX2Q= +golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/in_session.go b/in_session.go index d4dfc701d..3ca45d7bd 100644 --- a/in_session.go +++ b/in_session.go @@ -239,7 +239,7 @@ func (state inSession) resendMessages(session *session, beginSeqNo, endSeqNo int session.log.OnEventf("Resending Message: %v", sentMessageSeqNum) msgBytes = msg.build() - session.sendBytes(msgBytes) + session.EnqueueBytesAndSend(msgBytes) seqNum = sentMessageSeqNum + 1 nextSeqNum = seqNum @@ -382,7 +382,7 @@ func (state *inSession) generateSequenceReset(session *session, beginSeqNo int, msgBytes := sequenceReset.build() - session.sendBytes(msgBytes) + session.EnqueueBytesAndSend(msgBytes) session.log.OnEventf("Sent SequenceReset TO: %v", endSeqNo) return diff --git a/initiator.go b/initiator.go index 2ce02ace5..3b6672395 100644 --- a/initiator.go +++ b/initiator.go @@ -3,9 +3,11 @@ package quickfix import ( "bufio" "crypto/tls" - "golang.org/x/net/proxy" + "strings" "sync" "time" + + "golang.org/x/net/proxy" ) //Initiator initiates connections and processes messages for all sessions. @@ -107,7 +109,7 @@ func (i *Initiator) waitForInSessionTime(session *session) bool { return true } -//watiForReconnectInterval returns true if a reconnect should be re-attempted, false if handler should stop +//waitForReconnectInterval returns true if a reconnect should be re-attempted, false if handler should stop func (i *Initiator) waitForReconnectInterval(reconnectInterval time.Duration) bool { select { case <-time.After(reconnectInterval): @@ -150,6 +152,15 @@ func (i *Initiator) handleConnection(session *session, tlsConfig *tls.Config, di session.log.OnEventf("Failed to connect: %v", err) goto reconnect } else if tlsConfig != nil { + // Unless InsecureSkipVerify is true, server name config is required for TLS + // to verify the received certificate + if !tlsConfig.InsecureSkipVerify && len(tlsConfig.ServerName) == 0 { + serverName := address + if c := strings.LastIndex(serverName, ":"); c > 0 { + serverName = serverName[:c] + } + tlsConfig.ServerName = serverName + } tlsConn := tls.Client(netConn, tlsConfig) if err = tlsConn.Handshake(); err != nil { session.log.OnEventf("Failed handshake: %v", err) diff --git a/internal/session_settings.go b/internal/session_settings.go index f5d7f9a35..4c82b000d 100644 --- a/internal/session_settings.go +++ b/internal/session_settings.go @@ -22,5 +22,7 @@ type SessionSettings struct { //specific to initiators ReconnectInterval time.Duration + LogoutTimeout time.Duration + LogonTimeout time.Duration SocketConnectAddress []string } diff --git a/internal/time_range_test.go b/internal/time_range_test.go index df392118a..a28a495f5 100644 --- a/internal/time_range_test.go +++ b/internal/time_range_test.go @@ -374,6 +374,7 @@ func TestTimeRangeIsInSameRangeWithDay(t *testing.T) { time1 = time.Date(2004, time.July, 27, 3, 0, 0, 0, time.UTC) time2 = time.Date(2004, time.July, 27, 3, 0, 0, 0, time.UTC) + assert.True(t, NewUTCWeekRange(startTime, endTime, startDay, endDay).IsInSameRange(time1, time2)) time1 = time.Date(2004, time.July, 26, 10, 0, 0, 0, time.UTC) time2 = time.Date(2004, time.July, 27, 3, 0, 0, 0, time.UTC) diff --git a/logon_state.go b/logon_state.go index 351fa629c..d254b92b3 100644 --- a/logon_state.go +++ b/logon_state.go @@ -24,23 +24,15 @@ func (s logonState) FixMsgIn(session *session, msg *Message) (nextState sessionS if err := session.handleLogon(msg); err != nil { switch err := err.(type) { case RejectLogon: - session.log.OnEvent(err.Text) - logout := session.buildLogout(err.Text) + return shutdownWithReason(session, msg, true, err.Error()) - if err := session.dropAndSendInReplyTo(logout, msg); err != nil { - session.logError(err) - } - - if err := session.store.IncrNextTargetMsgSeqNum(); err != nil { - session.logError(err) - } - - return latentState{} + case targetTooLow: + return shutdownWithReason(session, msg, false, err.Error()) case targetTooHigh: var tooHighErr error if nextState, tooHighErr = session.doTargetTooHigh(err); tooHighErr != nil { - return handleStateError(session, tooHighErr) + return shutdownWithReason(session, msg, false, tooHighErr.Error()) } return @@ -64,3 +56,20 @@ func (s logonState) Timeout(session *session, e internal.Event) (nextState sessi func (s logonState) Stop(session *session) (nextState sessionState) { return latentState{} } + +func shutdownWithReason(session *session, msg *Message, incrNextTargetMsgSeqNum bool, reason string) (nextState sessionState) { + session.log.OnEvent(reason) + logout := session.buildLogout(reason) + + if err := session.dropAndSendInReplyTo(logout, msg); err != nil { + session.logError(err) + } + + if incrNextTargetMsgSeqNum { + if err := session.store.IncrNextTargetMsgSeqNum(); err != nil { + session.logError(err) + } + } + + return latentState{} +} diff --git a/logon_state_test.go b/logon_state_test.go index 08da2e710..3ba29af1d 100644 --- a/logon_state_test.go +++ b/logon_state_test.go @@ -302,3 +302,31 @@ func (s *LogonStateTestSuite) TestFixMsgInLogonSeqNumTooHigh() { s.State(inSession{}) s.NextTargetMsgSeqNum(7) } + +func (s *LogonStateTestSuite) TestFixMsgInLogonSeqNumTooLow() { + s.IncrNextSenderMsgSeqNum() + s.IncrNextTargetMsgSeqNum() + + logon := s.Logon() + logon.Body.SetField(tagHeartBtInt, FIXInt(32)) + logon.Header.SetInt(tagMsgSeqNum, 1) + + s.MockApp.On("ToAdmin") + s.NextTargetMsgSeqNum(2) + s.fixMsgIn(s.session, logon) + + s.State(latentState{}) + s.NextTargetMsgSeqNum(2) + + s.MockApp.AssertNumberOfCalls(s.T(), "ToAdmin", 1) + msgBytesSent, ok := s.Receiver.LastMessage() + s.Require().True(ok) + sentMessage := NewMessage() + err := ParseMessage(sentMessage, bytes.NewBuffer(msgBytesSent)) + s.Require().Nil(err) + s.MessageType(string(msgTypeLogout), sentMessage) + + s.session.sendQueued() + s.MessageType(string(msgTypeLogout), s.MockApp.lastToAdmin) + s.FieldEquals(tagText, "MsgSeqNum too low, expecting 2 but received 1", s.MockApp.lastToAdmin.Body) +} diff --git a/mongostore.go b/mongostore.go index 76789dbe9..e50af2c16 100644 --- a/mongostore.go +++ b/mongostore.go @@ -6,7 +6,7 @@ import ( "github.com/globalsign/mgo" "github.com/globalsign/mgo/bson" - + "github.com/pkg/errors" "github.com/quickfixgo/quickfix/config" ) @@ -66,7 +66,11 @@ func newMongoStore(sessionID SessionID, mongoURL string, mongoDatabase string, m messagesCollection: messagesCollection, sessionsCollection: sessionsCollection, } - store.cache.Reset() + + if err = store.cache.Reset(); err != nil { + err = errors.Wrap(err, "cache reset") + return + } if store.db, err = mgo.Dial(mongoURL); err != nil { return @@ -139,27 +143,43 @@ func (store *mongoStore) Refresh() error { return store.populateCache() } -func (store *mongoStore) populateCache() (err error) { +func (store *mongoStore) populateCache() error { msgFilter := generateMessageFilter(&store.sessionID) query := store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Find(msgFilter) - if cnt, err := query.Count(); err == nil && cnt > 0 { + cnt, err := query.Count() + if err != nil { + return errors.Wrap(err, "count") + } + + if cnt > 0 { // session record found, load it sessionData := &mongoQuickFixEntryData{} - err = query.One(&sessionData) - if err == nil { - store.cache.creationTime = sessionData.CreationTime - store.cache.SetNextTargetMsgSeqNum(sessionData.IncomingSeqNum) - store.cache.SetNextSenderMsgSeqNum(sessionData.OutgoingSeqNum) + if err = query.One(&sessionData); err != nil { + return errors.Wrap(err, "query one") + } + + store.cache.creationTime = sessionData.CreationTime + if err = store.cache.SetNextTargetMsgSeqNum(sessionData.IncomingSeqNum); err != nil { + return errors.Wrap(err, "cache set next target") } - } else if err == nil && cnt == 0 { - // session record not found, create it - msgFilter.CreationTime = store.cache.creationTime - msgFilter.IncomingSeqNum = store.cache.NextTargetMsgSeqNum() - msgFilter.OutgoingSeqNum = store.cache.NextSenderMsgSeqNum() - err = store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Insert(msgFilter) + + if err = store.cache.SetNextSenderMsgSeqNum(sessionData.OutgoingSeqNum); err != nil { + return errors.Wrap(err, "cache set next sender") + } + + return nil } - return + + // session record not found, create it + msgFilter.CreationTime = store.cache.creationTime + msgFilter.IncomingSeqNum = store.cache.NextTargetMsgSeqNum() + msgFilter.OutgoingSeqNum = store.cache.NextSenderMsgSeqNum() + + if err = store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Insert(msgFilter); err != nil { + return errors.Wrap(err, "insert") + } + return nil } // NextSenderMsgSeqNum returns the next MsgSeqNum that will be sent @@ -200,13 +220,17 @@ func (store *mongoStore) SetNextTargetMsgSeqNum(next int) error { // IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent func (store *mongoStore) IncrNextSenderMsgSeqNum() error { - store.cache.IncrNextSenderMsgSeqNum() + if err := store.cache.IncrNextSenderMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr") + } return store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum()) } // IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received func (store *mongoStore) IncrNextTargetMsgSeqNum() error { - store.cache.IncrNextTargetMsgSeqNum() + if err := store.cache.IncrNextTargetMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr") + } return store.SetNextTargetMsgSeqNum(store.cache.NextTargetMsgSeqNum()) } diff --git a/mongostore_test.go b/mongostore_test.go index 21ca84a8d..29d43d71d 100644 --- a/mongostore_test.go +++ b/mongostore_test.go @@ -2,12 +2,13 @@ package quickfix import ( "fmt" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" "log" "os" "strings" "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" ) // MongoStoreTestSuite runs all tests in the MessageStoreTestSuite against the MongoStore implementation diff --git a/parser_test.go b/parser_test.go index ab9c086a7..533691986 100644 --- a/parser_test.go +++ b/parser_test.go @@ -13,7 +13,7 @@ func BenchmarkParser_ReadMessage(b *testing.B) { for i := 0; i < b.N; i++ { reader := strings.NewReader(stream) parser := newParser(reader) - parser.ReadMessage() + _, _ = parser.ReadMessage() } } diff --git a/repeating_group.go b/repeating_group.go index d3bc4a58c..811379c37 100644 --- a/repeating_group.go +++ b/repeating_group.go @@ -109,7 +109,7 @@ func (f *RepeatingGroup) Add() *Group { //Write returns tagValues for all Items in the repeating group ordered by //Group sequence and Group template order func (f RepeatingGroup) Write() []TagValue { - tvs := make([]TagValue, 1, 1) + tvs := make([]TagValue, 1) tvs[0].init(f.tag, []byte(strconv.Itoa(len(f.groups)))) for _, group := range f.groups { diff --git a/session.go b/session.go index 204fa99d5..5e77c3e55 100644 --- a/session.go +++ b/session.go @@ -342,6 +342,14 @@ func (s *session) dropQueued() { s.toSend = s.toSend[:0] } +func (s *session) EnqueueBytesAndSend(msg []byte) { + s.sendMutex.Lock() + defer s.sendMutex.Unlock() + + s.toSend = append(s.toSend, msg) + s.sendQueued() +} + func (s *session) sendBytes(msg []byte) { s.log.OnOutgoing(msg) s.messageOut <- msg @@ -465,8 +473,7 @@ func (s *session) initiateLogoutInReplyTo(reason string, inReplyTo *Message) (er return } s.log.OnEvent("Inititated logout request") - time.AfterFunc(time.Duration(2)*time.Second, func() { s.sessionEvent <- internal.LogoutTimeout }) - + time.AfterFunc(s.LogoutTimeout, func() { s.sessionEvent <- internal.LogoutTimeout }) return } @@ -623,6 +630,9 @@ func (s *session) doReject(msg *Message, rej MessageRejectError) error { if rej.IsBusinessReject() { reply.Header.SetField(tagMsgType, FIXString("j")) reply.Body.SetField(tagBusinessRejectReason, FIXInt(rej.RejectReason())) + if refID := rej.BusinessRejectRefID(); refID != "" { + reply.Body.SetField(tagBusinessRejectRefID, FIXString(refID)) + } } else { reply.Header.SetField(tagMsgType, FIXString("3")) switch { @@ -703,6 +713,15 @@ func (s *session) onAdmin(msg interface{}) { return } + if !s.IsSessionTime() { + s.handleDisconnectState(s) + if msg.err != nil { + msg.err <- errors.New("Connection outside of session time") + close(msg.err) + } + return + } + if msg.err != nil { close(msg.err) } diff --git a/session_factory.go b/session_factory.go index 36e294837..f4f3fc0ef 100644 --- a/session_factory.go +++ b/session_factory.go @@ -331,6 +331,36 @@ func (f sessionFactory) buildInitiatorSettings(session *session, settings *Sessi session.ReconnectInterval = time.Duration(interval) * time.Second } + session.LogoutTimeout = 2 * time.Second + if settings.HasSetting(config.LogoutTimeout) { + + timeout, err := settings.IntSetting(config.LogoutTimeout) + if err != nil { + return err + } + + if timeout <= 0 { + return errors.New("LogoutTimeout must be greater than zero") + } + + session.LogoutTimeout = time.Duration(timeout) * time.Second + } + + session.LogonTimeout = 10 * time.Second + if settings.HasSetting(config.LogonTimeout) { + + timeout, err := settings.IntSetting(config.LogonTimeout) + if err != nil { + return err + } + + if timeout <= 0 { + return errors.New("LogonTimeout must be greater than zero") + } + + session.LogonTimeout = time.Duration(timeout) * time.Second + } + return f.configureSocketConnectAddress(session, settings) } diff --git a/session_factory_test.go b/session_factory_test.go index 7db6cfd87..88f08f1aa 100644 --- a/session_factory_test.go +++ b/session_factory_test.go @@ -129,7 +129,7 @@ func (s *SessionFactorySuite) TestResendRequestChunkSize() { s.Equal(2500, session.ResendRequestChunkSize) s.SessionSettings.Set(config.ResendRequestChunkSize, "notanint") - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err) } @@ -353,6 +353,8 @@ func (s *SessionFactorySuite) TestNewSessionBuildInitiators() { s.True(session.InitiateLogon) s.Equal(34*time.Second, session.HeartBtInt) s.Equal(30*time.Second, session.ReconnectInterval) + s.Equal(10*time.Second, session.LogonTimeout) + s.Equal(2*time.Second, session.LogoutTimeout) s.Equal("127.0.0.1:5000", session.SocketConnectAddress[0]) } @@ -399,6 +401,54 @@ func (s *SessionFactorySuite) TestNewSessionBuildInitiatorsValidReconnectInterva s.NotNil(err, "ReconnectInterval must be greater than zero") } +func (s *SessionFactorySuite) TestNewSessionBuildInitiatorsValidLogoutTimeout() { + s.sessionFactory.BuildInitiators = true + s.SessionSettings.Set(config.HeartBtInt, "34") + s.SessionSettings.Set(config.SocketConnectHost, "127.0.0.1") + s.SessionSettings.Set(config.SocketConnectPort, "3000") + + s.SessionSettings.Set(config.LogoutTimeout, "45") + session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.Nil(err) + s.Equal(45*time.Second, session.LogoutTimeout) + + s.SessionSettings.Set(config.LogoutTimeout, "not a number") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "LogoutTimeout must be a number") + + s.SessionSettings.Set(config.LogoutTimeout, "0") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "LogoutTimeout must be greater than zero") + + s.SessionSettings.Set(config.LogoutTimeout, "-20") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "LogoutTimeout must be greater than zero") +} + +func (s *SessionFactorySuite) TestNewSessionBuildInitiatorsValidLogonTimeout() { + s.sessionFactory.BuildInitiators = true + s.SessionSettings.Set(config.HeartBtInt, "34") + s.SessionSettings.Set(config.SocketConnectHost, "127.0.0.1") + s.SessionSettings.Set(config.SocketConnectPort, "3000") + + s.SessionSettings.Set(config.LogonTimeout, "45") + session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.Nil(err) + s.Equal(45*time.Second, session.LogonTimeout) + + s.SessionSettings.Set(config.LogonTimeout, "not a number") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "LogonTimeout must be a number") + + s.SessionSettings.Set(config.LogonTimeout, "0") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "LogonTimeout must be greater than zero") + + s.SessionSettings.Set(config.LogonTimeout, "-20") + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + s.NotNil(err, "LogonTimeout must be greater than zero") +} + func (s *SessionFactorySuite) TestConfigureSocketConnectAddress() { sess := new(session) err := s.configureSocketConnectAddress(sess, s.SessionSettings) @@ -468,7 +518,7 @@ func (s *SessionFactorySuite) TestConfigureSocketConnectAddressMulti() { func (s *SessionFactorySuite) TestNewSessionTimestampPrecision() { s.SessionSettings.Set(config.TimeStampPrecision, "blah") - session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err) var tests = []struct { @@ -483,7 +533,7 @@ func (s *SessionFactorySuite) TestNewSessionTimestampPrecision() { for _, test := range tests { s.SessionSettings.Set(config.TimeStampPrecision, test.config) - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.Nil(err) s.Equal(session.timestampPrecision, test.precision) @@ -492,19 +542,19 @@ func (s *SessionFactorySuite) TestNewSessionTimestampPrecision() { func (s *SessionFactorySuite) TestNewSessionMaxLatency() { s.SessionSettings.Set(config.MaxLatency, "not a number") - session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err, "MaxLatency must be a number") s.SessionSettings.Set(config.MaxLatency, "-20") - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err, "MaxLatency must be positive") s.SessionSettings.Set(config.MaxLatency, "0") - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + _, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.NotNil(err, "MaxLatency must be positive") s.SessionSettings.Set(config.MaxLatency, "20") - session, err = s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) + session, err := s.newSession(s.SessionID, s.MessageStoreFactory, s.SessionSettings, s.LogFactory, s.App) s.Nil(err) s.Equal(session.MaxLatency, 20*time.Second) } diff --git a/session_settings_test.go b/session_settings_test.go index a28f60529..34687ae2c 100644 --- a/session_settings_test.go +++ b/session_settings_test.go @@ -1,8 +1,9 @@ package quickfix import ( - "github.com/quickfixgo/quickfix/config" "testing" + + "github.com/quickfixgo/quickfix/config" ) func TestSessionSettings_StringSettings(t *testing.T) { diff --git a/session_state.go b/session_state.go index 1f076f221..c8b1f42a4 100644 --- a/session_state.go +++ b/session_state.go @@ -22,28 +22,27 @@ func (sm *stateMachine) Start(s *session) { } func (sm *stateMachine) Connect(session *session) { - if !sm.IsSessionTime() { - session.log.OnEvent("Connection outside of session time") - sm.handleDisconnectState(session) + // No special logon logic needed for FIX Acceptors. + if !session.InitiateLogon { + sm.setState(session, logonState{}) return } - if session.InitiateLogon { - if session.RefreshOnLogon { - if err := session.store.Refresh(); err != nil { - session.logError(err) - return - } - } - - session.log.OnEvent("Sending logon request") - if err := session.sendLogon(); err != nil { + if session.RefreshOnLogon { + if err := session.store.Refresh(); err != nil { session.logError(err) return } } + session.log.OnEvent("Sending logon request") + if err := session.sendLogon(); err != nil { + session.logError(err) + return + } sm.setState(session, logonState{}) + // Fire logon timeout event after the pre-configured delay period. + time.AfterFunc(session.LogonTimeout, func() { session.sessionEvent <- internal.LogonTimeout }) } func (sm *stateMachine) Stop(session *session) { diff --git a/session_test.go b/session_test.go index 33c7cf417..293b94805 100644 --- a/session_test.go +++ b/session_test.go @@ -276,8 +276,8 @@ func (s *SessionSuite) TestShouldSendReset() { s.session.ResetOnDisconnect = test.ResetOnDisconnect s.session.ResetOnLogout = test.ResetOnLogout - s.MockStore.SetNextSenderMsgSeqNum(test.NextSenderMsgSeqNum) - s.MockStore.SetNextTargetMsgSeqNum(test.NextTargetMsgSeqNum) + s.Require().Nil(s.MockStore.SetNextSenderMsgSeqNum(test.NextSenderMsgSeqNum)) + s.Require().Nil(s.MockStore.SetNextTargetMsgSeqNum(test.NextTargetMsgSeqNum)) s.Equal(s.shouldSendReset(), test.Expected) } @@ -944,7 +944,7 @@ func (suite *SessionSendTestSuite) TestDropAndSendDropsQueueWithReset() { suite.NoMessageSent() suite.MockApp.On("ToAdmin") - suite.MockStore.Reset() + suite.Require().Nil(suite.MockStore.Reset()) require.Nil(suite.T(), suite.dropAndSend(suite.Logon())) suite.MockApp.AssertExpectations(suite.T()) msg := suite.MockApp.lastToAdmin diff --git a/sqlstore.go b/sqlstore.go index b16b0a74d..3569359af 100644 --- a/sqlstore.go +++ b/sqlstore.go @@ -3,8 +3,10 @@ package quickfix import ( "database/sql" "fmt" + "regexp" "time" + "github.com/pkg/errors" "github.com/quickfixgo/quickfix/config" ) @@ -19,6 +21,27 @@ type sqlStore struct { sqlDataSourceName string sqlConnMaxLifetime time.Duration db *sql.DB + placeholder placeholderFunc +} + +type placeholderFunc func(int) string + +var rePlaceholder = regexp.MustCompile(`\?`) + +func sqlString(raw string, placeholder placeholderFunc) string { + if placeholder == nil { + return raw + } + idx := 0 + return rePlaceholder.ReplaceAllStringFunc(raw, func(s string) string { + new := placeholder(idx) + idx += 1 + return new + }) +} + +func postgresPlaceholder(i int) string { + return fmt.Sprintf("$%d", i+1) } // NewSQLStoreFactory returns a sql-based implementation of MessageStoreFactory @@ -58,7 +81,14 @@ func newSQLStore(sessionID SessionID, driver string, dataSourceName string, conn sqlDataSourceName: dataSourceName, sqlConnMaxLifetime: connMaxLifetime, } - store.cache.Reset() + if err = store.cache.Reset(); err != nil { + err = errors.Wrap(err, "cache reset") + return + } + + if store.sqlDriver == "postgres" { + store.placeholder = postgresPlaceholder + } if store.db, err = sql.Open(store.sqlDriver, store.sqlDataSourceName); err != nil { return nil, err @@ -78,10 +108,10 @@ func newSQLStore(sessionID SessionID, driver string, dataSourceName string, conn // Reset deletes the store records and sets the seqnums back to 1 func (store *sqlStore) Reset() error { s := store.sessionID - _, err := store.db.Exec(`DELETE FROM messages + _, err := store.db.Exec(sqlString(`DELETE FROM messages WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -93,11 +123,11 @@ func (store *sqlStore) Reset() error { return err } - _, err = store.db.Exec(`UPDATE sessions + _, err = store.db.Exec(sqlString(`UPDATE sessions SET creation_time=?, incoming_seqnum=?, outgoing_seqnum=? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), store.cache.CreationTime(), store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, @@ -114,26 +144,30 @@ func (store *sqlStore) Refresh() error { return store.populateCache() } -func (store *sqlStore) populateCache() (err error) { +func (store *sqlStore) populateCache() error { s := store.sessionID var creationTime time.Time var incomingSeqNum, outgoingSeqNum int - row := store.db.QueryRow(`SELECT creation_time, incoming_seqnum, outgoing_seqnum + row := store.db.QueryRow(sqlString(`SELECT creation_time, incoming_seqnum, outgoing_seqnum FROM sessions WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) - err = row.Scan(&creationTime, &incomingSeqNum, &outgoingSeqNum) + err := row.Scan(&creationTime, &incomingSeqNum, &outgoingSeqNum) // session record found, load it if err == nil { store.cache.creationTime = creationTime - store.cache.SetNextTargetMsgSeqNum(incomingSeqNum) - store.cache.SetNextSenderMsgSeqNum(outgoingSeqNum) + if err = store.cache.SetNextTargetMsgSeqNum(incomingSeqNum); err != nil { + return errors.Wrap(err, "cache set next target") + } + if err = store.cache.SetNextSenderMsgSeqNum(outgoingSeqNum); err != nil { + return errors.Wrap(err, "cache set next sender") + } return nil } @@ -143,12 +177,12 @@ func (store *sqlStore) populateCache() (err error) { } // session record not found, create it - _, err = store.db.Exec(`INSERT INTO sessions ( + _, err = store.db.Exec(sqlString(`INSERT INTO sessions ( creation_time, incoming_seqnum, outgoing_seqnum, beginstring, session_qualifier, sendercompid, sendersubid, senderlocid, targetcompid, targetsubid, targetlocid) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, store.placeholder), store.cache.creationTime, store.cache.NextTargetMsgSeqNum(), store.cache.NextSenderMsgSeqNum(), @@ -172,10 +206,10 @@ func (store *sqlStore) NextTargetMsgSeqNum() int { // SetNextSenderMsgSeqNum sets the next MsgSeqNum that will be sent func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error { s := store.sessionID - _, err := store.db.Exec(`UPDATE sessions SET outgoing_seqnum = ? + _, err := store.db.Exec(sqlString(`UPDATE sessions SET outgoing_seqnum = ? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), next, s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -188,10 +222,10 @@ func (store *sqlStore) SetNextSenderMsgSeqNum(next int) error { // SetNextTargetMsgSeqNum sets the next MsgSeqNum that should be received func (store *sqlStore) SetNextTargetMsgSeqNum(next int) error { s := store.sessionID - _, err := store.db.Exec(`UPDATE sessions SET incoming_seqnum = ? + _, err := store.db.Exec(sqlString(`UPDATE sessions SET incoming_seqnum = ? WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? - AND targetcompid=? AND targetsubid=? AND targetlocid=?`, + AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder), next, s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID) @@ -203,13 +237,17 @@ func (store *sqlStore) SetNextTargetMsgSeqNum(next int) error { // IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent func (store *sqlStore) IncrNextSenderMsgSeqNum() error { - store.cache.IncrNextSenderMsgSeqNum() + if err := store.cache.IncrNextSenderMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr next") + } return store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum()) } // IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received func (store *sqlStore) IncrNextTargetMsgSeqNum() error { - store.cache.IncrNextTargetMsgSeqNum() + if err := store.cache.IncrNextTargetMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr next") + } return store.SetNextTargetMsgSeqNum(store.cache.NextTargetMsgSeqNum()) } @@ -221,12 +259,12 @@ func (store *sqlStore) CreationTime() time.Time { func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error { s := store.sessionID - _, err := store.db.Exec(`INSERT INTO messages ( + _, err := store.db.Exec(sqlString(`INSERT INTO messages ( msgseqnum, message, beginstring, session_qualifier, sendercompid, sendersubid, senderlocid, targetcompid, targetsubid, targetlocid) - VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, store.placeholder), seqNum, string(msg), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, @@ -238,12 +276,12 @@ func (store *sqlStore) SaveMessage(seqNum int, msg []byte) error { func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) { s := store.sessionID var msgs [][]byte - rows, err := store.db.Query(`SELECT message FROM messages + rows, err := store.db.Query(sqlString(`SELECT message FROM messages WHERE beginstring=? AND session_qualifier=? AND sendercompid=? AND sendersubid=? AND senderlocid=? AND targetcompid=? AND targetsubid=? AND targetlocid=? AND msgseqnum>=? AND msgseqnum<=? - ORDER BY msgseqnum`, + ORDER BY msgseqnum`, store.placeholder), s.BeginString, s.Qualifier, s.SenderCompID, s.SenderSubID, s.SenderLocationID, s.TargetCompID, s.TargetSubID, s.TargetLocationID, diff --git a/sqlstore_test.go b/sqlstore_test.go index b252b2c25..da2c8a5a4 100644 --- a/sqlstore_test.go +++ b/sqlstore_test.go @@ -60,6 +60,11 @@ TargetCompID=%s`, sqlDriver, sqlDsn, sessionID.BeginString, sessionID.SenderComp require.Nil(suite.T(), err) } +func (suite *SQLStoreTestSuite) TestSqlPlaceholderReplacement() { + got := sqlString("A ? B ? C ?", postgresPlaceholder) + suite.Equal("A $1 B $2 C $3", got) +} + func (suite *SQLStoreTestSuite) TearDownTest() { suite.msgStore.Close() os.RemoveAll(suite.sqlStoreRootPath) diff --git a/store.go b/store.go index 837bbca13..41b6bc0c0 100644 --- a/store.go +++ b/store.go @@ -1,6 +1,10 @@ package quickfix -import "time" +import ( + "time" + + "github.com/pkg/errors" +) //The MessageStore interface provides methods to record and retrieve messages for resend purposes type MessageStore interface { @@ -107,7 +111,9 @@ type memoryStoreFactory struct{} func (f memoryStoreFactory) Create(sessionID SessionID) (MessageStore, error) { m := new(memoryStore) - m.Reset() + if err := m.Reset(); err != nil { + return m, errors.Wrap(err, "reset") + } return m, nil } diff --git a/store_test.go b/store_test.go index a61ccbb05..185704e52 100644 --- a/store_test.go +++ b/store_test.go @@ -30,61 +30,55 @@ func TestMemoryStoreTestSuite(t *testing.T) { suite.Run(t, new(MemoryStoreTestSuite)) } -func (suite *MessageStoreTestSuite) TestMessageStore_SetNextMsgSeqNum_Refresh_IncrNextMsgSeqNum() { - t := suite.T() - +func (s *MessageStoreTestSuite) TestMessageStore_SetNextMsgSeqNum_Refresh_IncrNextMsgSeqNum() { // Given a MessageStore with the following sender and target seqnums - suite.msgStore.SetNextSenderMsgSeqNum(867) - suite.msgStore.SetNextTargetMsgSeqNum(5309) + s.Require().Nil(s.msgStore.SetNextSenderMsgSeqNum(867)) + s.Require().Nil(s.msgStore.SetNextTargetMsgSeqNum(5309)) // When the store is refreshed from its backing store - suite.msgStore.Refresh() + s.Require().Nil(s.msgStore.Refresh()) // Then the sender and target seqnums should still be - assert.Equal(t, 867, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 5309, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(867, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(5309, s.msgStore.NextTargetMsgSeqNum()) // When the sender and target seqnums are incremented - require.Nil(t, suite.msgStore.IncrNextSenderMsgSeqNum()) - require.Nil(t, suite.msgStore.IncrNextTargetMsgSeqNum()) + s.Require().Nil(s.msgStore.IncrNextSenderMsgSeqNum()) + s.Require().Nil(s.msgStore.IncrNextTargetMsgSeqNum()) // Then the sender and target seqnums should be - assert.Equal(t, 868, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 5310, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(868, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(5310, s.msgStore.NextTargetMsgSeqNum()) // When the store is refreshed from its backing store - suite.msgStore.Refresh() + s.Require().Nil(s.msgStore.Refresh()) // Then the sender and target seqnums should still be - assert.Equal(t, 868, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 5310, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(868, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(5310, s.msgStore.NextTargetMsgSeqNum()) } -func (suite *MessageStoreTestSuite) TestMessageStore_Reset() { - t := suite.T() - +func (s *MessageStoreTestSuite) TestMessageStore_Reset() { // Given a MessageStore with the following sender and target seqnums - suite.msgStore.SetNextSenderMsgSeqNum(1234) - suite.msgStore.SetNextTargetMsgSeqNum(5678) + s.Require().Nil(s.msgStore.SetNextSenderMsgSeqNum(1234)) + s.Require().Nil(s.msgStore.SetNextTargetMsgSeqNum(5678)) // When the store is reset - require.Nil(t, suite.msgStore.Reset()) + s.Require().Nil(s.msgStore.Reset()) // Then the sender and target seqnums should be - assert.Equal(t, 1, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 1, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(1, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(1, s.msgStore.NextTargetMsgSeqNum()) // When the store is refreshed from its backing store - suite.msgStore.Refresh() + s.Require().Nil(s.msgStore.Refresh()) // Then the sender and target seqnums should still be - assert.Equal(t, 1, suite.msgStore.NextSenderMsgSeqNum()) - assert.Equal(t, 1, suite.msgStore.NextTargetMsgSeqNum()) + s.Equal(1, s.msgStore.NextSenderMsgSeqNum()) + s.Equal(1, s.msgStore.NextTargetMsgSeqNum()) } -func (suite *MessageStoreTestSuite) TestMessageStore_SaveMessage_GetMessage() { - t := suite.T() - +func (s *MessageStoreTestSuite) TestMessageStore_SaveMessage_GetMessage() { // Given the following saved messages expectedMsgsBySeqNum := map[int]string{ 1: "In the frozen land of Nador", @@ -92,31 +86,31 @@ func (suite *MessageStoreTestSuite) TestMessageStore_SaveMessage_GetMessage() { 3: "and there was much rejoicing", } for seqNum, msg := range expectedMsgsBySeqNum { - require.Nil(t, suite.msgStore.SaveMessage(seqNum, []byte(msg))) + s.Require().Nil(s.msgStore.SaveMessage(seqNum, []byte(msg))) } // When the messages are retrieved from the MessageStore - actualMsgs, err := suite.msgStore.GetMessages(1, 3) - require.Nil(t, err) + actualMsgs, err := s.msgStore.GetMessages(1, 3) + s.Require().Nil(err) // Then the messages should be - require.Len(t, actualMsgs, 3) - assert.Equal(t, expectedMsgsBySeqNum[1], string(actualMsgs[0])) - assert.Equal(t, expectedMsgsBySeqNum[2], string(actualMsgs[1])) - assert.Equal(t, expectedMsgsBySeqNum[3], string(actualMsgs[2])) + s.Require().Len(actualMsgs, 3) + s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0])) + s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1])) + s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) // When the store is refreshed from its backing store - suite.msgStore.Refresh() + s.Require().Nil(s.msgStore.Refresh()) // And the messages are retrieved from the MessageStore - actualMsgs, err = suite.msgStore.GetMessages(1, 3) - require.Nil(t, err) + actualMsgs, err = s.msgStore.GetMessages(1, 3) + s.Require().Nil(err) // Then the messages should still be - require.Len(t, actualMsgs, 3) - assert.Equal(t, expectedMsgsBySeqNum[1], string(actualMsgs[0])) - assert.Equal(t, expectedMsgsBySeqNum[2], string(actualMsgs[1])) - assert.Equal(t, expectedMsgsBySeqNum[3], string(actualMsgs[2])) + s.Require().Len(actualMsgs, 3) + s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0])) + s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1])) + s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2])) } func (suite *MessageStoreTestSuite) TestMessageStore_GetMessages_EmptyStore() { @@ -163,12 +157,12 @@ func (suite *MessageStoreTestSuite) TestMessageStore_GetMessages_VariousRanges() } } -func (suite *MessageStoreTestSuite) TestMessageStore_CreationTime() { - assert.False(suite.T(), suite.msgStore.CreationTime().IsZero()) +func (s *MessageStoreTestSuite) TestMessageStore_CreationTime() { + s.False(s.msgStore.CreationTime().IsZero()) t0 := time.Now() - suite.msgStore.Reset() + s.Require().Nil(s.msgStore.Reset()) t1 := time.Now() - require.True(suite.T(), suite.msgStore.CreationTime().After(t0)) - require.True(suite.T(), suite.msgStore.CreationTime().Before(t1)) + s.Require().True(s.msgStore.CreationTime().After(t0)) + s.Require().True(s.msgStore.CreationTime().Before(t1)) } diff --git a/tag.go b/tag.go index 800375e34..20fcda0b7 100644 --- a/tag.go +++ b/tag.go @@ -43,6 +43,7 @@ const ( tagBusinessRejectReason Tag = 380 tagSessionRejectReason Tag = 373 tagRefMsgType Tag = 372 + tagBusinessRejectRefID Tag = 379 tagRefTagID Tag = 371 tagRefSeqNum Tag = 45 tagEncryptMethod Tag = 98 diff --git a/tls.go b/tls.go index ef6846f61..951ac7e87 100644 --- a/tls.go +++ b/tls.go @@ -18,6 +18,14 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) } } + var serverName string + if settings.HasSetting(config.SocketServerName) { + serverName, err = settings.Setting(config.SocketServerName) + if err != nil { + return + } + } + insecureSkipVerify := false if settings.HasSetting(config.SocketInsecureSkipVerify) { insecureSkipVerify, err = settings.BoolSetting(config.SocketInsecureSkipVerify) @@ -29,7 +37,9 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) if !settings.HasSetting(config.SocketPrivateKeyFile) && !settings.HasSetting(config.SocketCertificateFile) { if allowSkipClientCerts { tlsConfig = defaultTLSConfig() + tlsConfig.ServerName = serverName tlsConfig.InsecureSkipVerify = insecureSkipVerify + setMinVersionExplicit(settings, tlsConfig) } return } @@ -46,26 +56,9 @@ func loadTLSConfig(settings *SessionSettings) (tlsConfig *tls.Config, err error) tlsConfig = defaultTLSConfig() tlsConfig.Certificates = make([]tls.Certificate, 1) + tlsConfig.ServerName = serverName tlsConfig.InsecureSkipVerify = insecureSkipVerify - - minVersion := "TLS12" - if settings.HasSetting(config.SocketMinimumTLSVersion) { - minVersion, err = settings.Setting(config.SocketMinimumTLSVersion) - if err != nil { - return - } - - switch minVersion { - case "SSL30": - tlsConfig.MinVersion = tls.VersionSSL30 - case "TLS10": - tlsConfig.MinVersion = tls.VersionTLS10 - case "TLS11": - tlsConfig.MinVersion = tls.VersionTLS11 - case "TLS12": - tlsConfig.MinVersion = tls.VersionTLS12 - } - } + setMinVersionExplicit(settings, tlsConfig) if tlsConfig.Certificates[0], err = tls.LoadX509KeyPair(certificateFile, privateKeyFile); err != nil { return @@ -112,3 +105,24 @@ func defaultTLSConfig() *tls.Config { }, } } + +func setMinVersionExplicit(settings *SessionSettings, tlsConfig *tls.Config) { + if settings.HasSetting(config.SocketMinimumTLSVersion) { + minVersion, err := settings.Setting(config.SocketMinimumTLSVersion) + if err != nil { + return + } + + switch minVersion { + case "SSL30": + //nolint:staticcheck // SA1019 min version ok + tlsConfig.MinVersion = tls.VersionSSL30 + case "TLS10": + tlsConfig.MinVersion = tls.VersionTLS10 + case "TLS11": + tlsConfig.MinVersion = tls.VersionTLS11 + case "TLS12": + tlsConfig.MinVersion = tls.VersionTLS12 + } + } +} diff --git a/tls_test.go b/tls_test.go index 3ddbddeaf..fe6745a19 100644 --- a/tls_test.go +++ b/tls_test.go @@ -87,6 +87,27 @@ func (s *TLSTestSuite) TestLoadTLSWithCA() { s.Equal(tls.RequireAndVerifyClientCert, tlsConfig.ClientAuth) } +func (s *TLSTestSuite) TestServerNameUseSSL() { + s.settings.GlobalSettings().Set(config.SocketUseSSL, "Y") + s.settings.GlobalSettings().Set(config.SocketServerName, "DummyServerNameUseSSL") + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + s.Equal("DummyServerNameUseSSL", tlsConfig.ServerName) +} + +func (s *TLSTestSuite) TestServerNameWithCerts() { + s.settings.GlobalSettings().Set(config.SocketPrivateKeyFile, s.PrivateKeyFile) + s.settings.GlobalSettings().Set(config.SocketCertificateFile, s.CertificateFile) + s.settings.GlobalSettings().Set(config.SocketServerName, "DummyServerNameWithCerts") + + tlsConfig, err := loadTLSConfig(s.settings.GlobalSettings()) + s.Nil(err) + s.NotNil(tlsConfig) + s.Equal("DummyServerNameWithCerts", tlsConfig.ServerName) +} + func (s *TLSTestSuite) TestInsecureSkipVerify() { s.settings.GlobalSettings().Set(config.SocketInsecureSkipVerify, "Y") @@ -129,6 +150,7 @@ func (s *TLSTestSuite) TestMinimumTLSVersion() { s.Nil(err) s.NotNil(tlsConfig) + //nolint:staticcheck s.Equal(tlsConfig.MinVersion, uint16(tls.VersionSSL30)) // TLS10 diff --git a/validation.go b/validation.go index dc17abc58..59edfac31 100644 --- a/validation.go +++ b/validation.go @@ -116,7 +116,7 @@ func validateFIXT(transportDD, appDD *datadictionary.DataDictionary, settings va } func validateMsgType(d *datadictionary.DataDictionary, msgType string, msg *Message) MessageRejectError { - if _, validMsgType := d.Messages[msgType]; validMsgType == false { + if _, validMsgType := d.Messages[msgType]; !validMsgType { return InvalidMessageType() } return nil