diff --git a/e2e/tests/dynamic/page/cookies/delete.fqlx b/e2e/tests/dynamic/page/cookies/delete.fql similarity index 88% rename from e2e/tests/dynamic/page/cookies/delete.fqlx rename to e2e/tests/dynamic/page/cookies/delete.fql index a7cf278f..9a58a5d3 100644 --- a/e2e/tests/dynamic/page/cookies/delete.fqlx +++ b/e2e/tests/dynamic/page/cookies/delete.fql @@ -15,7 +15,7 @@ COOKIE_DEL(doc, COOKIE_GET(doc, "x-e2e"), "x-e2e-2") LET cookie1 = COOKIE_GET(doc, "x-e2e") LET cookie2 = COOKIE_GET(doc, "x-e2e-2") -T::EQ(cookie1, "none") -T::EQ(cookie2, "none") +T::EQ(cookie1, NONE) +T::EQ(cookie2, NONE) RETURN NONE \ No newline at end of file diff --git a/e2e/tests/dynamic/page/cookies/get.fqlx b/e2e/tests/dynamic/page/cookies/get.fqlx index 7895f5bc..06bb2b90 100644 --- a/e2e/tests/dynamic/page/cookies/get.fqlx +++ b/e2e/tests/dynamic/page/cookies/get.fqlx @@ -7,6 +7,6 @@ LET cookiesPath = LENGTH(doc.cookies) > 0 ? "ok" : "false" LET cookie = COOKIE_GET(doc, "x-ferret") LET expected = "ok e2e" -T::LEN(doc.cookies +T::LEN(doc.cookies, 1) RETURN T::EQ(cookiesPath + " " + cookie.value, expected) \ No newline at end of file diff --git a/e2e/tests/dynamic/page/cookies/load_with.fqlx b/e2e/tests/dynamic/page/cookies/load_with.fql similarity index 56% rename from e2e/tests/dynamic/page/cookies/load_with.fqlx rename to e2e/tests/dynamic/page/cookies/load_with.fql index a25dd253..e4ba2b2e 100644 --- a/e2e/tests/dynamic/page/cookies/load_with.fqlx +++ b/e2e/tests/dynamic/page/cookies/load_with.fql @@ -7,8 +7,9 @@ LET doc = DOCUMENT(url, { }] }) -LET cookiesPath = LENGTH(doc.cookies) > 0 ? "ok" : "false" LET cookie = COOKIE_GET(doc, "x-e2e") -LET expected = "ok test" -RETURN T::EQ(cookiesPath + " " + cookie.value, expected) \ No newline at end of file +T::NOT::NONE(cookie) +T::EQ(cookie.value, "test") + +RETURN NONE \ No newline at end of file diff --git a/e2e/tests/dynamic/page/cookies/set.fqlx b/e2e/tests/dynamic/page/cookies/set.fql similarity index 100% rename from e2e/tests/dynamic/page/cookies/set.fqlx rename to e2e/tests/dynamic/page/cookies/set.fql diff --git a/examples/headers.fql b/examples/headers.fql new file mode 100644 index 00000000..aa565ca6 --- /dev/null +++ b/examples/headers.fql @@ -0,0 +1,5 @@ +LET proxy_header = {"Proxy-Authorization": ["Basic e40b7d5eff464a4fb51efed2d1a19a24"]} + +LET doc = DOCUMENT("https://google.com", { headers: proxy_header}) + +RETURN doc \ No newline at end of file diff --git a/pkg/drivers/cdp/driver.go b/pkg/drivers/cdp/driver.go index a340f88b..32e0ee51 100644 --- a/pkg/drivers/cdp/driver.go +++ b/pkg/drivers/cdp/driver.go @@ -36,7 +36,7 @@ type Driver struct { func NewDriver(opts ...Option) *Driver { drv := new(Driver) - drv.options = newOptions(opts) + drv.options = NewOptions(opts) drv.dev = devtool.New(drv.options.Address) return drv @@ -137,43 +137,11 @@ func (drv *Driver) createConnection(ctx context.Context, keepCookies bool) (*rpc } func (drv *Driver) setDefaultParams(params drivers.Params) drivers.Params { - if params.UserAgent == "" { - params.UserAgent = drv.options.UserAgent - } - if params.Viewport == nil { params.Viewport = defaultViewport } - if drv.options.Headers != nil && params.Headers == nil { - params.Headers = make(drivers.HTTPHeaders) - } - - // set default headers - for k, v := range drv.options.Headers { - _, exists := params.Headers[k] - - // do not override user's set values - if !exists { - params.Headers[k] = v - } - } - - if drv.options.Cookies != nil && params.Cookies == nil { - params.Cookies = make(drivers.HTTPCookies) - } - - // set default cookies - for k, v := range drv.options.Cookies { - _, exists := params.Cookies[k] - - // do not override user's set values - if !exists { - params.Cookies[k] = v - } - } - - return params + return drivers.SetDefaultParams(drv.options.Options, params) } func (drv *Driver) init(ctx context.Context) error { diff --git a/pkg/drivers/cdp/helpers.go b/pkg/drivers/cdp/helpers.go index d702de49..cadc4bad 100644 --- a/pkg/drivers/cdp/helpers.go +++ b/pkg/drivers/cdp/helpers.go @@ -51,12 +51,6 @@ func enableFeatures(ctx context.Context, client *cdp.Client, params drivers.Para func() error { ua := common.GetUserAgent(params.UserAgent) - //logger. - // Debug(). - // Timestamp(). - // Str("user-agent", ua). - // Msg("using User-Agent") - // do not use custom user agent if ua == "" { return nil diff --git a/pkg/drivers/cdp/network/manager.go b/pkg/drivers/cdp/network/manager.go index edb222fb..696731c6 100644 --- a/pkg/drivers/cdp/network/manager.go +++ b/pkg/drivers/cdp/network/manager.go @@ -34,7 +34,7 @@ type ( mu sync.Mutex logger *zerolog.Logger client *cdp.Client - headers drivers.HTTPHeaders + headers *drivers.HTTPHeaders eventLoop *events.Loop cancel context.CancelFunc responseListenerID events.ListenerID @@ -53,12 +53,12 @@ func New( m := new(Manager) m.logger = logger m.client = client - m.headers = make(drivers.HTTPHeaders) + m.headers = drivers.NewHTTPHeaders() m.eventLoop = events.NewLoop() m.cancel = cancel m.response = new(sync.Map) - if len(options.Cookies) > 0 { + if options.Cookies != nil && len(options.Cookies) > 0 { for url, cookies := range options.Cookies { if err := m.setCookiesInternal(ctx, url, cookies); err != nil { return nil, err @@ -66,7 +66,7 @@ func New( } } - if len(options.Headers) > 0 { + if options.Headers != nil && options.Headers.Length() > 0 { if err := m.setHeadersInternal(ctx, options.Headers); err != nil { return nil, err } @@ -104,7 +104,7 @@ func New( m.responseListenerID = m.eventLoop.AddListener(responseReceived, m.onResponse) - if len(options.Filter.Patterns) > 0 { + if options.Filter != nil && len(options.Filter.Patterns) > 0 { el2 := events.NewLoop() err = m.client.Fetch.Enable(ctx, toFetchArgs(options.Filter.Patterns)) @@ -147,87 +147,100 @@ func (m *Manager) Close() error { return nil } -func (m *Manager) GetCookies(ctx context.Context) (drivers.HTTPCookies, error) { +func (m *Manager) GetCookies(ctx context.Context) (*drivers.HTTPCookies, error) { repl, err := m.client.Network.GetAllCookies(ctx) if err != nil { return nil, errors.Wrap(err, "failed to get cookies") } - cookies := make(drivers.HTTPCookies) + cookies := drivers.NewHTTPCookies() if repl.Cookies == nil { return cookies, nil } for _, c := range repl.Cookies { - cookies[c.Name] = toDriverCookie(c) + cookies.Set(toDriverCookie(c)) } return cookies, nil } -func (m *Manager) SetCookies(ctx context.Context, url string, cookies drivers.HTTPCookies) error { +func (m *Manager) SetCookies(ctx context.Context, url string, cookies *drivers.HTTPCookies) error { m.mu.Lock() defer m.mu.Unlock() return m.setCookiesInternal(ctx, url, cookies) } -func (m *Manager) setCookiesInternal(ctx context.Context, url string, cookies drivers.HTTPCookies) error { - if len(cookies) == 0 { +func (m *Manager) setCookiesInternal(ctx context.Context, url string, cookies *drivers.HTTPCookies) error { + if cookies == nil { + return errors.Wrap(core.ErrMissedArgument, "cookies") + } + + if cookies.Length() == 0 { return nil } - params := make([]network.CookieParam, 0, len(cookies)) + params := make([]network.CookieParam, 0, cookies.Length()) - for _, c := range cookies { - params = append(params, fromDriverCookie(url, c)) - } + cookies.ForEach(func(value drivers.HTTPCookie, _ values.String) bool { + params = append(params, fromDriverCookie(url, value)) + + return true + }) return m.client.Network.SetCookies(ctx, network.NewSetCookiesArgs(params)) } -func (m *Manager) DeleteCookies(ctx context.Context, url string, cookies drivers.HTTPCookies) error { +func (m *Manager) DeleteCookies(ctx context.Context, url string, cookies *drivers.HTTPCookies) error { m.mu.Lock() defer m.mu.Unlock() - if len(cookies) == 0 { + if cookies == nil { + return errors.Wrap(core.ErrMissedArgument, "cookies") + } + + if cookies.Length() == 0 { return nil } var err error - for _, c := range cookies { - err = m.client.Network.DeleteCookies(ctx, fromDriverCookieDelete(url, c)) + cookies.ForEach(func(value drivers.HTTPCookie, _ values.String) bool { + err = m.client.Network.DeleteCookies(ctx, fromDriverCookieDelete(url, value)) if err != nil { - break + return false } - } + + return true + }) return err } -func (m *Manager) GetHeaders(_ context.Context) (drivers.HTTPHeaders, error) { - copied := make(drivers.HTTPHeaders) +func (m *Manager) GetHeaders(_ context.Context) (*drivers.HTTPHeaders, error) { + m.mu.Lock() + defer m.mu.Unlock() - for k, v := range m.headers { - copied[k] = v + if m.headers == nil { + return drivers.NewHTTPHeaders(), nil } - return copied, nil + return m.headers.Clone().(*drivers.HTTPHeaders), nil } -func (m *Manager) SetHeaders(ctx context.Context, headers drivers.HTTPHeaders) error { +func (m *Manager) SetHeaders(ctx context.Context, headers *drivers.HTTPHeaders) error { m.mu.Lock() defer m.mu.Unlock() return m.setHeadersInternal(ctx, headers) } -func (m *Manager) setHeadersInternal(ctx context.Context, headers drivers.HTTPHeaders) error { - if len(headers) == 0 { +func (m *Manager) setHeadersInternal(ctx context.Context, headers *drivers.HTTPHeaders) error { + if headers.Length() == 0 { return nil } @@ -461,7 +474,7 @@ func (m *Manager) onResponse(_ context.Context, message interface{}) (out bool) response := drivers.HTTPResponse{ StatusCode: msg.Response.Status, Status: msg.Response.StatusText, - Headers: make(drivers.HTTPHeaders), + Headers: drivers.NewHTTPHeaders(), } deserialized := make(map[string]string) diff --git a/pkg/drivers/cdp/network/options.go b/pkg/drivers/cdp/network/options.go index b013b119..4ce49401 100644 --- a/pkg/drivers/cdp/network/options.go +++ b/pkg/drivers/cdp/network/options.go @@ -6,7 +6,7 @@ import ( ) type ( - Cookies map[string]drivers.HTTPCookies + Cookies map[string]*drivers.HTTPCookies Filter struct { Patterns []drivers.ResourceFilter @@ -14,8 +14,8 @@ type ( Options struct { Cookies Cookies - Headers drivers.HTTPHeaders - Filter Filter + Headers *drivers.HTTPHeaders + Filter *Filter } ) diff --git a/pkg/drivers/cdp/options.go b/pkg/drivers/cdp/options.go index 76afe21b..dd3c5b0b 100644 --- a/pkg/drivers/cdp/options.go +++ b/pkg/drivers/cdp/options.go @@ -4,13 +4,9 @@ import "github.com/MontFerret/ferret/pkg/drivers" type ( Options struct { - Name string - Proxy string - UserAgent string + *drivers.Options Address string KeepCookies bool - Headers drivers.HTTPHeaders - Cookies drivers.HTTPCookies } Option func(opts *Options) @@ -18,8 +14,9 @@ type ( const DefaultAddress = "http://127.0.0.1:9222" -func newOptions(setters []Option) *Options { +func NewOptions(setters []Option) *Options { opts := new(Options) + opts.Options = new(drivers.Options) opts.Name = DriverName opts.Address = DefaultAddress @@ -40,13 +37,13 @@ func WithAddress(address string) Option { func WithProxy(address string) Option { return func(opts *Options) { - opts.Proxy = address + drivers.WithProxy(address)(opts.Options) } } func WithUserAgent(value string) Option { return func(opts *Options) { - opts.UserAgent = value + drivers.WithUserAgent(value)(opts.Options) } } @@ -58,50 +55,30 @@ func WithKeepCookies() Option { func WithCustomName(name string) Option { return func(opts *Options) { - opts.Name = name + drivers.WithCustomName(name)(opts.Options) } } -func WithHeader(name string, value []string) Option { +func WithHeader(name string, header []string) Option { return func(opts *Options) { - if opts.Headers == nil { - opts.Headers = make(drivers.HTTPHeaders) - } - - opts.Headers[name] = value + drivers.WithHeader(name, header)(opts.Options) } } -func WithHeaders(headers drivers.HTTPHeaders) Option { +func WithHeaders(headers *drivers.HTTPHeaders) Option { return func(opts *Options) { - if opts.Headers == nil { - opts.Headers = make(drivers.HTTPHeaders) - } - - for k, v := range headers { - opts.Headers[k] = v - } + drivers.WithHeaders(headers)(opts.Options) } } func WithCookie(cookie drivers.HTTPCookie) Option { return func(opts *Options) { - if opts.Cookies == nil { - opts.Cookies = make(drivers.HTTPCookies) - } - - opts.Cookies[cookie.Name] = cookie + drivers.WithCookie(cookie)(opts.Options) } } func WithCookies(cookies []drivers.HTTPCookie) Option { return func(opts *Options) { - if opts.Cookies == nil { - opts.Cookies = make(drivers.HTTPCookies) - } - - for _, c := range cookies { - opts.Cookies[c.Name] = c - } + drivers.WithCookies(cookies)(opts.Options) } } diff --git a/pkg/drivers/cdp/options_test.go b/pkg/drivers/cdp/options_test.go new file mode 100644 index 00000000..e4aa018d --- /dev/null +++ b/pkg/drivers/cdp/options_test.go @@ -0,0 +1,71 @@ +package cdp_test + +import ( + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + + "github.com/MontFerret/ferret/pkg/drivers" + "github.com/MontFerret/ferret/pkg/drivers/cdp" +) + +func TestNewOptions(t *testing.T) { + Convey("Should create driver options with initial values", t, func() { + opts := cdp.NewOptions([]cdp.Option{}) + So(opts.Options, ShouldNotBeNil) + So(opts.Name, ShouldEqual, cdp.DriverName) + So(opts.Address, ShouldEqual, cdp.DefaultAddress) + }) + + Convey("Should use setters to set values", t, func() { + expectedName := cdp.DriverName + "2" + expectedAddress := "0.0.0.0:9222" + expectedUA := "Mozilla" + expectedProxy := "https://proxy.com" + + opts := cdp.NewOptions([]cdp.Option{ + cdp.WithCustomName(expectedName), + cdp.WithAddress(expectedAddress), + cdp.WithUserAgent(expectedUA), + cdp.WithProxy(expectedProxy), + cdp.WithKeepCookies(), + cdp.WithCookie(drivers.HTTPCookie{ + Name: "Session", + Value: "fsdfsdfs", + Path: "dfsdfsd", + Domain: "sfdsfs", + Expires: time.Time{}, + MaxAge: 0, + Secure: false, + HTTPOnly: false, + SameSite: 0, + }), + cdp.WithCookies([]drivers.HTTPCookie{ + { + Name: "Use", + Value: "Foos", + Path: "", + Domain: "", + Expires: time.Time{}, + MaxAge: 0, + Secure: false, + HTTPOnly: false, + SameSite: 0, + }, + }), + cdp.WithHeader("Authorization", []string{"Bearer dfsd7f98sd9fsd9fsd"}), + cdp.WithHeaders(drivers.NewHTTPHeadersWith(map[string][]string{ + "x-correlation-id": {"232483833833839"}, + })), + }) + So(opts.Options, ShouldNotBeNil) + So(opts.Name, ShouldEqual, expectedName) + So(opts.Address, ShouldEqual, expectedAddress) + So(opts.UserAgent, ShouldEqual, expectedUA) + So(opts.Proxy, ShouldEqual, expectedProxy) + So(opts.KeepCookies, ShouldBeTrue) + So(opts.Cookies.Length(), ShouldEqual, 2) + So(opts.Headers.Length(), ShouldEqual, 2) + }) +} diff --git a/pkg/drivers/cdp/page.go b/pkg/drivers/cdp/page.go index 064d41e7..6fa66408 100644 --- a/pkg/drivers/cdp/page.go +++ b/pkg/drivers/cdp/page.go @@ -2,7 +2,6 @@ package cdp import ( "context" - "github.com/MontFerret/ferret/pkg/drivers/cdp/templates" "hash/fnv" "io" "regexp" @@ -18,6 +17,7 @@ import ( "github.com/MontFerret/ferret/pkg/drivers/cdp/dom" "github.com/MontFerret/ferret/pkg/drivers/cdp/input" net "github.com/MontFerret/ferret/pkg/drivers/cdp/network" + "github.com/MontFerret/ferret/pkg/drivers/cdp/templates" "github.com/MontFerret/ferret/pkg/drivers/common" "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/logging" @@ -73,15 +73,13 @@ func LoadHTMLPage( Headers: params.Headers, } - if len(params.Cookies) > 0 { - netOpts.Cookies = make(map[string]drivers.HTTPCookies) + if params.Cookies != nil && params.Cookies.Length() > 0 { + netOpts.Cookies = make(map[string]*drivers.HTTPCookies) netOpts.Cookies[params.URL] = params.Cookies } - if params.Ignore != nil { - if len(params.Ignore.Resources) > 0 { - netOpts.Filter.Patterns = params.Ignore.Resources - } + if params.Ignore != nil && len(params.Ignore.Resources) > 0 { + netOpts.Filter.Patterns = params.Ignore.Resources } netManager, err := net.New(logger, client, netOpts) @@ -358,21 +356,21 @@ func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, er return frames.Get(idx), nil } -func (p *HTMLPage) GetCookies(ctx context.Context) (drivers.HTTPCookies, error) { +func (p *HTMLPage) GetCookies(ctx context.Context) (*drivers.HTTPCookies, error) { p.mu.Lock() defer p.mu.Unlock() return p.network.GetCookies(ctx) } -func (p *HTMLPage) SetCookies(ctx context.Context, cookies drivers.HTTPCookies) error { +func (p *HTMLPage) SetCookies(ctx context.Context, cookies *drivers.HTTPCookies) error { p.mu.Lock() defer p.mu.Unlock() return p.network.SetCookies(ctx, p.getCurrentDocument().GetURL().String(), cookies) } -func (p *HTMLPage) DeleteCookies(ctx context.Context, cookies drivers.HTTPCookies) error { +func (p *HTMLPage) DeleteCookies(ctx context.Context, cookies *drivers.HTTPCookies) error { p.mu.Lock() defer p.mu.Unlock() diff --git a/pkg/drivers/cookies.go b/pkg/drivers/cookies.go index ed8343ec..4798e1bd 100644 --- a/pkg/drivers/cookies.go +++ b/pkg/drivers/cookies.go @@ -13,21 +13,27 @@ import ( "github.com/wI2L/jettison" ) -type HTTPCookies map[string]HTTPCookie +type HTTPCookies struct { + values map[string]HTTPCookie +} + +func NewHTTPCookies() *HTTPCookies { + return NewHTTPCookiesWith(make(map[string]HTTPCookie)) +} -func NewHTTPCookies() HTTPCookies { - return make(HTTPCookies) +func NewHTTPCookiesWith(values map[string]HTTPCookie) *HTTPCookies { + return &HTTPCookies{values} } -func (c HTTPCookies) MarshalJSON() ([]byte, error) { - return jettison.MarshalOpts(map[string]HTTPCookie(c), jettison.NoHTMLEscaping()) +func (c *HTTPCookies) MarshalJSON() ([]byte, error) { + return jettison.MarshalOpts(c.values, jettison.NoHTMLEscaping()) } -func (c HTTPCookies) Type() core.Type { +func (c *HTTPCookies) Type() core.Type { return HTTPCookiesType } -func (c HTTPCookies) String() string { +func (c *HTTPCookies) String() string { j, err := c.MarshalJSON() if err != nil { @@ -37,21 +43,21 @@ func (c HTTPCookies) String() string { return string(j) } -func (c HTTPCookies) Compare(other core.Value) int64 { +func (c *HTTPCookies) Compare(other core.Value) int64 { if other.Type() != HTTPCookiesType { return Compare(HTTPCookiesType, other.Type()) } - oc := other.(HTTPCookies) + oc := other.(*HTTPCookies) switch { - case len(c) > len(oc): + case len(c.values) > len(oc.values): return 1 - case len(c) < len(oc): + case len(c.values) < len(oc.values): return -1 } - for name := range c { + for name := range c.values { cEl, cExists := c.Get(values.NewString(name)) if !cExists { @@ -74,20 +80,20 @@ func (c HTTPCookies) Compare(other core.Value) int64 { return 0 } -func (c HTTPCookies) Unwrap() interface{} { - return map[string]HTTPCookie(c) +func (c *HTTPCookies) Unwrap() interface{} { + return c.values } -func (c HTTPCookies) Hash() uint64 { +func (c *HTTPCookies) Hash() uint64 { hash := fnv.New64a() hash.Write([]byte(c.Type().String())) hash.Write([]byte(":")) hash.Write([]byte("{")) - keys := make([]string, 0, len(c)) + keys := make([]string, 0, len(c.values)) - for key := range c { + for key := range c.values { keys = append(keys, key) } @@ -100,7 +106,7 @@ func (c HTTPCookies) Hash() uint64 { hash.Write([]byte(key)) hash.Write([]byte(":")) - el := c[key] + el := c.values[key] bytes := make([]byte, 8) binary.LittleEndian.PutUint64(bytes, el.Hash()) @@ -117,47 +123,59 @@ func (c HTTPCookies) Hash() uint64 { return hash.Sum64() } -func (c HTTPCookies) Copy() core.Value { - copied := make(HTTPCookies) +func (c *HTTPCookies) Copy() core.Value { + return NewHTTPCookiesWith(c.values) +} + +func (c *HTTPCookies) Clone() core.Cloneable { + clone := make(map[string]HTTPCookie) - for k, v := range c { - copied[k] = v + for _, cookie := range c.values { + clone[cookie.Name] = cookie } - return copied + return NewHTTPCookiesWith(clone) } -func (c HTTPCookies) Length() values.Int { - return values.NewInt(len(c)) +func (c *HTTPCookies) Length() values.Int { + return values.NewInt(len(c.values)) } -func (c HTTPCookies) Keys() []values.String { - keys := make([]values.String, 0, len(c)) +func (c *HTTPCookies) Keys() []values.String { + result := make([]values.String, 0, len(c.values)) - for k := range c { - keys = append(keys, values.NewString(k)) + for k := range c.values { + result = append(result, values.NewString(k)) } - return keys + return result } -func (c HTTPCookies) Get(key values.String) (core.Value, values.Boolean) { - value, found := c[key.String()] +func (c *HTTPCookies) Values() []HTTPCookie { + result := make([]HTTPCookie, 0, len(c.values)) + + for _, v := range c.values { + result = append(result, v) + } + + return result +} + +func (c *HTTPCookies) Get(key values.String) (HTTPCookie, values.Boolean) { + value, found := c.values[key.String()] if found { return value, values.True } - return values.None, values.False + return HTTPCookie{}, values.False } -func (c HTTPCookies) Set(key values.String, value core.Value) { - if cookie, ok := value.(HTTPCookie); ok { - c[key.String()] = cookie - } +func (c *HTTPCookies) Set(cookie HTTPCookie) { + c.values[cookie.Name] = cookie } -func (c HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { +func (c *HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { if len(path) == 0 { return values.None, nil } @@ -170,7 +188,7 @@ func (c HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, return values.None, err } - cookie, found := c[segment.String()] + cookie, found := c.values[segment.String()] if found { if len(path) == 1 { @@ -182,3 +200,11 @@ func (c HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, return values.None, nil } + +func (c *HTTPCookies) ForEach(predicate func(value HTTPCookie, key values.String) bool) { + for key, val := range c.values { + if !predicate(val, values.NewString(key)) { + break + } + } +} diff --git a/pkg/drivers/cookies_test.go b/pkg/drivers/cookies_test.go new file mode 100644 index 00000000..e2958850 --- /dev/null +++ b/pkg/drivers/cookies_test.go @@ -0,0 +1,65 @@ +package drivers_test + +import ( + "fmt" + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + "github.com/wI2L/jettison" + + "github.com/MontFerret/ferret/pkg/drivers" +) + +func TestHTTPCookies(t *testing.T) { + Convey("HTTPCookies", t, func() { + Convey(".MarshalJSON", func() { + Convey("Should serialize cookies", func() { + expires := time.Now() + headers := drivers.NewHTTPCookiesWith(map[string]drivers.HTTPCookie{ + "Session": { + Name: "Session", + Value: "asdfg", + Path: "/", + Domain: "www.google.com", + Expires: expires, + MaxAge: 0, + Secure: true, + HTTPOnly: true, + SameSite: drivers.SameSiteLaxMode, + }, + }) + + out, err := headers.MarshalJSON() + + t, e := expires.MarshalJSON() + So(e, ShouldBeNil) + + expected := fmt.Sprintf(`{"Session":{"domain":"www.google.com","expires":%s,"http_only":true,"max_age":0,"name":"Session","path":"/","same_site":"Lax","secure":true,"value":"asdfg"}}`, string(t)) + + So(err, ShouldBeNil) + So(string(out), ShouldEqual, expected) + }) + + Convey("Should set proper values", func() { + headers := drivers.NewHTTPCookies() + + headers.Set(drivers.HTTPCookie{ + Name: "Authorization", + Value: "e40b7d5eff464a4fb51efed2d1a19a24", + Path: "/", + Domain: "www.google.com", + Expires: time.Now(), + MaxAge: 0, + Secure: false, + HTTPOnly: false, + SameSite: 0, + }) + + _, err := jettison.MarshalOpts(headers, jettison.NoHTMLEscaping()) + + So(err, ShouldBeNil) + }) + }) + }) +} diff --git a/pkg/drivers/driver.go b/pkg/drivers/driver.go index 5eee7c2d..695e4ce3 100644 --- a/pkg/drivers/driver.go +++ b/pkg/drivers/driver.go @@ -11,7 +11,7 @@ type ( ctxKey struct{} ctxValue struct { - opts *options + opts *globalOptions drivers map[string]Driver } @@ -23,7 +23,7 @@ type ( } ) -func WithContext(ctx context.Context, drv Driver, opts ...Option) context.Context { +func WithContext(ctx context.Context, drv Driver, opts ...GlobalOption) context.Context { ctx, value := resolveValue(ctx) value.drivers[drv.Name()] = drv @@ -63,7 +63,7 @@ func resolveValue(ctx context.Context) (context.Context, *ctxValue) { if !ok { value = &ctxValue{ - opts: &options{}, + opts: &globalOptions{}, drivers: make(map[string]Driver), } diff --git a/pkg/drivers/headers.go b/pkg/drivers/headers.go index 12bd50a9..c89a5118 100644 --- a/pkg/drivers/headers.go +++ b/pkg/drivers/headers.go @@ -17,40 +17,50 @@ import ( ) // HTTPHeaders HTTP header object -type HTTPHeaders map[string][]string +type HTTPHeaders struct { + values map[string][]string +} + +func NewHTTPHeaders() *HTTPHeaders { + return NewHTTPHeadersWith(make(map[string][]string)) +} + +func NewHTTPHeadersWith(values map[string][]string) *HTTPHeaders { + return &HTTPHeaders{values} +} -func NewHTTPHeaders(values map[string][]string) HTTPHeaders { - return HTTPHeaders(values) +func (h *HTTPHeaders) Length() values.Int { + return values.NewInt(len(h.values)) } -func (h HTTPHeaders) Type() core.Type { +func (h *HTTPHeaders) Type() core.Type { return HTTPHeaderType } -func (h HTTPHeaders) String() string { +func (h *HTTPHeaders) String() string { var buf bytes.Buffer - for k := range h { + for k := range h.values { buf.WriteString(fmt.Sprintf("%s=%s;", k, h.Get(k))) } return buf.String() } -func (h HTTPHeaders) Compare(other core.Value) int64 { +func (h *HTTPHeaders) Compare(other core.Value) int64 { if other.Type() != HTTPHeaderType { return Compare(HTTPHeaderType, other.Type()) } - oh := other.(HTTPHeaders) + oh := other.(*HTTPHeaders) - if len(h) > len(oh) { + if len(h.values) > len(oh.values) { return 1 - } else if len(h) < len(oh) { + } else if len(h.values) < len(oh.values) { return -1 } - for k := range h { + for k := range h.values { c := strings.Compare(h.Get(k), oh.Get(k)) if c != 0 { @@ -61,20 +71,20 @@ func (h HTTPHeaders) Compare(other core.Value) int64 { return 0 } -func (h HTTPHeaders) Unwrap() interface{} { - return h +func (h *HTTPHeaders) Unwrap() interface{} { + return h.values } -func (h HTTPHeaders) Hash() uint64 { +func (h *HTTPHeaders) Hash() uint64 { hash := fnv.New64a() hash.Write([]byte(h.Type().String())) hash.Write([]byte(":")) hash.Write([]byte("{")) - keys := make([]string, 0, len(h)) + keys := make([]string, 0, len(h.values)) - for key := range h { + for key := range h.values { keys = append(keys, key) } @@ -101,18 +111,28 @@ func (h HTTPHeaders) Hash() uint64 { return hash.Sum64() } -func (h HTTPHeaders) Copy() core.Value { - return *(&h) +func (h *HTTPHeaders) Copy() core.Value { + return &HTTPHeaders{h.values} +} + +func (h *HTTPHeaders) Clone() core.Cloneable { + cp := make(map[string][]string) + + for k, v := range h.values { + cp[k] = v + } + + return &HTTPHeaders{cp} } -func (h HTTPHeaders) MarshalJSON() ([]byte, error) { +func (h *HTTPHeaders) MarshalJSON() ([]byte, error) { headers := map[string]string{} - for key, val := range h { + for key, val := range h.values { headers[key] = strings.Join(val, ", ") } - out, err := jettison.MarshalOpts(headers, jettison.NoHTMLEscaping()) + out, err := jettison.MarshalOpts(headers) if err != nil { return nil, err @@ -121,15 +141,25 @@ func (h HTTPHeaders) MarshalJSON() ([]byte, error) { return out, err } -func (h HTTPHeaders) Set(key, value string) { - textproto.MIMEHeader(h).Set(key, value) +func (h *HTTPHeaders) Set(key, value string) { + textproto.MIMEHeader(h.values).Set(key, value) } -func (h HTTPHeaders) Get(key string) string { - return textproto.MIMEHeader(h).Get(key) +func (h *HTTPHeaders) SetArr(key string, value []string) { + h.values[key] = value } -func (h HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, error) { +func (h *HTTPHeaders) Get(key string) string { + _, found := h.values[key] + + if !found { + return "" + } + + return textproto.MIMEHeader(h.values).Get(key) +} + +func (h *HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, error) { if len(path) == 0 { return values.None, nil } @@ -144,3 +174,11 @@ func (h HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, er return values.NewString(h.Get(segment.String())), nil } + +func (h *HTTPHeaders) ForEach(predicate func(value []string, key string) bool) { + for key, val := range h.values { + if !predicate(val, key) { + break + } + } +} diff --git a/pkg/drivers/headers_test.go b/pkg/drivers/headers_test.go index 5248597d..f8765781 100644 --- a/pkg/drivers/headers_test.go +++ b/pkg/drivers/headers_test.go @@ -4,24 +4,35 @@ import ( "testing" . "github.com/smartystreets/goconvey/convey" + "github.com/wI2L/jettison" "github.com/MontFerret/ferret/pkg/drivers" ) -func TestHTTPHeader(t *testing.T) { +func TestHTTPHeaders(t *testing.T) { Convey("HTTPHeaders", t, func() { Convey(".MarshalJSON", func() { Convey("Should serialize header values", func() { - headers := make(drivers.HTTPHeaders) - - headers["Content-Encoding"] = []string{"gzip"} - headers["Content-Type"] = []string{"text/html", "charset=utf-8"} + headers := drivers.NewHTTPHeadersWith(map[string][]string{ + "Content-Encoding": []string{"gzip"}, + "Content-Type": []string{"text/html", "charset=utf-8"}, + }) out, err := headers.MarshalJSON() So(err, ShouldBeNil) So(string(out), ShouldEqual, `{"Content-Encoding":"gzip","Content-Type":"text/html, charset=utf-8"}`) }) + + Convey("Should set proper values", func() { + headers := drivers.NewHTTPHeaders() + + headers.Set("Authorization", `["Basic e40b7d5eff464a4fb51efed2d1a19a24"]`) + + _, err := jettison.MarshalOpts(headers, jettison.NoHTMLEscaping()) + + So(err, ShouldBeNil) + }) }) }) } diff --git a/pkg/drivers/helpers.go b/pkg/drivers/helpers.go index 434baf56..c9f8a139 100644 --- a/pkg/drivers/helpers.go +++ b/pkg/drivers/helpers.go @@ -2,6 +2,7 @@ package drivers import ( "github.com/MontFerret/ferret/pkg/runtime/core" + "github.com/MontFerret/ferret/pkg/runtime/values" ) func ToPage(value core.Value) (HTMLPage, error) { @@ -46,3 +47,48 @@ func ToElement(value core.Value) (HTMLElement, error) { ) } } + +func SetDefaultParams(opts *Options, params Params) Params { + if params.Headers == nil && opts.Headers != nil { + params.Headers = NewHTTPHeaders() + } + + // set default headers + if opts.Headers != nil { + opts.Headers.ForEach(func(value []string, key string) bool { + val := params.Headers.Get(key) + + // do not override user's set values + if val == "" { + params.Headers.SetArr(key, value) + } + + return true + }) + } + + if params.Cookies == nil && opts.Cookies != nil { + params.Cookies = NewHTTPCookies() + } + + // set default cookies + if opts.Cookies != nil { + opts.Cookies.ForEach(func(value HTTPCookie, key values.String) bool { + _, exists := params.Cookies.Get(key) + + // do not override user's set values + if !exists { + params.Cookies.Set(value) + } + + return true + }) + } + + // set default user agent + if opts.UserAgent != "" && params.UserAgent == "" { + params.UserAgent = opts.UserAgent + } + + return params +} diff --git a/pkg/drivers/helpers_test.go b/pkg/drivers/helpers_test.go new file mode 100644 index 00000000..5d0a0fd8 --- /dev/null +++ b/pkg/drivers/helpers_test.go @@ -0,0 +1,41 @@ +package drivers_test + +import ( + "testing" + "time" + + . "github.com/smartystreets/goconvey/convey" + + "github.com/MontFerret/ferret/pkg/drivers" +) + +func TestSetDefaultParams(t *testing.T) { + Convey("Should take values from Options if not present in Params", t, func() { + opts := &drivers.Options{ + Name: "Test", + UserAgent: "Mozilla", + Headers: drivers.NewHTTPHeadersWith(map[string][]string{ + "Accept": {"application/json"}, + }), + Cookies: drivers.NewHTTPCookiesWith(map[string]drivers.HTTPCookie{ + "Session": drivers.HTTPCookie{ + Name: "Session", + Value: "fsfsdfsd", + Path: "", + Domain: "", + Expires: time.Time{}, + MaxAge: 0, + Secure: false, + HTTPOnly: false, + SameSite: 0, + }, + }), + } + + params := drivers.SetDefaultParams(opts, drivers.Params{}) + + So(params.UserAgent, ShouldEqual, opts.UserAgent) + So(params.Headers, ShouldNotBeNil) + So(params.Cookies, ShouldNotBeNil) + }) +} diff --git a/pkg/drivers/http/driver.go b/pkg/drivers/http/driver.go index 1b4929b5..eaf7bc1f 100644 --- a/pkg/drivers/http/driver.go +++ b/pkg/drivers/http/driver.go @@ -25,7 +25,7 @@ type Driver struct { func NewDriver(opts ...Option) *Driver { drv := new(Driver) - drv.options = newOptions(opts) + drv.options = NewOptions(opts) drv.client = newHTTPClient(drv.options) drv.client.Concurrency = drv.options.Concurrency @@ -96,63 +96,11 @@ func (drv *Driver) Open(ctx context.Context, params drivers.Params) (drivers.HTM req.Header.Set("Cache-Control", "no-cache") req.Header.Set("Pragma", "no-cache") - if drv.options.Headers != nil && params.Headers == nil { - params.Headers = make(drivers.HTTPHeaders) - } - - // Set default headers - for k, v := range drv.options.Headers { - _, exists := params.Headers[k] - - // do not override user's set values - if !exists { - params.Headers[k] = v - } - } - - for k := range params.Headers { - req.Header.Add(k, params.Headers.Get(k)) - - logger. - Debug(). - Timestamp(). - Str("header", k). - Msg("set header") - } - - if drv.options.Cookies != nil && params.Cookies == nil { - params.Cookies = make(drivers.HTTPCookies) - } - - // set default cookies - for k, v := range drv.options.Cookies { - _, exists := params.Cookies[k] - - // do not override user's set values - if !exists { - params.Cookies[k] = v - } - } - - for _, c := range params.Cookies { - req.AddCookie(fromDriverCookie(c)) - - logger. - Debug(). - Timestamp(). - Str("cookie", c.Name). - Msg("set cookie") - } + params = drivers.SetDefaultParams(drv.options.Options, params) req = req.WithContext(ctx) - var ua string - - if params.UserAgent != "" { - ua = common.GetUserAgent(params.UserAgent) - } else { - ua = common.GetUserAgent(drv.options.UserAgent) - } + ua := common.GetUserAgent(params.UserAgent) logger. Debug(). @@ -197,7 +145,7 @@ func (drv *Driver) Open(ctx context.Context, params drivers.Params) (drivers.HTM r := drivers.HTTPResponse{ StatusCode: resp.StatusCode, Status: resp.Status, - Headers: drivers.HTTPHeaders(resp.Header), + Headers: drivers.NewHTTPHeadersWith(resp.Header), } return NewHTMLPage(doc, params.URL, r, cookies) diff --git a/pkg/drivers/http/driver_test.go b/pkg/drivers/http/driver_test.go index bb117516..7d9e2aa8 100644 --- a/pkg/drivers/http/driver_test.go +++ b/pkg/drivers/http/driver_test.go @@ -2,6 +2,7 @@ package http import ( "crypto/tls" + "github.com/MontFerret/ferret/pkg/drivers" "net/http" "reflect" "testing" @@ -25,14 +26,18 @@ func Test_newHTTPClientWithTransport(t *testing.T) { { name: "check transport exist with pester.New()", args: args{options: &Options{ - Proxy: "http://0.0.0.|", + Options: &drivers.Options{ + Proxy: "http://0.0.0.|", + }, HTTPTransport: httpTransport, }}, }, { name: "check transport exist with pester.NewExtendedClient()", args: args{options: &Options{ - Proxy: "http://0.0.0.0", + Options: &drivers.Options{ + Proxy: "http://0.0.0.0", + }, HTTPTransport: httpTransport, }}, }, @@ -69,7 +74,9 @@ func Test_newHTTPClient(t *testing.T) { convey.Convey("pester.New()", t, func() { var ( client = newHTTPClient(&Options{ - Proxy: "http://0.0.0.|", + Options: &drivers.Options{ + Proxy: "http://0.0.0.|", + }, }) rValue = reflect.ValueOf(client).Elem() @@ -85,7 +92,9 @@ func Test_newHTTPClient(t *testing.T) { convey.Convey("pester.NewExtend()", t, func() { var ( client = newHTTPClient(&Options{ - Proxy: "http://0.0.0.0", + Options: &drivers.Options{ + Proxy: "http://0.0.0.0", + }, }) rValue = reflect.ValueOf(client).Elem() diff --git a/pkg/drivers/http/helpers.go b/pkg/drivers/http/helpers.go index cc9d15e4..2a506a9a 100644 --- a/pkg/drivers/http/helpers.go +++ b/pkg/drivers/http/helpers.go @@ -54,8 +54,8 @@ func outerHTML(s *goquery.Selection) (string, error) { return buf.String(), nil } -func toDriverCookies(cookies []*HTTP.Cookie) (drivers.HTTPCookies, error) { - res := make(drivers.HTTPCookies) +func toDriverCookies(cookies []*HTTP.Cookie) (*drivers.HTTPCookies, error) { + res := drivers.NewHTTPCookies() for _, c := range cookies { dc, err := toDriverCookie(c) @@ -64,7 +64,7 @@ func toDriverCookies(cookies []*HTTP.Cookie) (drivers.HTTPCookies, error) { return nil, err } - res[dc.Name] = dc + res.Set(dc) } return res, nil diff --git a/pkg/drivers/http/options.go b/pkg/drivers/http/options.go index 5f9c53c6..5adc81b5 100644 --- a/pkg/drivers/http/options.go +++ b/pkg/drivers/http/options.go @@ -1,11 +1,17 @@ package http import ( - "github.com/gobwas/glob" stdhttp "net/http" - "github.com/MontFerret/ferret/pkg/drivers" + "github.com/gobwas/glob" "github.com/sethgrid/pester" + + "github.com/MontFerret/ferret/pkg/drivers" +) + +var ( + DefaultConcurrency = 3 + DefaultMaxRetries = 5 ) type ( @@ -17,25 +23,22 @@ type ( } Options struct { - Name string + *drivers.Options Backoff pester.BackoffStrategy MaxRetries int Concurrency int - Proxy string - UserAgent string - Headers drivers.HTTPHeaders - Cookies drivers.HTTPCookies HTTPCodesFilter []compiledStatusCodeFilter HTTPTransport *stdhttp.Transport } ) -func newOptions(setters []Option) *Options { +func NewOptions(setters []Option) *Options { opts := new(Options) + opts.Options = new(drivers.Options) opts.Name = DriverName opts.Backoff = pester.ExponentialBackoff - opts.Concurrency = 3 - opts.MaxRetries = 5 + opts.Concurrency = DefaultConcurrency + opts.MaxRetries = DefaultMaxRetries opts.HTTPCodesFilter = make([]compiledStatusCodeFilter, 0, 5) for _, setter := range setters { @@ -77,63 +80,43 @@ func WithConcurrency(value int) Option { func WithProxy(address string) Option { return func(opts *Options) { - opts.Proxy = address + drivers.WithProxy(address)(opts.Options) } } func WithUserAgent(value string) Option { return func(opts *Options) { - opts.UserAgent = value + drivers.WithUserAgent(value)(opts.Options) } } func WithCustomName(name string) Option { return func(opts *Options) { - opts.Name = name + drivers.WithCustomName(name)(opts.Options) } } func WithHeader(name string, value []string) Option { return func(opts *Options) { - if opts.Headers == nil { - opts.Headers = make(drivers.HTTPHeaders) - } - - opts.Headers[name] = value + drivers.WithHeader(name, value)(opts.Options) } } -func WithHeaders(headers drivers.HTTPHeaders) Option { +func WithHeaders(headers *drivers.HTTPHeaders) Option { return func(opts *Options) { - if opts.Headers == nil { - opts.Headers = make(drivers.HTTPHeaders) - } - - for k, v := range headers { - opts.Headers[k] = v - } + drivers.WithHeaders(headers)(opts.Options) } } func WithCookie(cookie drivers.HTTPCookie) Option { return func(opts *Options) { - if opts.Cookies == nil { - opts.Cookies = make(drivers.HTTPCookies) - } - - opts.Cookies[cookie.Name] = cookie + drivers.WithCookie(cookie)(opts.Options) } } func WithCookies(cookies []drivers.HTTPCookie) Option { return func(opts *Options) { - if opts.Cookies == nil { - opts.Cookies = make(drivers.HTTPCookies) - } - - for _, c := range cookies { - opts.Cookies[c.Name] = c - } + drivers.WithCookies(cookies)(opts.Options) } } diff --git a/pkg/drivers/http/options_test.go b/pkg/drivers/http/options_test.go new file mode 100644 index 00000000..49902931 --- /dev/null +++ b/pkg/drivers/http/options_test.go @@ -0,0 +1,85 @@ +package http_test + +import ( + stdhttp "net/http" + "testing" + "time" + + "github.com/sethgrid/pester" + . "github.com/smartystreets/goconvey/convey" + + "github.com/MontFerret/ferret/pkg/drivers" + "github.com/MontFerret/ferret/pkg/drivers/http" +) + +func TestNewOptions(t *testing.T) { + Convey("Should create driver options with initial values", t, func() { + opts := http.NewOptions([]http.Option{}) + So(opts.Options, ShouldNotBeNil) + So(opts.Name, ShouldEqual, http.DriverName) + So(opts.Backoff, ShouldEqual, pester.ExponentialBackoff) + So(opts.Concurrency, ShouldEqual, http.DefaultConcurrency) + So(opts.MaxRetries, ShouldEqual, http.DefaultMaxRetries) + So(opts.HTTPCodesFilter, ShouldHaveLength, 0) + }) + + Convey("Should use setters to set values", t, func() { + expectedName := http.DriverName + "2" + expectedUA := "Mozilla" + expectedProxy := "https://proxy.com" + expectedMaxRetries := 2 + expectedConcurrency := 10 + expectedTransport := &stdhttp.Transport{} + + opts := http.NewOptions([]http.Option{ + http.WithCustomName(expectedName), + http.WithUserAgent(expectedUA), + http.WithProxy(expectedProxy), + http.WithCookie(drivers.HTTPCookie{ + Name: "Session", + Value: "fsdfsdfs", + Path: "dfsdfsd", + Domain: "sfdsfs", + Expires: time.Time{}, + MaxAge: 0, + Secure: false, + HTTPOnly: false, + SameSite: 0, + }), + http.WithCookies([]drivers.HTTPCookie{ + { + Name: "Use", + Value: "Foos", + Path: "", + Domain: "", + Expires: time.Time{}, + MaxAge: 0, + Secure: false, + HTTPOnly: false, + SameSite: 0, + }, + }), + http.WithHeader("Authorization", []string{"Bearer dfsd7f98sd9fsd9fsd"}), + http.WithHeaders(drivers.NewHTTPHeadersWith(map[string][]string{ + "x-correlation-id": {"232483833833839"}, + })), + http.WithDefaultBackoff(), + http.WithMaxRetries(expectedMaxRetries), + http.WithConcurrency(expectedConcurrency), + http.WithAllowedHTTPCode(401), + http.WithAllowedHTTPCodes([]int{403, 404}), + http.WithCustomTransport(expectedTransport), + }) + So(opts.Options, ShouldNotBeNil) + So(opts.Name, ShouldEqual, expectedName) + So(opts.UserAgent, ShouldEqual, expectedUA) + So(opts.Proxy, ShouldEqual, expectedProxy) + So(opts.Cookies.Length(), ShouldEqual, 2) + So(opts.Headers.Length(), ShouldEqual, 2) + So(opts.Backoff, ShouldEqual, pester.DefaultBackoff) + So(opts.MaxRetries, ShouldEqual, expectedMaxRetries) + So(opts.Concurrency, ShouldEqual, expectedConcurrency) + So(opts.HTTPCodesFilter, ShouldHaveLength, 3) + So(opts.HTTPTransport, ShouldEqual, expectedTransport) + }) +} diff --git a/pkg/drivers/http/page.go b/pkg/drivers/http/page.go index ceb50965..c69e782c 100644 --- a/pkg/drivers/http/page.go +++ b/pkg/drivers/http/page.go @@ -14,7 +14,7 @@ import ( type HTMLPage struct { document *HTMLDocument - cookies drivers.HTTPCookies + cookies *drivers.HTTPCookies frames *values.Array response drivers.HTTPResponse } @@ -23,7 +23,7 @@ func NewHTMLPage( qdoc *goquery.Document, url string, response drivers.HTTPResponse, - cookies drivers.HTTPCookies, + cookies *drivers.HTTPCookies, ) (*HTMLPage, error) { doc, err := NewRootHTMLDocument(qdoc, url) @@ -84,10 +84,10 @@ func (p *HTMLPage) Hash() uint64 { } func (p *HTMLPage) Copy() core.Value { - cookies := make(drivers.HTTPCookies) + var cookies *drivers.HTTPCookies - for k, v := range p.cookies { - cookies[k] = v + if p.cookies != nil { + cookies = p.cookies.Copy().(*drivers.HTTPCookies) } page, err := NewHTMLPage( @@ -168,11 +168,15 @@ func (p *HTMLPage) GetFrame(ctx context.Context, idx values.Int) (core.Value, er return p.frames.Get(idx), nil } -func (p *HTMLPage) GetCookies(_ context.Context) (drivers.HTTPCookies, error) { - res := make(drivers.HTTPCookies) +func (p *HTMLPage) GetCookies(_ context.Context) (*drivers.HTTPCookies, error) { + res := drivers.NewHTTPCookies() - for n, v := range p.cookies { - res[n] = v + if p.cookies != nil { + p.cookies.ForEach(func(value drivers.HTTPCookie, _ values.String) bool { + res.Set(value) + + return true + }) } return res, nil @@ -182,11 +186,11 @@ func (p *HTMLPage) GetResponse(_ context.Context) (drivers.HTTPResponse, error) return p.response, nil } -func (p *HTMLPage) SetCookies(_ context.Context, _ drivers.HTTPCookies) error { +func (p *HTMLPage) SetCookies(_ context.Context, _ *drivers.HTTPCookies) error { return core.ErrNotSupported } -func (p *HTMLPage) DeleteCookies(_ context.Context, _ drivers.HTTPCookies) error { +func (p *HTMLPage) DeleteCookies(_ context.Context, _ *drivers.HTTPCookies) error { return core.ErrNotSupported } diff --git a/pkg/drivers/options.go b/pkg/drivers/options.go index 357d1503..8eeb89c7 100644 --- a/pkg/drivers/options.go +++ b/pkg/drivers/options.go @@ -1,15 +1,89 @@ package drivers type ( - options struct { + globalOptions struct { defaultDriver string } - Option func(drv Driver, opts *options) + GlobalOption func(drv Driver, opts *globalOptions) + + Options struct { + Name string + Proxy string + UserAgent string + Headers *HTTPHeaders + Cookies *HTTPCookies + } + + Option func(opts *Options) ) -func AsDefault() Option { - return func(drv Driver, opts *options) { +func AsDefault() GlobalOption { + return func(drv Driver, opts *globalOptions) { opts.defaultDriver = drv.Name() } } + +func WithProxy(address string) Option { + return func(opts *Options) { + opts.Proxy = address + } +} + +func WithUserAgent(value string) Option { + return func(opts *Options) { + opts.UserAgent = value + } +} + +func WithCustomName(name string) Option { + return func(opts *Options) { + opts.Name = name + } +} + +func WithHeader(name string, value []string) Option { + return func(opts *Options) { + if opts.Headers == nil { + opts.Headers = NewHTTPHeaders() + } + + opts.Headers.SetArr(name, value) + } +} + +func WithHeaders(headers *HTTPHeaders) Option { + return func(opts *Options) { + if opts.Headers == nil { + opts.Headers = NewHTTPHeaders() + } + + headers.ForEach(func(value []string, key string) bool { + opts.Headers.SetArr(key, value) + + return true + }) + } +} + +func WithCookie(cookie HTTPCookie) Option { + return func(opts *Options) { + if opts.Cookies == nil { + opts.Cookies = NewHTTPCookies() + } + + opts.Cookies.Set(cookie) + } +} + +func WithCookies(cookies []HTTPCookie) Option { + return func(opts *Options) { + if opts.Cookies == nil { + opts.Cookies = NewHTTPCookies() + } + + for _, c := range cookies { + opts.Cookies.Set(c) + } + } +} diff --git a/pkg/drivers/params.go b/pkg/drivers/params.go index e39c707f..f85033df 100644 --- a/pkg/drivers/params.go +++ b/pkg/drivers/params.go @@ -28,8 +28,8 @@ type ( URL string UserAgent string KeepCookies bool - Cookies HTTPCookies - Headers HTTPHeaders + Cookies *HTTPCookies + Headers *HTTPHeaders Viewport *Viewport Ignore *Ignore } @@ -37,8 +37,8 @@ type ( ParseParams struct { Content []byte KeepCookies bool - Cookies HTTPCookies - Headers HTTPHeaders + Cookies *HTTPCookies + Headers *HTTPHeaders Viewport *Viewport } ) diff --git a/pkg/drivers/response.go b/pkg/drivers/response.go index e617c4aa..bf891ef2 100644 --- a/pkg/drivers/response.go +++ b/pkg/drivers/response.go @@ -14,7 +14,7 @@ import ( type HTTPResponse struct { StatusCode int Status string - Headers HTTPHeaders + Headers *HTTPHeaders } func (resp *HTTPResponse) Type() core.Type { @@ -60,9 +60,9 @@ func (resp *HTTPResponse) Hash() uint64 { // responseMarshal is a structure that repeats HTTPResponse. It allows // easily Marshal the HTTPResponse object. type responseMarshal struct { - StatusCode int `json:"status_code"` - Status string `json:"status"` - Headers HTTPHeaders `json:"headers"` + StatusCode int `json:"status_code"` + Status string `json:"status"` + Headers *HTTPHeaders `json:"headers"` } func (resp *HTTPResponse) MarshalJSON() ([]byte, error) { diff --git a/pkg/drivers/value.go b/pkg/drivers/value.go index 16ab6d06..97cf7935 100644 --- a/pkg/drivers/value.go +++ b/pkg/drivers/value.go @@ -196,11 +196,11 @@ type ( GetFrame(ctx context.Context, idx values.Int) (core.Value, error) - GetCookies(ctx context.Context) (HTTPCookies, error) + GetCookies(ctx context.Context) (*HTTPCookies, error) - SetCookies(ctx context.Context, cookies HTTPCookies) error + SetCookies(ctx context.Context, cookies *HTTPCookies) error - DeleteCookies(ctx context.Context, cookies HTTPCookies) error + DeleteCookies(ctx context.Context, cookies *HTTPCookies) error GetResponse(ctx context.Context) (HTTPResponse, error) diff --git a/pkg/stdlib/html/cookie_del.go b/pkg/stdlib/html/cookie_del.go index e5225fec..e63d1a07 100644 --- a/pkg/stdlib/html/cookie_del.go +++ b/pkg/stdlib/html/cookie_del.go @@ -26,8 +26,8 @@ func CookieDel(ctx context.Context, args ...core.Value) (core.Value, error) { } inputs := args[1:] - var currentCookies drivers.HTTPCookies - cookies := make(drivers.HTTPCookies) + var currentCookies *drivers.HTTPCookies + cookies := drivers.NewHTTPCookies() for _, c := range inputs { switch cookie := c.(type) { @@ -42,14 +42,14 @@ func CookieDel(ctx context.Context, args ...core.Value) (core.Value, error) { currentCookies = current } - found, isFound := currentCookies[cookie.String()] + found, isFound := currentCookies.Get(cookie) if isFound { - cookies[cookie.String()] = found + cookies.Set(found) } case drivers.HTTPCookie: - cookies[cookie.Name] = cookie + cookies.Set(cookie) default: return values.None, core.TypeError(c.Type(), types.String, drivers.HTTPCookieType) } diff --git a/pkg/stdlib/html/cookie_get.go b/pkg/stdlib/html/cookie_get.go index e97d8d12..a202705f 100644 --- a/pkg/stdlib/html/cookie_get.go +++ b/pkg/stdlib/html/cookie_get.go @@ -40,7 +40,7 @@ func CookieGet(ctx context.Context, args ...core.Value) (core.Value, error) { return values.None, err } - cookie, found := cookies[name.String()] + cookie, found := cookies.Get(name) if found { return cookie, nil diff --git a/pkg/stdlib/html/cookie_set.go b/pkg/stdlib/html/cookie_set.go index f7ea34fd..4eaf0b6d 100644 --- a/pkg/stdlib/html/cookie_set.go +++ b/pkg/stdlib/html/cookie_set.go @@ -24,7 +24,7 @@ func CookieSet(ctx context.Context, args ...core.Value) (core.Value, error) { return values.None, err } - cookies := make(drivers.HTTPCookies) + cookies := drivers.NewHTTPCookies() for _, c := range args[1:] { cookie, err := parseCookie(c) @@ -33,7 +33,7 @@ func CookieSet(ctx context.Context, args ...core.Value) (core.Value, error) { return values.None, err } - cookies[cookie.Name] = cookie + cookies.Set(cookie) } return values.None, page.SetCookies(ctx, cookies) diff --git a/pkg/stdlib/html/document.go b/pkg/stdlib/html/document.go index a19273d0..312ddb96 100644 --- a/pkg/stdlib/html/document.go +++ b/pkg/stdlib/html/document.go @@ -168,7 +168,7 @@ func newPageLoadParams(url values.String, arg core.Value) (PageLoadParams, error res.Cookies = cookies default: - res.Cookies = make(drivers.HTTPCookies) + res.Cookies = drivers.NewHTTPCookies() } } @@ -220,9 +220,13 @@ func newPageLoadParams(url values.String, arg core.Value) (PageLoadParams, error return res, nil } -func parseCookieObject(obj *values.Object) (drivers.HTTPCookies, error) { +func parseCookieObject(obj *values.Object) (*drivers.HTTPCookies, error) { + if obj == nil { + return nil, errors.Wrap(core.ErrMissedArgument, "cookies") + } + var err error - res := make(drivers.HTTPCookies) + res := drivers.NewHTTPCookies() obj.ForEach(func(value core.Value, _ string) bool { cookie, e := parseCookie(value) @@ -233,7 +237,7 @@ func parseCookieObject(obj *values.Object) (drivers.HTTPCookies, error) { return false } - res[cookie.Name] = cookie + res.Set(cookie) return true }) @@ -241,9 +245,13 @@ func parseCookieObject(obj *values.Object) (drivers.HTTPCookies, error) { return res, err } -func parseCookieArray(arr *values.Array) (drivers.HTTPCookies, error) { +func parseCookieArray(arr *values.Array) (*drivers.HTTPCookies, error) { + if arr == nil { + return nil, errors.Wrap(core.ErrMissedArgument, "cookies") + } + var err error - res := make(drivers.HTTPCookies) + res := drivers.NewHTTPCookies() arr.ForEach(func(value core.Value, _ int) bool { cookie, e := parseCookie(value) @@ -254,7 +262,7 @@ func parseCookieArray(arr *values.Array) (drivers.HTTPCookies, error) { return false } - res[cookie.Name] = cookie + res.Set(cookie) return true }) @@ -350,11 +358,25 @@ func parseCookie(value core.Value) (drivers.HTTPCookie, error) { return cookie, err } -func parseHeader(headers *values.Object) drivers.HTTPHeaders { - res := make(drivers.HTTPHeaders) +func parseHeader(headers *values.Object) *drivers.HTTPHeaders { + res := drivers.NewHTTPHeaders() headers.ForEach(func(value core.Value, key string) bool { - res.Set(key, value.String()) + if value.Type() == types.Array { + value := value.(*values.Array) + + keyValues := make([]string, 0, value.Length()) + + value.ForEach(func(v core.Value, idx int) bool { + keyValues = append(keyValues, v.String()) + + return true + }) + + res.SetArr(key, keyValues) + } else { + res.Set(key, value.String()) + } return true }) diff --git a/pkg/stdlib/html/parse.go b/pkg/stdlib/html/parse.go index 00e29118..110999ef 100644 --- a/pkg/stdlib/html/parse.go +++ b/pkg/stdlib/html/parse.go @@ -132,7 +132,7 @@ func parseParseParams(content []byte, arg *values.Object) (ParseParams, error) { res.Cookies = cookies default: - res.Cookies = make(drivers.HTTPCookies) + res.Cookies = drivers.NewHTTPCookies() } }