diff --git a/rpcserver/jsonrpc_server.go b/rpcserver/jsonrpc_server.go index 2350518..27d9ed6 100644 --- a/rpcserver/jsonrpc_server.go +++ b/rpcserver/jsonrpc_server.go @@ -277,8 +277,32 @@ func (h *JSONRPCHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } ctx = context.WithValue(ctx, signerKey{}, signer) } - // r.URL - ctx = context.WithValue(ctx, urlKey{}, r.URL) + // Extract URL from headers (Stage 2) or use r.URL directly (Stage 1) + // Proxyd may send X-Original-Path and X-Original-Query independently + reqURL := r.URL + originalPath := r.Header.Get("X-Original-Path") + originalQuery := r.Header.Get("X-Original-Query") + + // Only create new URL if at least one header is present + if originalPath != "" || originalQuery != "" { + // Start with actual URL values + path := r.URL.Path + query := r.URL.RawQuery + + // Replace with header values if present + if originalPath != "" { + path = originalPath + } + if originalQuery != "" { + query = originalQuery + } + + reqURL = &url.URL{ + Path: path, + RawQuery: query, + } + } + ctx = context.WithValue(ctx, urlKey{}, reqURL) // read request var req jsonRPCRequest diff --git a/rpcserver/jsonrpc_server_test.go b/rpcserver/jsonrpc_server_test.go index e6bed00..ab4d1db 100644 --- a/rpcserver/jsonrpc_server_test.go +++ b/rpcserver/jsonrpc_server_test.go @@ -208,3 +208,77 @@ func TestJSONRPCServerReadyzError(t *testing.T) { fmt.Println(rr.Body.String()) require.Equal(t, "not ready\n", rr.Body.String()) } + +func TestURLExtraction(t *testing.T) { + // Handler that captures URL from context + var capturedURL string + handlerMethod := func(ctx context.Context) (string, error) { + url := GetURL(ctx) + capturedURL = url.Path + "?" + url.RawQuery + return capturedURL, nil + } + + handler, err := NewJSONRPCHandler(map[string]interface{}{ + "test": handlerMethod, + }, JSONRPCHandlerOpts{}) + require.NoError(t, err) + + t.Run("No headers: uses r.URL (backward compat)", func(t *testing.T) { + body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"test","params":[]}`)) + request, err := http.NewRequest(http.MethodPost, "/fast?hint=calldata", body) + require.NoError(t, err) + request.Header.Add("Content-Type", "application/json") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "/fast?hint=calldata", capturedURL) + }) + + t.Run("Both headers: reconstructs URL (works with any path)", func(t *testing.T) { + // Test with /whatever instead of /fast to prove it's not hardcoded + body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"test","params":[]}`)) + request, err := http.NewRequest(http.MethodPost, "/", body) + require.NoError(t, err) + request.Header.Add("Content-Type", "application/json") + request.Header.Add("X-Original-Path", "/whatever") + request.Header.Add("X-Original-Query", "hint=hash&builder=flashbots") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "/whatever?hint=hash&builder=flashbots", capturedURL) + }) + + t.Run("Only query header: uses r.URL.Path", func(t *testing.T) { + // Proxyd doesn't send X-Original-Path when path is "/" + body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"test","params":[]}`)) + request, err := http.NewRequest(http.MethodPost, "/", body) + require.NoError(t, err) + request.Header.Add("Content-Type", "application/json") + request.Header.Add("X-Original-Query", "hint=hash") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "/?hint=hash", capturedURL) + }) + + t.Run("Only path header: uses r.URL.RawQuery", func(t *testing.T) { + // Proxyd doesn't send X-Original-Query when there's no query string + body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"test","params":[]}`)) + request, err := http.NewRequest(http.MethodPost, "/api", body) + require.NoError(t, err) + request.Header.Add("Content-Type", "application/json") + request.Header.Add("X-Original-Path", "/fast") + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, request) + + require.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, "/fast?", capturedURL) + }) +}