diff --git a/cmd/cmd.go b/cmd/cmd.go index 94117a7..5b0dc2a 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -11,6 +11,7 @@ import ( func GetRootCommand() *cobra.Command { rootCMD := root.GetCommand() rootCMD.AddCommand(multifile.GetCommand()) + rootCMD.AddCommand(GetProxyCommand()) rootCMD.AddCommand(version.VersionCMD) return rootCMD } diff --git a/cmd/proxy.go b/cmd/proxy.go new file mode 100644 index 0000000..8f1e77a --- /dev/null +++ b/cmd/proxy.go @@ -0,0 +1,104 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/dustin/go-humanize" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/replicate/pget/pkg/cli" + "github.com/replicate/pget/pkg/client" + "github.com/replicate/pget/pkg/config" + "github.com/replicate/pget/pkg/download" + "github.com/replicate/pget/pkg/proxy" +) + +const longDesc = ` +TODO +` + +func GetProxyCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "proxy [flags] ", + Short: "run as an http proxy server", + Long: longDesc, + PreRunE: proxyPreRunE, + RunE: runProxyCMD, + Args: cobra.ExactArgs(0), + Example: ` pget proxy`, + } + cmd.Flags().String(config.OptListenAddress, "127.0.0.1:9512", "address to listen on") + err := viper.BindPFlags(cmd.Flags()) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + err = viper.BindPFlags(cmd.PersistentFlags()) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + cmd.SetUsageTemplate(cli.UsageTemplate) + return cmd +} + +func proxyPreRunE(cmd *cobra.Command, args []string) error { + if viper.GetBool(config.OptExtract) { + return fmt.Errorf("cannot use --extract with proxy mode") + } + if viper.GetString(config.OptOutputConsumer) == config.ConsumerTarExtractor { + return fmt.Errorf("cannot use --output-consumer tar-extractor with proxy mode") + } + return nil +} + +func runProxyCMD(cmd *cobra.Command, args []string) error { + minChunkSize, err := humanize.ParseBytes(viper.GetString(config.OptMinimumChunkSize)) + if err != nil { + return err + } + clientOpts := client.Options{ + MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost), + ForceHTTP2: viper.GetBool(config.OptForceHTTP2), + MaxRetries: viper.GetInt(config.OptRetries), + ConnectTimeout: viper.GetDuration(config.OptConnTimeout), + } + downloadOpts := download.Options{ + MaxConcurrency: viper.GetInt(config.OptConcurrency), + MinChunkSize: int64(minChunkSize), + Client: clientOpts, + } + + // TODO DRY this + srvName := config.GetCacheSRV() + + if srvName == "" { + return fmt.Errorf("Option %s MUST be specified in proxy mode", config.OptCacheNodesSRVName) + } + + downloadOpts.SliceSize = 500 * humanize.MiByte + // FIXME: make this a config option + downloadOpts.DomainsToCache = []string{"weights.replicate.delivery"} + // TODO: dynamically respond to SRV updates rather than just looking up + // once at startup + downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName) + if err != nil { + return err + } + chMode, err := download.GetConsistentHashingMode(downloadOpts) + if err != nil { + return err + } + + proxy, err := proxy.New( + chMode, + &proxy.Options{ + Address: viper.GetString(config.OptListenAddress), + }) + if err != nil { + return err + } + return proxy.Start() +} diff --git a/pkg/config/optnames.go b/pkg/config/optnames.go index b0e4208..b04a10b 100644 --- a/pkg/config/optnames.go +++ b/pkg/config/optnames.go @@ -12,6 +12,7 @@ const ( OptExtract = "extract" OptForce = "force" OptForceHTTP2 = "force-http2" + OptListenAddress = "listen-address" OptLoggingLevel = "log-level" OptMaxChunks = "max-chunks" OptMaxConnPerHost = "max-conn-per-host" diff --git a/pkg/download/buffer.go b/pkg/download/buffer.go index 97459ac..83a6cb8 100644 --- a/pkg/download/buffer.go +++ b/pkg/download/buffer.go @@ -68,6 +68,11 @@ type firstReqResult struct { func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) { logger := logging.GetLogger() + baseReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, 0, err + } + br := newBufferedReader(m.minChunkSize()) firstReqResultCh := make(chan firstReqResult) @@ -75,7 +80,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e m.sem.Go(func() error { defer close(firstReqResultCh) defer br.done() - firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url) + firstChunkResp, err := m.DoRequest(baseReq, 0, m.minChunkSize()-1) if err != nil { firstReqResultCh <- firstReqResult{err: err} return err @@ -109,7 +114,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e } fileSize := firstReqResult.fileSize - trueURL := firstReqResult.trueURL + trueURLReq, err := http.NewRequestWithContext(ctx, http.MethodGet, firstReqResult.trueURL, nil) + if err != nil { + return nil, 0, err + } if fileSize <= m.minChunkSize() { // we only need a single chunk: just download it and finish @@ -157,7 +165,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e m.sem.Go(func() error { defer br.done() - resp, err := m.DoRequest(ctx, start, end, trueURL) + resp, err := m.DoRequest(trueURLReq, start, end) if err != nil { return err } @@ -170,18 +178,15 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e return newChanMultiReader(readersCh), fileSize, nil } -func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to download %s: %w", trueURL, err) - } +func (m *BufferMode) DoRequest(origReq *http.Request, start, end int64) (*http.Response, error) { + req := origReq.Clone(origReq.Context()) req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) resp, err := m.Client.Do(req) if err != nil { return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err) } if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status) + return nil, fmt.Errorf("%w %s", ErrUnexpectedHTTPStatus(resp.StatusCode), req.URL.String()) } return resp, nil diff --git a/pkg/download/consistent_hashing.go b/pkg/download/consistent_hashing.go index ed147b4..8091cb6 100644 --- a/pkg/download/consistent_hashing.go +++ b/pkg/download/consistent_hashing.go @@ -78,16 +78,46 @@ func (m *ConsistentHashingMode) getFileSizeFromContentRange(contentRange string) return strconv.ParseInt(groups[1], 10, 64) } +var _ http.Handler = &ConsistentHashingMode{} + func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io.Reader, int64, error) { - logger := logging.GetLogger() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlString, nil) + if err != nil { + return nil, 0, err + } + return m.fetch(req) +} - parsed, err := url.Parse(urlString) +func (m *ConsistentHashingMode) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + // if we want to forward req, we must blank out req.RequestURI + req.RequestURI = "" + // client requests don't have scheme or host set on the request. We need to + // restore it for hash consistency + req.URL.Scheme = "https" + req.URL.Host = req.Host + reader, size, err := m.fetch(req) if err != nil { - return nil, -1, err + var httpErr HttpStatusError + if errors.As(err, &httpErr) { + resp.WriteHeader(httpErr.StatusCode) + } else { + resp.WriteHeader(http.StatusInternalServerError) + } + return } + // TODO: http.StatusPartialContent and Content-Range if it was a range request + resp.Header().Set("Content-Length", fmt.Sprint(size)) + resp.WriteHeader(http.StatusOK) + // we ignore errors as it's too late to change status code + _, _ = io.Copy(resp, reader) +} + +func (m *ConsistentHashingMode) fetch(req *http.Request) (io.Reader, int64, error) { + logger := logging.GetLogger() + shouldContinue := false for _, host := range m.DomainsToCache { - if host == parsed.Host { + if host == req.Host { shouldContinue = true break } @@ -95,10 +125,10 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io // Use our fallback mode if we're not downloading from a consistent-hashing enabled domain if !shouldContinue { logger.Debug(). - Str("url", urlString). - Str("reason", fmt.Sprintf("consistent hashing not enabled for %s", parsed.Host)). + Str("url", req.URL.String()). + Str("reason", fmt.Sprintf("consistent hashing not enabled for %s", req.Host)). Msg("fallback strategy") - return m.FallbackStrategy.Fetch(ctx, urlString) + return m.FallbackStrategy.Fetch(req.Context(), req.URL.String()) } br := newBufferedReader(m.minChunkSize()) @@ -107,7 +137,8 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io m.sem.Go(func() error { defer close(firstReqResultCh) defer br.done() - firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, urlString) + // TODO: respect Range header in the original request + firstChunkResp, err := m.DoRequest(req, 0, m.minChunkSize()-1) if err != nil { firstReqResultCh <- firstReqResult{err: err} return err @@ -135,11 +166,11 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io if errors.Is(firstReqResult.err, client.ErrStrategyFallback) { // TODO(morgan): we should indicate the fallback strategy we're using in the logs logger.Info(). - Str("url", urlString). + Str("url", req.URL.String()). Str("type", "file"). - Err(err). + Err(firstReqResult.err). Msg("consistent hash fallback") - return m.FallbackStrategy.Fetch(ctx, urlString) + return m.FallbackStrategy.Fetch(req.Context(), req.URL.String()) } return nil, -1, firstReqResult.err } @@ -172,7 +203,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io readersCh := make(chan io.Reader, m.maxConcurrency()+1) readersCh <- br - logger.Debug().Str("url", urlString). + logger.Debug().Str("url", req.URL.String()). Int64("size", fileSize). Int("concurrency", m.maxConcurrency()). Ints64("chunks_per_slice", chunksPerSlice). @@ -214,7 +245,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io m.sem.Go(func() error { defer br.done() logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request") - resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString) + resp, err := m.DoRequest(req, chunkStart, chunkEnd) if err != nil { // in the case that an error indicating an issue with the cache server, networking, etc is returned, // this will use the fallback strategy. This is a case where the whole file will perform the fall-back @@ -222,11 +253,11 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io if errors.Is(err, client.ErrStrategyFallback) { // TODO(morgan): we should indicate the fallback strategy we're using in the logs logger.Info(). - Str("url", urlString). + Str("url", req.URL.String()). Str("type", "chunk"). Err(err). Msg("consistent hash fallback") - resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString) + resp, err = m.FallbackStrategy.DoRequest(req, chunkStart, chunkEnd) } if err != nil { return err @@ -244,36 +275,30 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io return newChanMultiReader(readersCh), fileSize, nil } -func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) { +func (m *ConsistentHashingMode) DoRequest(origReq *http.Request, start, end int64) (*http.Response, error) { logger := logging.GetLogger() - chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true) - req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) - if err != nil { - return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err) - } + chContext := context.WithValue(origReq.Context(), config.ConsistentHashingStrategyKey, true) + req := origReq.Clone(chContext) cachePodIndex, err := m.rewriteRequestToCacheHost(req, start, end) if err != nil { return nil, err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request") + logger.Debug().Str("url", req.URL.String()).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request") resp, err := m.Client.Do(req) if err != nil { if errors.Is(err, client.ErrStrategyFallback) { origErr := err - req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil) - if err != nil { - return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err) - } + req = origReq.Clone(chContext) _, err = m.rewriteRequestToCacheHost(req, start, end, cachePodIndex) if err != nil { // return origErr so that we can use our regular fallback strategy return nil, origErr } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request") + logger.Debug().Str("url", origReq.URL.String()).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request") resp, err = m.Client.Do(req) if err != nil { @@ -285,7 +310,11 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, } } if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 { - return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status) + if resp.StatusCode >= 400 { + return nil, HttpStatusError{StatusCode: resp.StatusCode} + } + + return nil, fmt.Errorf("%w %s", ErrUnexpectedHTTPStatus(resp.StatusCode), req.URL.String()) } return resp, nil diff --git a/pkg/download/consistent_hashing_test.go b/pkg/download/consistent_hashing_test.go index b6a3456..adfae69 100644 --- a/pkg/download/consistent_hashing_test.go +++ b/pkg/download/consistent_hashing_test.go @@ -326,14 +326,10 @@ func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64, return io.NopCloser(strings.NewReader("00")), -1, nil } -func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) { +func (s *testStrategy) DoRequest(req *http.Request, start, end int64) (*http.Response, error) { s.mut.Lock() s.doRequestCalledCount++ s.mut.Unlock() - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return nil, err - } resp := &http.Response{ Request: req, Body: io.NopCloser(strings.NewReader("00")), @@ -362,7 +358,7 @@ func TestConsistentHashingFileFallback(t *testing.T) { responseStatus: http.StatusNotFound, fetchCalledCount: 0, doRequestCalledCount: 0, - expectedError: download.ErrUnexpectedHTTPStatus, + expectedError: download.ErrUnexpectedHTTPStatus(http.StatusNotFound), }, } diff --git a/pkg/download/errors.go b/pkg/download/errors.go new file mode 100644 index 0000000..584f4e8 --- /dev/null +++ b/pkg/download/errors.go @@ -0,0 +1,19 @@ +package download + +import ( + "fmt" +) + +type HttpStatusError struct { + StatusCode int +} + +func ErrUnexpectedHTTPStatus(statusCode int) error { + return HttpStatusError{StatusCode: statusCode} +} + +var _ error = &HttpStatusError{} + +func (c HttpStatusError) Error() string { + return fmt.Sprintf("Status code %d", c.StatusCode) +} diff --git a/pkg/download/strategy.go b/pkg/download/strategy.go index a430dd2..2a1818c 100644 --- a/pkg/download/strategy.go +++ b/pkg/download/strategy.go @@ -2,20 +2,17 @@ package download import ( "context" - "errors" "io" "net/http" ) -var ErrUnexpectedHTTPStatus = errors.New("unexpected http status") - type Strategy interface { // Fetch retrieves the content from a given URL and returns it as an io.Reader along with the file size. // If an error occurs during the process, it returns nil for the reader, 0 for the fileSize, and the error itself. // This is the primary method that should be called to initiate a download of a file. Fetch(ctx context.Context, url string) (result io.Reader, fileSize int64, err error) - // DoRequest sends an HTTP GET request with a specified range of bytes to the given URL using the provided context. + // DoRequest executes an HTTP request with a specified range of bytes. // It returns the HTTP response and any error encountered during the request. It is intended that Fetch calls DoRequest // and that each chunk is downloaded with a call to DoRequest. DoRequest is exposed so that consistent-hashing can // utilize any strategy as a fall-back for chunk downloading. @@ -23,6 +20,5 @@ type Strategy interface { // If the request fails to download or execute, an error is returned. // // The start and end parameters specify the byte range to request. - // The trueURL parameter is the actual URL after any redirects. - DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) + DoRequest(req *http.Request, start, end int64) (*http.Response, error) } diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go new file mode 100644 index 0000000..cd07d99 --- /dev/null +++ b/pkg/proxy/proxy.go @@ -0,0 +1,43 @@ +package proxy + +import ( + "net/http" + "time" + + "github.com/replicate/pget/pkg/download" + "github.com/replicate/pget/pkg/logging" +) + +type Proxy struct { + httpServer *http.Server + chMode *download.ConsistentHashingMode + opts *Options +} + +type Options struct { + Address string +} + +func New(chMode *download.ConsistentHashingMode, opts *Options) (*Proxy, error) { + return &Proxy{ + chMode: chMode, + opts: opts, + }, nil +} + +func (p *Proxy) Start() error { + logger := logging.GetLogger() + var err error + if err != nil { + return err + } + logger.Debug().Str("address", p.opts.Address).Msg("Listening on") + p.httpServer = &http.Server{ + Addr: p.opts.Address, + Handler: p.chMode, + ReadTimeout: 15 * time.Second, + ReadHeaderTimeout: 5 * time.Second, + WriteTimeout: 15 * time.Second, + } + return p.httpServer.ListenAndServe() +}