From 25029927194174df7e64aeab902015377674c2d2 Mon Sep 17 00:00:00 2001 From: samjtro Date: Fri, 25 Oct 2024 10:11:08 -0500 Subject: [PATCH] major: ref: #67, v0.9.0 desc: - refactor utils.go to reflect removal of previous oauth implementation in favor of x/oauth2, vis-a-vis go-schwab/oauth2ns - check -> isErrNil --- accounts-trading.go | 78 ++++++++-------- go.mod | 11 ++- go.sum | 21 +++++ market-data.go | 52 ++++++----- utils.go | 213 +++++++++++++++++--------------------------- 5 files changed, 185 insertions(+), 190 deletions(-) diff --git a/accounts-trading.go b/accounts-trading.go index e0f7c40..882cf31 100644 --- a/accounts-trading.go +++ b/accounts-trading.go @@ -9,6 +9,10 @@ import ( "github.com/bytedance/sonic" ) +/* TODO: +[ ] http.NewRequest -> agent.client +*/ + var ( accountEndpoint string = "https://api.schwabapi.com/trader/v1" @@ -16,16 +20,16 @@ var ( endpointAccountNumbers string = accountEndpoint + "/accounts/accountNumbers" endpointAccounts string = accountEndpoint + "/accounts" endpointAccount string = accountEndpoint + "/accounts/%s" - //endpointUserPreference string = accountEndpoint + "/userPreference" + // endpointUserPreference string = accountEndpoint + "/userPreference" // Orders endpointOrders string = accountEndpoint + "/orders" endpointAccountOrders string = accountEndpoint + "/accounts/%s/orders" endpointAccountOrder string = accountEndpoint + "/accounts/%s/orders/%s" - //endpointPreviewOrder string = accountEndpoint + "/accounts/%s/previewOrder" + // endpointPreviewOrder string = accountEndpoint + "/accounts/%s/previewOrder" // Transactions - //endpointTransactions string = accountEndpoint + "/accounts/%s/transactions" + // endpointTransactions string = accountEndpoint + "/accounts/%s/transactions" endpointTransaction string = accountEndpoint + "/accounts/%s/transactions/%s" ) @@ -315,8 +319,10 @@ type SimpleOrderInstrument struct { AssetType string // EQUITY } -type SingleLegOrderComposition func(order *SingleLegOrder) -type MultiLegSimpleOrderComposition func(order *MultiLegOrder) +type ( + SingleLegOrderComposition func(order *SingleLegOrder) + MultiLegSimpleOrderComposition func(order *MultiLegOrder) +) // Create a new Market order func CreateSingleLegOrder(opts ...SingleLegOrderComposition) *SingleLegOrder { @@ -431,25 +437,25 @@ func marshalMultiLegOrder(order *MultiLegOrder) string { func (agent *Agent) SubmitSingleLegOrder(hashValue string, order *SingleLegOrder) error { orderJson := marshalSingleLegOrder(order) req, err := http.NewRequest("POST", fmt.Sprintf(endpointAccountOrders, hashValue), strings.NewReader(orderJson)) - check(err) + isErrNil(err) req.Header.Set("Content-Type", "application/json") _, err = agent.Handler(req) - check(err) + isErrNil(err) return nil } // Get a specific order by account number & order ID func (agent *Agent) GetOrder(accountNumber, orderID string) (FullOrder, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointAccountOrder, accountNumber, orderID), nil) - check(err) + isErrNil(err) resp, err := agent.Handler(req) - check(err) + isErrNil(err) var order FullOrder defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) err = sonic.Unmarshal(body, &order) - check(err) + isErrNil(err) return order, nil } @@ -457,19 +463,19 @@ func (agent *Agent) GetOrder(accountNumber, orderID string) (FullOrder, error) { // yyyy-MM-ddTHH:mm:ss.SSSZ func (agent *Agent) GetAccountOrders(accountNumber, fromEnteredTime, toEnteredTime string) ([]FullOrder, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointAccountOrders, accountNumber), nil) - check(err) + isErrNil(err) q := req.URL.Query() q.Add("fromEnteredTime", fromEnteredTime) q.Add("toEnteredTime", toEnteredTime) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) var orders []FullOrder defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) err = sonic.Unmarshal(body, &orders) - check(err) + isErrNil(err) return orders, nil } @@ -478,19 +484,19 @@ func (agent *Agent) GetAccountOrders(accountNumber, fromEnteredTime, toEnteredTi // yyyy-MM-ddTHH:mm:ss.SSSZ func (agent *Agent) GetAllOrders(fromEnteredTime, toEnteredTime string) ([]FullOrder, error) { req, err := http.NewRequest("GET", endpointOrders, nil) - check(err) + isErrNil(err) q := req.URL.Query() q.Add("fromEnteredTime", fromEnteredTime) q.Add("toEnteredTime", toEnteredTime) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var orders []FullOrder /*err = sonic.Unmarshal(body, &orders) - check(err)*/ + isErrNil(err)*/ fmt.Println(body) return orders, nil } @@ -498,62 +504,62 @@ func (agent *Agent) GetAllOrders(fromEnteredTime, toEnteredTime string) ([]FullO // Get encrypted account numbers for trading func (agent *Agent) GetAccountNumbers() ([]AccountNumbers, error) { req, err := http.NewRequest("GET", endpointAccountNumbers, nil) - check(err) + isErrNil(err) resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var accountNumbers []AccountNumbers err = sonic.Unmarshal(body, &accountNumbers) - check(err) + isErrNil(err) return accountNumbers, nil } // Get all accounts associated with the user logged in func (agent *Agent) GetAccounts() ([]Account, error) { req, err := http.NewRequest("GET", endpointAccounts, nil) - check(err) + isErrNil(err) resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var accounts []Account err = sonic.Unmarshal(body, &accounts) - check(err) + isErrNil(err) return accounts, nil } // Get account by encrypted account id func (agent *Agent) GetAccount(id string) (Account, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointAccount, id), nil) - check(err) + isErrNil(err) resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var account Account err = sonic.Unmarshal(body, &account) - check(err) + isErrNil(err) return account, nil } // Get all transactions for the user logged in -//func (agent *Agent) GetTransactions() ([]Transaction, error) {} +// func (agent *Agent) GetTransactions() ([]Transaction, error) {} // Get a transaction for a specific account id func (agent *Agent) GetTransaction(accountNumber, transactionId string) (Transaction, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointTransaction, accountNumber, transactionId), nil) - check(err) + isErrNil(err) resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var transaction Transaction err = sonic.Unmarshal(body, &transaction) - check(err) + isErrNil(err) return transaction, nil } diff --git a/go.mod b/go.mod index a4bb5aa..f65d26f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-schwab/trader -go 1.22.4 +go 1.23.2 require github.com/joho/godotenv v1.5.1 @@ -9,7 +9,16 @@ require ( github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/fatih/color v1.17.0 // indirect + github.com/go-schwab/oauth2ns v0.0.0-20241015193425-e8abfd05a439 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/nmrshll/rndm-go v0.0.0-20170430161430-8da3024e53de // indirect + github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177 // indirect + github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + golang.org/x/oauth2 v0.23.0 // indirect + golang.org/x/sys v0.18.0 // indirect ) diff --git a/go.sum b/go.sum index 1f1d15a..c77375a 100644 --- a/go.sum +++ b/go.sum @@ -13,13 +13,28 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= +github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= +github.com/go-schwab/oauth2ns v0.0.0-20241015193425-e8abfd05a439 h1:y/AmrAYZNvqI7IaMLllCjKlqOBbtYTe/KKtE+gip2c8= +github.com/go-schwab/oauth2ns v0.0.0-20241015193425-e8abfd05a439/go.mod h1:69d3XNxDSeVmN6g+8a0RT+a13ZxDc5LRKODEZB3d5go= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/nmrshll/rndm-go v0.0.0-20170430161430-8da3024e53de h1:j+mSQhCm1H2d7apFbM5ODqrTultUvF3jt//DcRNkxVM= +github.com/nmrshll/rndm-go v0.0.0-20170430161430-8da3024e53de/go.mod h1:OeEnWnbCrUWnPl1xSCGM5/qtWqZ4L15KOAjR/wmxhXc= +github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177 h1:nRlQD0u1871kaznCnn1EvYiMbum36v7hw1DLPEjds4o= +github.com/palantir/stacktrace v0.0.0-20161112013806-78658fd2d177/go.mod h1:ao5zGxj8Z4x60IOVYZUbDSmt3R8Ddo080vEgPosHpak= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA= +github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -32,6 +47,12 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/market-data.go b/market-data.go index 0125222..7480a86 100644 --- a/market-data.go +++ b/market-data.go @@ -9,6 +9,10 @@ import ( "github.com/bytedance/sonic" ) +/* TODO: +[ ] http.NewRequest -> agent.client +*/ + var ( endpoint string = "https://api.schwabapi.com/marketdata/v1" @@ -197,19 +201,19 @@ type Contract struct { // ticker = "AAPL", etc. func (agent *Agent) GetQuote(symbol string) (Quote, error) { req, err := http.NewRequest("GET", endpointQuotes, nil) - check(err) + isErrNil(err) q := req.URL.Query() q.Add("symbols", symbol) q.Add("fields", "quote") req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var quote Quote err = sonic.Unmarshal([]byte(strings.Join(strings.Split(strings.Split(string(body), fmt.Sprintf("\"%s\":", symbol))[1], "\"quote\":{"), "")[:len(strings.Join(strings.Split(strings.Split(string(body), fmt.Sprintf("\"%s\":", symbol))[1], "\"quote\":{"), ""))-2]), "e) - check(err) + isErrNil(err) return quote, err } @@ -217,19 +221,19 @@ func (agent *Agent) GetQuote(symbol string) (Quote, error) { // It takes one param: func (agent *Agent) SearchInstrumentSimple(symbols string) (SimpleInstrument, error) { req, err := http.NewRequest("GET", endpointSearchInstrument, nil) - check(err) + isErrNil(err) q := req.URL.Query() q.Add("symbol", symbols) q.Add("projection", "symbol-search") req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var instrument SimpleInstrument err = sonic.Unmarshal([]byte(strings.Split(string(body), "[")[1][:len(strings.Split(string(body), "[")[1])-2]), &instrument) - check(err) + isErrNil(err) return instrument, nil } @@ -237,21 +241,21 @@ func (agent *Agent) SearchInstrumentSimple(symbols string) (SimpleInstrument, er // It takes one param: func (agent *Agent) SearchInstrumentFundamental(symbol string) (FundamentalInstrument, error) { req, err := http.NewRequest("GET", endpointSearchInstrument, nil) - check(err) + isErrNil(err) q := req.URL.Query() q.Add("symbol", symbol) q.Add("projection", "fundamental") req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var instrument FundamentalInstrument split0 := strings.Split(string(body), "[{\"fundamental\":")[1] split := strings.Split(split0, "}") err = sonic.Unmarshal([]byte(fmt.Sprintf("%s}", strings.Join(split[:2], ""))), &instrument) - check(err) + isErrNil(err) return instrument, nil } @@ -274,7 +278,7 @@ func (agent *Agent) SearchInstrumentFundamental(symbol string) (FundamentalInstr // endDate = func (agent *Agent) GetPriceHistory(symbol, periodType, period, frequencyType, frequency, startDate, endDate string) ([]Candle, error) { req, err := http.NewRequest("GET", endpointPriceHistory, nil) - check(err) + isErrNil(err) q := req.URL.Query() q.Add("symbol", symbol) q.Add("periodType", periodType) @@ -285,13 +289,13 @@ func (agent *Agent) GetPriceHistory(symbol, periodType, period, frequencyType, f q.Add("endDate", endDate) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var candles []Candle err = sonic.Unmarshal([]byte(fmt.Sprintf("[%s]", strings.Split(strings.Split(string(body), "[")[1], "]")[0])), &candles) - check(err) + isErrNil(err) return candles, nil } @@ -302,20 +306,20 @@ func (agent *Agent) GetPriceHistory(symbol, periodType, period, frequencyType, f // change = "percent" or "value" func (agent *Agent) GetMovers(index, direction, change string) ([]Screener, error) { req, err := http.NewRequest("GET", fmt.Sprintf(endpointMovers, index), nil) - check(err) + isErrNil(err) q := req.URL.Query() q.Add("direction", direction) q.Add("change", change) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var movers []Screener stringToParse := fmt.Sprintf("[%s]", strings.Split(string(body), "[")[1][:len(strings.Split(string(body), "[")[1])-2]) err = sonic.Unmarshal([]byte(stringToParse), &movers) - check(err) + isErrNil(err) return movers, nil } @@ -337,7 +341,7 @@ func (agent *Agent) GetMovers(index, direction, change string) ([]Screener, erro // This returns 5 AAPL CALL contracts both above and below the at the money price, with no preference as to the status of the contract ("ALL"), expiring before 2022-07-01 func (agent *Agent) Single(ticker, contractType, strikeRange, strikeCount, toDate string) ([]Contract, error) { req, err := http.NewRequest("GET", endpointOptions, nil) - check(err) + isErrNil(err) q := req.URL.Query() q.Add("symbol", ticker) q.Add("contractType", contractType) @@ -346,14 +350,14 @@ func (agent *Agent) Single(ticker, contractType, strikeRange, strikeCount, toDat q.Add("toDate", toDate) req.URL.RawQuery = q.Encode() resp, err := agent.Handler(req) - check(err) + isErrNil(err) defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) + isErrNil(err) var chain []Contract // WIP err = sonic.Unmarshal(body, &chain) - check(err) + isErrNil(err) return chain, nil } diff --git a/utils.go b/utils.go index 5e587bd..5a5c50b 100644 --- a/utils.go +++ b/utils.go @@ -1,66 +1,65 @@ package trader import ( - "bytes" - "encoding/base64" - "encoding/json" + "context" "errors" "fmt" "io" "log" "net/http" - "net/url" "os" - "os/exec" - "runtime" "strings" "time" + "github.com/bytedance/sonic" + "github.com/go-schwab/oauth2ns" + o "github.com/go-schwab/oauth2ns" "github.com/joho/godotenv" + "golang.org/x/oauth2" ) -func init() { - err := godotenv.Load("config.env") - check(err) -} +type Agent struct{ client *o.AuthorizedClient } -type Agent struct { - tokens Token +type DB struct { + AccessToken string + RefreshToken string + TokenType string + Expiry time.Time + ExpiresIn int64 } -type Token struct { - RefreshExpiration time.Time - Refresh string - BearerExpiration time.Time - Bearer string -} +var Tokens DB -// Helper: parse access token response -func parseAccessTokenResponse(s string) Token { - token := Token{ - RefreshExpiration: time.Now().Add(time.Hour * 168), - BearerExpiration: time.Now().Add(time.Minute * 30), - } - for _, x := range strings.Split(s, ",") { - for i1, x1 := range strings.Split(x, ":") { - if trimOneFirstOneLast(x1) == "refresh_token" { - token.Refresh = trimOneFirstOneLast(strings.Split(x, ":")[i1+1]) - } else if trimOneFirstOneLast(x1) == "access_token" { - token.Bearer = trimOneFirstOneLast(strings.Split(x, ":")[i1+1]) - } - } - } - return token +func init() { + err := godotenv.Load("*.env") + isErrNil(err) } // Read in tokens from ~/.trade/bar.json -func readDB() Token { - var tokens Token +func readDB() *oauth2ns.AuthorizedClient { body, err := os.ReadFile(fmt.Sprintf("%s/.trade/bar.json", homeDir())) - check(err) - err = json.Unmarshal(body, &tokens) - check(err) - return tokens + isErrNil(err) + var ctx context.Context + err = sonic.Unmarshal(body, &Tokens) + isErrNil(err) + token := new(oauth2.Token) + token.AccessToken = Tokens.AccessToken + token.RefreshToken = Tokens.RefreshToken + token.TokenType = Tokens.TokenType + token.Expiry = Tokens.Expiry + token.ExpiresIn = Tokens.ExpiresIn + c := &oauth2.Config{ + ClientID: os.Getenv("APPKEY"), + ClientSecret: os.Getenv("SECRET"), + Endpoint: oauth2.Endpoint{ + AuthURL: "https://api.schwabapi.com/v1/oauth/authorize", + TokenURL: "https://api.schwabapi.com/v1/oauth/token", + }, + } + return &o.AuthorizedClient{ + c.Client(ctx, token), + token, + } } // Credit: https://go.dev/play/p/C2sZRYC15XN @@ -77,24 +76,8 @@ func getStringInBetween(str string, start string, end string) (result string) { return str[s : s+e] } -// Credit: https://gist.github.com/hyg/9c4afcd91fe24316cbf0 -func openBrowser(url string) { - var err error - switch runtime.GOOS { - case "linux": - err = exec.Command("xdg-open", url).Start() - case "windows": - err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - case "darwin": - err = exec.Command("open", url).Start() - default: - log.Fatalf("Unsupported platform.") - } - check(err) -} - -// Generic error checking, will be implementing more robust error/exception handling >v0.9.0 -func check(err error) { +// is the err nil? +func isErrNil(err error) { if err != nil { log.Fatalf("[ERR] %s", err.Error()) } @@ -151,101 +134,73 @@ func trimOneFirstThreeLast(s string) string { // wrapper for os.UserHomeDir() func homeDir() string { dir, err := os.UserHomeDir() - check(err) + isErrNil(err) return dir } -// Initiate the Schwab oAuth process to retrieve bearer/refresh tokens func Initiate() *Agent { - agent := Agent{} + var agent Agent if _, err := os.Stat(fmt.Sprintf("%s/.trade", homeDir())); errors.Is(err, os.ErrNotExist) { - err := os.Mkdir(fmt.Sprintf("%s/.trade", homeDir()), os.ModePerm) - check(err) - // oAuth Leg 1 - Authorization Code - openBrowser(fmt.Sprintf("https://api.schwabapi.com/v1/oauth/authorize?client_id=%s&redirect_uri=%s", os.Getenv("APPKEY"), os.Getenv("CBURL"))) - fmt.Printf("Log into your Schwab brokerage account. Copy Error404 URL and paste it here: ") - var urlInput string - fmt.Scanln(&urlInput) - authCodeEncoded := getStringInBetween(urlInput, "?code=", "&session=") - authCode, err := url.QueryUnescape(authCodeEncoded) - check(err) - // oAuth Leg 2 - Refresh, Bearer Tokens - authStringLegTwo := fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", os.Getenv("APPKEY"), os.Getenv("SECRET"))))) - client := http.Client{} - payload := fmt.Sprintf("grant_type=authorization_code&code=%s&redirect_uri=%s", string(authCode), os.Getenv("CBURL")) - req, err := http.NewRequest("POST", "https://api.schwabapi.com/v1/oauth/token", bytes.NewBuffer([]byte(payload))) - check(err) - req.Header = http.Header{ - "Authorization": {authStringLegTwo}, - "Content-Type": {"application/x-www-form-urlencoded"}, - } - res, err := client.Do(req) - check(err) - defer res.Body.Close() - bodyBytes, err := io.ReadAll(res.Body) - check(err) - agent.tokens = parseAccessTokenResponse(string(bodyBytes)) - tokensJson, err := json.Marshal(agent.tokens) - check(err) - err = os.WriteFile(fmt.Sprintf("%s/.trade/bar.json", homeDir()), tokensJson, 0777) - check(err) + err = os.Mkdir(fmt.Sprintf("%s/.trade", homeDir()), os.ModePerm) + isErrNil(err) + agent.client, err = o.Run() + isErrNil(err) + Tokens.AccessToken = agent.client.Token.AccessToken + Tokens.RefreshToken = agent.client.Token.RefreshToken + Tokens.TokenType = agent.client.Token.TokenType + Tokens.Expiry = agent.client.Token.Expiry + Tokens.ExpiresIn = agent.client.Token.ExpiresIn + bytes, err := sonic.Marshal(Tokens) + err = os.WriteFile(fmt.Sprintf("%s/.trade/bar.json", homeDir()), bytes, 0777) + isErrNil(err) } else { - agent.tokens = readDB() - if agent.tokens.Bearer == "" { + agent.client = readDB() + if Tokens.AccessToken == "" { err := os.RemoveAll(fmt.Sprintf("%s/.trade", homeDir())) - check(err) - log.Fatalf("[err] please reinitiate, something went wrong\n") + isErrNil(err) + log.Fatalf("[err] something went wrong - please reinitiate with 'Initiate'") } } return &agent } -// Use refresh token to generate a new bearer token for authentication -func (agent *Agent) refresh() { - oldTokens := readDB() - authStringRefresh := fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", os.Getenv("APPKEY"), os.Getenv("SECRET"))))) - client := http.Client{} - req, err := http.NewRequest("POST", "https://api.schwabapi.com/v1/oauth/token", bytes.NewBuffer([]byte(fmt.Sprintf("grant_type=refresh_token&refresh_token=%s", oldTokens.Refresh)))) - check(err) - req.Header = http.Header{ - "Authorization": {authStringRefresh}, - "Content-Type": {"application/x-www-form-urlencoded"}, - } - res, err := client.Do(req) - check(err) - defer res.Body.Close() - bodyBytes, err := io.ReadAll(res.Body) - check(err) - agent.tokens = parseAccessTokenResponse(string(bodyBytes)) -} - // Handler is the general purpose request function for the td-ameritrade api, all functions will be routed through this handler function, which does all of the API calling work // It performs a GET request after adding the apikey found in the config.env file in the same directory as the program calling the function, // then returns the body of the GET request's return. // It takes one parameter: // req = a request of type *http.Request func (agent *Agent) Handler(req *http.Request) (*http.Response, error) { - if (&Agent{}) == agent { - log.Fatal("[ERR] empty agent - call 'Agent.Initiate' before making any API function calls.") - } - if !time.Now().Before(agent.tokens.BearerExpiration) { - agent.refresh() - } - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", agent.tokens.Bearer)) + var err error + if Tokens.AccessToken == "" { + log.Fatalf("[err] no access token found, please reinitiate with 'Initiate'") + } + if ((&Agent{}) == agent) || ((&o.AuthorizedClient{}) == agent.client) { + agent.client, err = o.Run() + isErrNil(err) + Tokens.AccessToken = agent.client.Token.AccessToken + Tokens.RefreshToken = agent.client.Token.RefreshToken + Tokens.TokenType = agent.client.Token.TokenType + Tokens.Expiry = agent.client.Token.Expiry + Tokens.ExpiresIn = agent.client.Token.ExpiresIn + } + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", agent.client.Token.AccessToken)) client := http.Client{} resp, err := client.Do(req) if err != nil { return resp, err } - if resp.StatusCode == 401 { - err := os.Remove(fmt.Sprintf("%s/.trade", homeDir())) - check(err) - } - if resp.StatusCode < 200 || resp.StatusCode > 300 { + // TODO: test this block + var statusErr error + switch true { + case resp.StatusCode == 401: + err := os.RemoveAll(fmt.Sprintf("%s/.trade", homeDir())) + isErrNil(err) + statusErr = errors.New("[err] invalid token - please reinitiate with 'Initiate'") + case resp.StatusCode < 200, resp.StatusCode > 300: defer resp.Body.Close() body, err := io.ReadAll(resp.Body) - check(err) - log.Fatalf("[ERR] %d, %s", resp.StatusCode, body) + isErrNil(err) + statusErr = errors.New(fmt.Sprintf("[err] %d - %s", resp.StatusCode, body)) } - return resp, nil + return resp, statusErr }