package httpcache import ( "bufio" hash "crypto/md5" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "sort" "strings" ) var ( UserAgent string ModifyResponse func(url string, entry CacheEntry, resp *http.Response) *http.Response CheckRedirect func(req *http.Request, via []*http.Request) error ) type CacheEntry string var memCache = map[string]CacheEntry{} type httpStatusError struct { StatusCode int Status string } // Is implements the interface for [errors.Is]. func (e *httpStatusError) Is(target error) bool { switch target { case os.ErrNotExist: return e.StatusCode == http.StatusNotFound default: return false } } // Error implements [error]. func (e *httpStatusError) Error() string { return fmt.Sprintf("unexpected HTTP status: %v", e.Status) } type transport struct{} func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { // Return an error for things that are the fault of things // not-on-this-box. Panic for things that are the fault of // this box. // Initialize. if err := os.Mkdir(".http-cache", 0777); err != nil && !os.IsExist(err) { panic(err) } // Calculate cache-key. u := req.URL.String() cacheKey := url.QueryEscape(u) hdrKeys := make([]string, 0, len(req.Header)) for k := range req.Header { switch k { case "User-Agent": case "Referer": default: hdrKeys = append(hdrKeys, http.CanonicalHeaderKey(k)) } } sort.Strings(hdrKeys) for _, k := range hdrKeys { cacheKey += "|" + url.QueryEscape(k) + ":" + url.QueryEscape(req.Header[k][0]) } if len(cacheKey) >= 255 { prefix := cacheKey[:255-(hash.Size*2)] csum := hash.Sum([]byte(cacheKey)) suffix := hex.EncodeToString(csum[:]) cacheKey = prefix + suffix } cacheFile := filepath.Join(".http-cache", cacheKey) // Check the mem cache. if _, ok := memCache[cacheKey]; ok { fmt.Printf("GET|CACHE|MEM %q...", u) goto end } // Check the file cache. if bs, err := os.ReadFile(cacheFile); err == nil { str := string(bs) if strings.HasPrefix(str, "HTTP/") || strings.HasPrefix(str, "CLIENT/") { fmt.Printf("GET|CACHE|FILE %q...", u) memCache[cacheKey] = CacheEntry(str) goto end } } // Do the request for real. fmt.Printf("GET|NET %q...", u) if resp, err := http.DefaultTransport.RoundTrip(req); err == nil { var buf strings.Builder if err := resp.Write(&buf); err != nil { panic(err) } memCache[cacheKey] = CacheEntry(buf.String()) } else { memCache[cacheKey] = CacheEntry("CLIENT/" + err.Error()) } // Record the response to the file cache. if err := os.WriteFile(cacheFile, []byte(memCache[cacheKey]), 0666); err != nil { panic(err) } end: // Turn the cache entry into an http.Response (or error) var ret_resp *http.Response var ret_err error entry := memCache[cacheKey] switch { case strings.HasPrefix(string(entry), "HTTP/"): var err error ret_resp, err = http.ReadResponse(bufio.NewReader(strings.NewReader(string(entry))), nil) if err != nil { panic(fmt.Errorf("invalid cache entry: %v", err)) } if ModifyResponse != nil { ret_resp = ModifyResponse(u, entry, ret_resp) } case strings.HasPrefix(string(entry), "CLIENT/"): ret_err = errors.New(string(entry)[len("CLIENT/"):]) default: panic("invalid cache entry: invalid prefix") } // Return. if ret_err != nil { fmt.Printf(" err\n") } else { fmt.Printf(" http %v\n", ret_resp.StatusCode) } return ret_resp, ret_err } func Get(u string, hdr map[string]string) (string, error) { if UserAgent == "" { panic("main() must set the user agent string") } req, err := http.NewRequest(http.MethodGet, u, nil) if err != nil { panic(fmt.Errorf("should not happen: http.NewRequest: %v", err)) } req.Header.Set("User-Agent", UserAgent) for k, v := range hdr { req.Header.Add(k, v) } client := &http.Client{ Transport: &transport{}, CheckRedirect: CheckRedirect, } resp, err := client.Do(req) if err != nil { return "", err } if resp.StatusCode != http.StatusOK { return "", &httpStatusError{StatusCode: resp.StatusCode, Status: resp.Status} } bs, err := io.ReadAll(resp.Body) if err != nil { panic(fmt.Errorf("should not happen: strings.Reader.Read: %v", err)) } return string(bs), nil } func GetJSON(u string, hdr map[string]string, out any) error { str, err := Get(u, hdr) if err != nil { return err } return json.Unmarshal([]byte(str), out) } func GetPaginatedJSON[T any](uStr string, hdr map[string]string, out *[]T, pageFn func(i int) url.Values) error { u, err := url.Parse(uStr) if err != nil { return err } query := u.Query() for i := 0; true; i++ { pageParams := pageFn(i) for k, v := range pageParams { query[k] = v } u.RawQuery = query.Encode() var resp []T if err := GetJSON(u.String(), hdr, &resp); err != nil { return err } fmt.Printf(" -> %d records\n", len(resp)) if len(resp) == 0 { break } *out = append(*out, resp...) } return nil }