From fbb8840b8f24cc8df4c20b1c072c6f02876d94ef Mon Sep 17 00:00:00 2001 From: mono Date: Mon, 19 Aug 2024 00:05:39 +0900 Subject: [PATCH] Use context when scraping --- scraper/source.go | 3 ++- scraper/source/googlecalendar/source.go | 7 +++---- scraper/source/impresswatchcolumn/source.go | 9 +++++++-- scraper/source/impresswatchcolumn/source_test.go | 3 ++- scraper/source/kittychaninfo/source.go | 9 +++++++-- scraper/source/kittychaninfo/source_test.go | 3 ++- scraper/source/lalapiroomevent/source.go | 9 +++++++-- scraper/source/lalapiroomevent/source_test.go | 3 ++- scraper/source/yuyakekoyakenews/source.go | 9 +++++++-- scraper/source/yuyakekoyakenews/source_test.go | 3 ++- server/server.go | 2 +- 11 files changed, 42 insertions(+), 18 deletions(-) diff --git a/scraper/source.go b/scraper/source.go index 270011a..b4c0c87 100644 --- a/scraper/source.go +++ b/scraper/source.go @@ -1,6 +1,7 @@ package scraper import ( + "context" "net/url" "github.com/gorilla/feeds" @@ -8,5 +9,5 @@ import ( type Source interface { Name() string - Scrape(query url.Values) (*feeds.Feed, error) + Scrape(ctx context.Context, query url.Values) (*feeds.Feed, error) } diff --git a/scraper/source/googlecalendar/source.go b/scraper/source/googlecalendar/source.go index 2c6bf05..dc76c12 100644 --- a/scraper/source/googlecalendar/source.go +++ b/scraper/source/googlecalendar/source.go @@ -43,25 +43,24 @@ func (s *source) Name() string { return "google-calendar" } -func (s *source) Scrape(query url.Values) (*feeds.Feed, error) { +func (s *source) Scrape(ctx context.Context, query url.Values) (*feeds.Feed, error) { calendarID := query.Get("id") if calendarID == "" { return &feeds.Feed{}, nil } - events, err := s.fetch(calendarID) + events, err := s.fetch(ctx, calendarID) if err != nil { return nil, err } return s.render(events, calendarID) } -func (s *source) fetch(calendarID string) (*calendar.Events, error) { +func (s *source) fetch(ctx context.Context, calendarID string) (*calendar.Events, error) { config, err := google.JWTConfigFromJSON(([]byte)(os.Getenv("GOOGLE_CLIENT_CREDENTIALS")), calendar.CalendarReadonlyScope) if err != nil { return nil, fmt.Errorf("%w", err) } - ctx := context.Background() ctx = context.WithValue(ctx, oauth2.HTTPClient, s.httpClient) client := config.Client(ctx) diff --git a/scraper/source/impresswatchcolumn/source.go b/scraper/source/impresswatchcolumn/source.go index efd3bd5..25b8a73 100644 --- a/scraper/source/impresswatchcolumn/source.go +++ b/scraper/source/impresswatchcolumn/source.go @@ -1,6 +1,7 @@ package impresswatchcolumn import ( + "context" "fmt" "html" "net/http" @@ -42,7 +43,7 @@ func (*source) Name() string { return "impress-watch-column" } -func (s *source) Scrape(query url.Values) (*feeds.Feed, error) { +func (s *source) Scrape(ctx context.Context, query url.Values) (*feeds.Feed, error) { site := query.Get("site") column := query.Get("column") if site == "" || column == "" { @@ -54,7 +55,11 @@ func (s *source) Scrape(query url.Values) (*feeds.Feed, error) { r := strings.NewReplacer("{site}", site, "{column}", column) - res, err := s.httpClient.Get(r.Replace(s.baseURL + endpoint)) + req, err := http.NewRequestWithContext(ctx, "GET", r.Replace(s.baseURL+endpoint), nil) + if err != nil { + return nil, fmt.Errorf("%w", err) + } + res, err := s.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("%w", err) } diff --git a/scraper/source/impresswatchcolumn/source_test.go b/scraper/source/impresswatchcolumn/source_test.go index 2bf8e76..ed28b7c 100644 --- a/scraper/source/impresswatchcolumn/source_test.go +++ b/scraper/source/impresswatchcolumn/source_test.go @@ -1,6 +1,7 @@ package impresswatchcolumn import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -31,7 +32,7 @@ func TestScrape(t *testing.T) { v := url.Values{} v.Set("site", "k-tai") v.Set("column", "stapa") - feed, err := source.Scrape(v) + feed, err := source.Scrape(context.Background(), v) if err != nil { t.Fatal(err) } diff --git a/scraper/source/kittychaninfo/source.go b/scraper/source/kittychaninfo/source.go index f2dd7dc..354f3fe 100644 --- a/scraper/source/kittychaninfo/source.go +++ b/scraper/source/kittychaninfo/source.go @@ -1,6 +1,7 @@ package kittychaninfo import ( + "context" "fmt" "io" "net/http" @@ -48,8 +49,12 @@ func (s *source) Name() string { return "kittychan-info" } -func (s *source) Scrape(url.Values) (*feeds.Feed, error) { - res, err := s.httpClient.Get(s.baseURL + endpoint) +func (s *source) Scrape(ctx context.Context, _ url.Values) (*feeds.Feed, error) { + req, err := http.NewRequestWithContext(ctx, "GET", s.baseURL+endpoint, nil) + if err != nil { + return nil, fmt.Errorf("%w", err) + } + res, err := s.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("%w", err) } diff --git a/scraper/source/kittychaninfo/source_test.go b/scraper/source/kittychaninfo/source_test.go index c438eca..32a468a 100644 --- a/scraper/source/kittychaninfo/source_test.go +++ b/scraper/source/kittychaninfo/source_test.go @@ -1,6 +1,7 @@ package kittychaninfo import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -33,7 +34,7 @@ func TestScrape(t *testing.T) { t.Fatal(err) } - feed, err := source.Scrape(url.Values{}) + feed, err := source.Scrape(context.Background(), url.Values{}) if err != nil { t.Fatal(err) } diff --git a/scraper/source/lalapiroomevent/source.go b/scraper/source/lalapiroomevent/source.go index b550cba..0f9bf88 100644 --- a/scraper/source/lalapiroomevent/source.go +++ b/scraper/source/lalapiroomevent/source.go @@ -1,6 +1,7 @@ package lalapiroomevent import ( + "context" "fmt" "net/http" "net/url" @@ -38,8 +39,12 @@ func (*source) Name() string { return "lalapi-room-event" } -func (s *source) Scrape(query url.Values) (*feeds.Feed, error) { - res, err := s.httpClient.Get(s.baseURL + endpoint) +func (s *source) Scrape(ctx context.Context, query url.Values) (*feeds.Feed, error) { + req, err := http.NewRequestWithContext(ctx, "GET", s.baseURL+endpoint, nil) + if err != nil { + return nil, fmt.Errorf("%w", err) + } + res, err := s.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("%w", err) } diff --git a/scraper/source/lalapiroomevent/source_test.go b/scraper/source/lalapiroomevent/source_test.go index d0468f2..9042573 100644 --- a/scraper/source/lalapiroomevent/source_test.go +++ b/scraper/source/lalapiroomevent/source_test.go @@ -1,6 +1,7 @@ package lalapiroomevent import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -28,7 +29,7 @@ func TestScrape(t *testing.T) { source := NewSource(server.Client()) source.baseURL = server.URL - feed, err := source.Scrape(url.Values{}) + feed, err := source.Scrape(context.Background(), url.Values{}) if err != nil { t.Fatal(err) } diff --git a/scraper/source/yuyakekoyakenews/source.go b/scraper/source/yuyakekoyakenews/source.go index 1ba1bad..b4ec37e 100644 --- a/scraper/source/yuyakekoyakenews/source.go +++ b/scraper/source/yuyakekoyakenews/source.go @@ -1,6 +1,7 @@ package yuyakekoyakenews import ( + "context" "fmt" "net/http" "net/url" @@ -36,8 +37,12 @@ func (s *source) Name() string { return "yuyakekoyake-news" } -func (s *source) Scrape(url.Values) (*feeds.Feed, error) { - res, err := s.httpClient.Get(s.baseURL + endpoint) +func (s *source) Scrape(ctx context.Context, _ url.Values) (*feeds.Feed, error) { + req, err := http.NewRequestWithContext(ctx, "GET", s.baseURL+endpoint, nil) + if err != nil { + return nil, fmt.Errorf("%w", err) + } + res, err := s.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("%w", err) } diff --git a/scraper/source/yuyakekoyakenews/source_test.go b/scraper/source/yuyakekoyakenews/source_test.go index 611fac7..b4f40da 100644 --- a/scraper/source/yuyakekoyakenews/source_test.go +++ b/scraper/source/yuyakekoyakenews/source_test.go @@ -1,6 +1,7 @@ package yuyakekoyakenews import ( + "context" "net/http" "net/http/httptest" "net/url" @@ -28,7 +29,7 @@ func TestScrape(t *testing.T) { source := NewSource(server.Client()) source.baseURL = server.URL - feed, err := source.Scrape(url.Values{}) + feed, err := source.Scrape(context.Background(), url.Values{}) if err != nil { t.Fatal(err) } diff --git a/server/server.go b/server/server.go index e66492d..7a6db8e 100644 --- a/server/server.go +++ b/server/server.go @@ -47,7 +47,7 @@ func NewHandler(sources []scraper.Source) (http.Handler, error) { return } - feed, err := source.Scrape(r.URL.Query()) + feed, err := source.Scrape(r.Context(), r.URL.Query()) if err != nil { log.Printf("%v: %+v\n", reflect.TypeOf(source), err) w.WriteHeader(http.StatusServiceUnavailable)