From 17b9dcd3583f40813001f790b046ac73d4933a59 Mon Sep 17 00:00:00 2001 From: Jeremy Edwards <1312331+jeremyje@users.noreply.github.com> Date: Fri, 3 Mar 2023 04:54:32 +0000 Subject: [PATCH] Add support for local file system in HTTP driver. --- pkg/drivers/http/driver.go | 43 ++++++++++++++++++++++++++++++--- pkg/drivers/http/driver_test.go | 38 +++++++++++++++++++++++++++-- 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/pkg/drivers/http/driver.go b/pkg/drivers/http/driver.go index 4da5c2ee..389ce8a8 100644 --- a/pkg/drivers/http/driver.go +++ b/pkg/drivers/http/driver.go @@ -6,6 +6,7 @@ import ( "io" "net/http" "net/url" + "os" "github.com/gobwas/glob" @@ -88,17 +89,51 @@ func (drv *Driver) Name() string { return drv.options.Name } -func (drv *Driver) Open(ctx context.Context, params drivers.Params) (drivers.HTMLPage, error) { +func (drv *Driver) readLocalFile(u *url.URL, params drivers.Params) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, params.URL, nil) if err != nil { return nil, err } + f, err := os.Open(u.Path) + if err != nil { + return nil, err + } + stat, err := f.Stat() + if err != nil { + f.Close() + return nil, err + } + return &http.Response{ + Status: "200 OK", + StatusCode: http.StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: map[string][]string{}, + Body: f, + ContentLength: stat.Size(), + Trailer: map[string][]string{}, + Request: req, + }, nil +} - params = drivers.SetDefaultParams(drv.options.Options, params) +func (drv *Driver) Open(ctx context.Context, params drivers.Params) (drivers.HTMLPage, error) { + var resp *http.Response + var err error + u, err := url.Parse(params.URL) + if err == nil && u.Scheme == "file" { + resp, err = drv.readLocalFile(u, params) + } else { + req, reqErr := http.NewRequest(http.MethodGet, params.URL, nil) + if reqErr != nil { + return nil, reqErr + } - drv.makeRequest(ctx, req, params) + params = drivers.SetDefaultParams(drv.options.Options, params) + drv.makeRequest(ctx, req, params) + resp, err = drv.client.Do(req) + } - resp, err := drv.client.Do(req) if err != nil { return nil, errors.Wrapf(err, "failed to retrieve a document %s", params.URL) } diff --git a/pkg/drivers/http/driver_test.go b/pkg/drivers/http/driver_test.go index 2f19d4ad..c710508c 100644 --- a/pkg/drivers/http/driver_test.go +++ b/pkg/drivers/http/driver_test.go @@ -6,6 +6,8 @@ import ( "crypto/tls" "io" "net/http" + "os" + "path/filepath" "reflect" "testing" "unsafe" @@ -17,6 +19,11 @@ import ( "github.com/MontFerret/ferret/pkg/drivers" ) +const ( + testInnerText = `феррет` + testHTMLDocument = `феррет` +) + func Test_newHTTPClientWithTransport(t *testing.T) { httpTransport := (http.DefaultTransport).(*http.Transport) httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} @@ -87,6 +94,33 @@ func Test_newHTTPClient(t *testing.T) { }) } +func TestDriver_readFromFile(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + localFile := filepath.Join(dir, "test.html") + if err := os.WriteFile(localFile, []byte(testHTMLDocument), 0644); err != nil { + t.Fatal(err) + } + localFileURL := "file://" + localFile + drv := &Driver{} + page, err := drv.Open(context.Background(), drivers.Params{ + URL: localFileURL, + }) + if err != nil { + t.Error("cannot read local file, %w", err) + } + if localFileURL != string(page.GetURL()) { + t.Errorf("got: '%s', want: '%s'", page.GetURL(), localFileURL) + } + doc := page.GetMainFrame() + gotInnerText, err := doc.GetElement().GetInnerText(ctx) + if err != nil { + t.Errorf("cannot get inner text of document, %s", err) + } else if testInnerText != gotInnerText { + t.Errorf("got: '%s', want: '%s'", gotInnerText, testInnerText) + } +} + func TestDriver_convertToUTF8(t *testing.T) { type args struct { inputData string @@ -102,11 +136,11 @@ func TestDriver_convertToUTF8(t *testing.T) { { name: "should convert to expected state", args: args{ - inputData: `феррет`, + inputData: testHTMLDocument, srcCharset: "windows-1251", }, wantErr: false, - expected: `феррет`, + expected: testHTMLDocument, }, } for _, tt := range tests {