From f6b4cabef592d83491c3b9509e07ce879a850d23 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Tue, 25 Nov 2025 19:07:32 -0800 Subject: [PATCH 01/12] Add AWS OIDC outgoing auth plugin --- app/auth/plugins/aws_oidc/outgoing.go | 177 +++++++++++++++++++++ app/auth/plugins/aws_oidc/outgoing_test.go | 163 +++++++++++++++++++ app/auth/plugins/plugins.go | 1 + docs/auth-plugins.md | 14 ++ 4 files changed, 355 insertions(+) create mode 100644 app/auth/plugins/aws_oidc/outgoing.go create mode 100644 app/auth/plugins/aws_oidc/outgoing_test.go diff --git a/app/auth/plugins/aws_oidc/outgoing.go b/app/auth/plugins/aws_oidc/outgoing.go new file mode 100644 index 0000000..f8b4b50 --- /dev/null +++ b/app/auth/plugins/aws_oidc/outgoing.go @@ -0,0 +1,177 @@ +package awsoidc + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + authplugins "github.com/winhowes/AuthTranslator/app/auth" +) + +// awsOIDCParams configures the AWS OIDC plugin. +type awsOIDCParams struct { + Audience string `json:"audience"` + Header string `json:"header"` + Prefix string `json:"prefix"` +} + +// AWSOIDC fetches an ID token from the AWS Instance Metadata Service (IMDSv2) +// and adds it to outgoing requests. +type AWSOIDC struct{} + +// MetadataHost is the base URL for the AWS metadata service. It can be +// overridden in tests. +var MetadataHost = "http://169.254.169.254" + +// HTTPClient is used for all metadata requests. +var HTTPClient = &http.Client{Timeout: 5 * time.Second} + +var tokenCache = struct { + sync.Mutex + m map[string]cachedToken +}{m: make(map[string]cachedToken)} + +type cachedToken struct { + token string + exp time.Time +} + +func (a *AWSOIDC) Name() string { return "aws_oidc" } + +func (a *AWSOIDC) RequiredParams() []string { return []string{"audience"} } + +func (a *AWSOIDC) OptionalParams() []string { return []string{"header", "prefix"} } + +func (a *AWSOIDC) ParseParams(m map[string]interface{}) (interface{}, error) { + p, err := authplugins.ParseParams[awsOIDCParams](m) + if err != nil { + return nil, err + } + if p.Audience == "" { + return nil, fmt.Errorf("missing audience") + } + if p.Header == "" { + p.Header = "Authorization" + } + if p.Prefix == "" { + p.Prefix = "Bearer " + } + return p, nil +} + +func (a *AWSOIDC) AddAuth(ctx context.Context, r *http.Request, params interface{}) error { + cfg, ok := params.(*awsOIDCParams) + if !ok { + return fmt.Errorf("invalid config") + } + tok, exp := getCachedToken(cfg.Audience) + if tok == "" || time.Now().After(exp.Add(-1*time.Minute)) { + var err error + tok, exp, err = fetchToken(ctx, cfg.Audience) + if err != nil { + return err + } + setCachedToken(cfg.Audience, tok, exp) + } + r.Header.Set(cfg.Header, cfg.Prefix+tok) + return nil +} + +func fetchToken(ctx context.Context, aud string) (string, time.Time, error) { + metaToken, err := fetchMetadataToken(ctx) + if err != nil { + return "", time.Time{}, err + } + + metaURL := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/oidc?audience=%s", MetadataHost, url.QueryEscape(aud)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, metaURL, nil) + if err != nil { + return "", time.Time{}, err + } + req.Header.Set("X-aws-ec2-metadata-token", metaToken) + + resp, err := HTTPClient.Do(req) + if err != nil { + return "", time.Time{}, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", time.Time{}, fmt.Errorf("status %s: %s", resp.Status, body) + } + + tokenBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", time.Time{}, err + } + tok := string(tokenBytes) + return tok, parseExpiry(tok), nil +} + +func fetchMetadataToken(ctx context.Context) (string, error) { + tokenURL := fmt.Sprintf("%s/latest/api/token", MetadataHost) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, tokenURL, nil) + if err != nil { + return "", err + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") + + resp, err := HTTPClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("token fetch status %s: %s", resp.Status, body) + } + + tokenBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return string(tokenBytes), nil +} + +func parseExpiry(tok string) time.Time { + parts := strings.Split(tok, ".") + if len(parts) < 2 { + return time.Now().Add(time.Minute) + } + data, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return time.Now().Add(time.Minute) + } + var c struct { + Exp int64 `json:"exp"` + } + if err := json.Unmarshal(data, &c); err != nil || c.Exp == 0 { + return time.Now().Add(time.Minute) + } + return time.Unix(c.Exp, 0) +} + +func getCachedToken(aud string) (string, time.Time) { + tokenCache.Lock() + defer tokenCache.Unlock() + ct, ok := tokenCache.m[aud] + if !ok { + return "", time.Time{} + } + return ct.token, ct.exp +} + +func setCachedToken(aud, tok string, exp time.Time) { + tokenCache.Lock() + tokenCache.m[aud] = cachedToken{token: tok, exp: exp} + tokenCache.Unlock() +} + +func init() { authplugins.RegisterOutgoing(&AWSOIDC{}) } diff --git a/app/auth/plugins/aws_oidc/outgoing_test.go b/app/auth/plugins/aws_oidc/outgoing_test.go new file mode 100644 index 0000000..49cb1a4 --- /dev/null +++ b/app/auth/plugins/aws_oidc/outgoing_test.go @@ -0,0 +1,163 @@ +package awsoidc + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +type testClaims struct { + Exp int64 `json:"exp"` +} + +func makeJWT(exp time.Time) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none"}`)) + claims := base64.RawURLEncoding.EncodeToString([]byte(`{"exp":` + fmt.Sprintf("%d", exp.Unix()) + `}`)) + return strings.Join([]string{header, claims, "sig"}, ".") +} + +func TestAddAuthFetchesAndCachesToken(t *testing.T) { + now := time.Now().Add(2 * time.Minute) + jwt := makeJWT(now) + + metaToken := "meta123" + aud := "urn:test" + var requestCount int + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/api/token": + requestCount++ + if r.Method != http.MethodPut { + t.Fatalf("expected PUT for token, got %s", r.Method) + } + if ttl := r.Header.Get("X-aws-ec2-metadata-token-ttl-seconds"); ttl == "" { + t.Fatalf("missing TTL header") + } + w.Write([]byte(metaToken)) + case "/latest/meta-data/iam/security-credentials/oidc": + requestCount++ + if got := r.Header.Get("X-aws-ec2-metadata-token"); got != metaToken { + t.Fatalf("expected metadata token %q, got %q", metaToken, got) + } + if got := r.URL.Query().Get("audience"); got != aud { + t.Fatalf("expected audience %s, got %s", aud, got) + } + w.Write([]byte(jwt)) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer srv.Close() + + MetadataHost = srv.URL + HTTPClient = srv.Client() + tokenCache.m = map[string]cachedToken{} + + plugin := &AWSOIDC{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{"audience": aud}) + if err != nil { + t.Fatalf("parse params: %v", err) + } + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + if err := plugin.AddAuth(context.Background(), req, paramsRaw); err != nil { + t.Fatalf("AddAuth: %v", err) + } + + if got := req.Header.Get("Authorization"); got != "Bearer "+jwt { + t.Fatalf("unexpected header: %s", got) + } + if requestCount != 2 { + t.Fatalf("expected 2 metadata requests, got %d", requestCount) + } + + // Second call should use cache. + req2, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil { + t.Fatalf("AddAuth second: %v", err) + } + if requestCount != 2 { + t.Fatalf("expected cached token, still %d requests", requestCount) + } +} + +func TestExpiresSoonTriggersRefresh(t *testing.T) { + expSoon := time.Now().Add(30 * time.Second) + jwt1 := makeJWT(expSoon) + jwt2 := makeJWT(time.Now().Add(10 * time.Minute)) + metaToken := "meta123" + aud := "urn:test" + var stage int + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/api/token": + w.Write([]byte(metaToken)) + case "/latest/meta-data/iam/security-credentials/oidc": + if stage == 0 { + w.Write([]byte(jwt1)) + } else { + w.Write([]byte(jwt2)) + } + stage++ + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer srv.Close() + + MetadataHost = srv.URL + HTTPClient = srv.Client() + tokenCache.m = map[string]cachedToken{} + + plugin := &AWSOIDC{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{"audience": aud}) + if err != nil { + t.Fatalf("parse params: %v", err) + } + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + if err := plugin.AddAuth(context.Background(), req, paramsRaw); err != nil { + t.Fatalf("AddAuth: %v", err) + } + req2, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil { + t.Fatalf("AddAuth second: %v", err) + } + if stage != 2 { + t.Fatalf("expected token refresh, stage %d", stage) + } +} + +func TestErrorResponses(t *testing.T) { + aud := "urn:test" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/api/token": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("no token")) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer srv.Close() + + MetadataHost = srv.URL + HTTPClient = srv.Client() + tokenCache.m = map[string]cachedToken{} + + plugin := &AWSOIDC{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{"audience": aud}) + if err != nil { + t.Fatalf("parse params: %v", err) + } + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + if err := plugin.AddAuth(context.Background(), req, paramsRaw); err == nil { + t.Fatalf("expected error from metadata token fetch") + } +} diff --git a/app/auth/plugins/plugins.go b/app/auth/plugins/plugins.go index 70c886c..971ba1c 100644 --- a/app/auth/plugins/plugins.go +++ b/app/auth/plugins/plugins.go @@ -1,6 +1,7 @@ package plugins import ( + _ "github.com/winhowes/AuthTranslator/app/auth/plugins/aws_oidc" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/azure_oidc" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/basic" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/findreplace" diff --git a/docs/auth-plugins.md b/docs/auth-plugins.md index 55e64ea..cdb8dbb 100644 --- a/docs/auth-plugins.md +++ b/docs/auth-plugins.md @@ -33,6 +33,7 @@ AuthTranslator’s behaviour is extended by **plugins** – small Go packages th | Outbound | `google_oidc` | Attaches a Google identity token from the metadata service. | | Outbound | `gcp_token` | Uses a metadata service access token. | | Outbound | `azure_oidc` | Retrieves an Azure access token from the Instance Metadata Service. | +| Outbound | `aws_oidc` | Retrieves an AWS OIDC token from the Instance Metadata Service (IMDSv2). | | Outbound | `hmac_signature` | Computes an HMAC for the request. | | Outbound | `jwt` | Adds a signed JWT to the request. | | Outbound | `mtls` | Sends a client certificate and exposes the CN via header. | @@ -99,6 +100,19 @@ outgoing_auth: Obtains an access token from the Azure Instance Metadata Service for the specified `resource`, caches it, and attaches it to the configured header on each outgoing request. +### Outbound `aws_oidc` + +```yaml +outgoing_auth: + - type: aws_oidc + params: + audience: urn:example + header: Authorization # optional (default: Authorization) + prefix: "Bearer " # optional (default: "Bearer ") +``` + +Retrieves an ID token from the AWS Instance Metadata Service v2 for the provided `audience`, caches it until shortly before expiry, and attaches it to the chosen header on each outgoing request. + --- ## Writing your own plugin From d2472f6c99fb4e554989259c81ea53e493115fe9 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Tue, 25 Nov 2025 19:26:24 -0800 Subject: [PATCH 02/12] Fix AWS OIDC plugin to use IMDS role credentials --- app/auth/plugins/aws_oidc/outgoing.go | 93 +++++++++++++++------- app/auth/plugins/aws_oidc/outgoing_test.go | 75 ++++++++--------- docs/auth-plugins.md | 2 +- 3 files changed, 106 insertions(+), 64 deletions(-) diff --git a/app/auth/plugins/aws_oidc/outgoing.go b/app/auth/plugins/aws_oidc/outgoing.go index f8b4b50..a25890f 100644 --- a/app/auth/plugins/aws_oidc/outgoing.go +++ b/app/auth/plugins/aws_oidc/outgoing.go @@ -2,12 +2,10 @@ package awsoidc import ( "context" - "encoding/base64" "encoding/json" "fmt" "io" "net/http" - "net/url" "strings" "sync" "time" @@ -22,8 +20,8 @@ type awsOIDCParams struct { Prefix string `json:"prefix"` } -// AWSOIDC fetches an ID token from the AWS Instance Metadata Service (IMDSv2) -// and adds it to outgoing requests. +// AWSOIDC fetches the IAM role session token from the AWS Instance Metadata +// Service (IMDSv2) and adds it to outgoing requests. type AWSOIDC struct{} // MetadataHost is the base URL for the AWS metadata service. It can be @@ -90,29 +88,26 @@ func fetchToken(ctx context.Context, aud string) (string, time.Time, error) { return "", time.Time{}, err } - metaURL := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/oidc?audience=%s", MetadataHost, url.QueryEscape(aud)) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, metaURL, nil) + roleName, err := fetchRoleName(ctx, metaToken) if err != nil { return "", time.Time{}, err } - req.Header.Set("X-aws-ec2-metadata-token", metaToken) - resp, err := HTTPClient.Do(req) + credentials, err := fetchRoleCredentials(ctx, metaToken, roleName) if err != nil { return "", time.Time{}, err } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return "", time.Time{}, fmt.Errorf("status %s: %s", resp.Status, body) + + if credentials.Token == "" { + return "", time.Time{}, fmt.Errorf("empty session token from IMDS for role %s", roleName) } - tokenBytes, err := io.ReadAll(resp.Body) + exp, err := time.Parse(time.RFC3339, credentials.Expiration) if err != nil { - return "", time.Time{}, err + return "", time.Time{}, fmt.Errorf("parse expiration: %w", err) } - tok := string(tokenBytes) - return tok, parseExpiry(tok), nil + + return credentials.Token, exp, nil } func fetchMetadataToken(ctx context.Context) (string, error) { @@ -140,22 +135,66 @@ func fetchMetadataToken(ctx context.Context) (string, error) { return string(tokenBytes), nil } -func parseExpiry(tok string) time.Time { - parts := strings.Split(tok, ".") - if len(parts) < 2 { - return time.Now().Add(time.Minute) +func fetchRoleName(ctx context.Context, metaToken string) (string, error) { + roleURL := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/", MetadataHost) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, roleURL, nil) + if err != nil { + return "", err } - data, err := base64.RawURLEncoding.DecodeString(parts[1]) + req.Header.Set("X-aws-ec2-metadata-token", metaToken) + + resp, err := HTTPClient.Do(req) if err != nil { - return time.Now().Add(time.Minute) + return "", err } - var c struct { - Exp int64 `json:"exp"` + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("role name status %s: %s", resp.Status, body) + } + + roleBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + roleName := strings.TrimSpace(string(roleBytes)) + if roleName == "" { + return "", fmt.Errorf("empty role name from IMDS") + } + return roleName, nil +} + +type roleCredentials struct { + Expiration string `json:"Expiration"` + Token string `json:"Token"` +} + +func fetchRoleCredentials(ctx context.Context, metaToken, roleName string) (*roleCredentials, error) { + credsURL := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/%s", MetadataHost, roleName) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, credsURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("X-aws-ec2-metadata-token", metaToken) + + resp, err := HTTPClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("role credentials status %s: %s", resp.Status, body) + } + + var rc roleCredentials + if err := json.NewDecoder(resp.Body).Decode(&rc); err != nil { + return nil, err } - if err := json.Unmarshal(data, &c); err != nil || c.Exp == 0 { - return time.Now().Add(time.Minute) + if rc.Expiration == "" { + return nil, fmt.Errorf("missing expiration in role credentials") } - return time.Unix(c.Exp, 0) + return &rc, nil } func getCachedToken(aud string) (string, time.Time) { diff --git a/app/auth/plugins/aws_oidc/outgoing_test.go b/app/auth/plugins/aws_oidc/outgoing_test.go index 49cb1a4..1d5dbda 100644 --- a/app/auth/plugins/aws_oidc/outgoing_test.go +++ b/app/auth/plugins/aws_oidc/outgoing_test.go @@ -2,30 +2,18 @@ package awsoidc import ( "context" - "encoding/base64" - "fmt" + "encoding/json" "net/http" "net/http/httptest" - "strings" "testing" "time" ) -type testClaims struct { - Exp int64 `json:"exp"` -} - -func makeJWT(exp time.Time) string { - header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none"}`)) - claims := base64.RawURLEncoding.EncodeToString([]byte(`{"exp":` + fmt.Sprintf("%d", exp.Unix()) + `}`)) - return strings.Join([]string{header, claims, "sig"}, ".") -} - func TestAddAuthFetchesAndCachesToken(t *testing.T) { - now := time.Now().Add(2 * time.Minute) - jwt := makeJWT(now) - + expires := time.Now().Add(2 * time.Minute).UTC().Truncate(time.Second) + sessionToken := "sts-session-token" metaToken := "meta123" + roleName := "example-role" aud := "urn:test" var requestCount int @@ -40,15 +28,21 @@ func TestAddAuthFetchesAndCachesToken(t *testing.T) { t.Fatalf("missing TTL header") } w.Write([]byte(metaToken)) - case "/latest/meta-data/iam/security-credentials/oidc": + case "/latest/meta-data/iam/security-credentials/": requestCount++ if got := r.Header.Get("X-aws-ec2-metadata-token"); got != metaToken { t.Fatalf("expected metadata token %q, got %q", metaToken, got) } - if got := r.URL.Query().Get("audience"); got != aud { - t.Fatalf("expected audience %s, got %s", aud, got) + w.Write([]byte(roleName)) + case "/latest/meta-data/iam/security-credentials/" + roleName: + requestCount++ + if got := r.Header.Get("X-aws-ec2-metadata-token"); got != metaToken { + t.Fatalf("expected metadata token %q, got %q", metaToken, got) } - w.Write([]byte(jwt)) + json.NewEncoder(w).Encode(map[string]interface{}{ + "Token": sessionToken, + "Expiration": expires.Format(time.RFC3339), + }) default: t.Fatalf("unexpected path %s", r.URL.Path) } @@ -69,11 +63,11 @@ func TestAddAuthFetchesAndCachesToken(t *testing.T) { t.Fatalf("AddAuth: %v", err) } - if got := req.Header.Get("Authorization"); got != "Bearer "+jwt { + if got := req.Header.Get("Authorization"); got != "Bearer "+sessionToken { t.Fatalf("unexpected header: %s", got) } - if requestCount != 2 { - t.Fatalf("expected 2 metadata requests, got %d", requestCount) + if requestCount != 3 { + t.Fatalf("expected 3 metadata requests, got %d", requestCount) } // Second call should use cache. @@ -81,30 +75,36 @@ func TestAddAuthFetchesAndCachesToken(t *testing.T) { if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil { t.Fatalf("AddAuth second: %v", err) } - if requestCount != 2 { + if requestCount != 3 { t.Fatalf("expected cached token, still %d requests", requestCount) } } func TestExpiresSoonTriggersRefresh(t *testing.T) { - expSoon := time.Now().Add(30 * time.Second) - jwt1 := makeJWT(expSoon) - jwt2 := makeJWT(time.Now().Add(10 * time.Minute)) + expSoon := time.Now().Add(30 * time.Second).UTC().Truncate(time.Second) + expLater := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) metaToken := "meta123" + roleName := "role" + sessionTokens := []string{"first", "second"} aud := "urn:test" - var stage int + var credIndex int srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/latest/api/token": w.Write([]byte(metaToken)) - case "/latest/meta-data/iam/security-credentials/oidc": - if stage == 0 { - w.Write([]byte(jwt1)) - } else { - w.Write([]byte(jwt2)) + case "/latest/meta-data/iam/security-credentials/": + w.Write([]byte(roleName)) + case "/latest/meta-data/iam/security-credentials/" + roleName: + exp := expSoon + if credIndex > 0 { + exp = expLater } - stage++ + json.NewEncoder(w).Encode(map[string]interface{}{ + "Token": sessionTokens[credIndex], + "Expiration": exp.Format(time.RFC3339), + }) + credIndex++ default: t.Fatalf("unexpected path %s", r.URL.Path) } @@ -128,8 +128,8 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil { t.Fatalf("AddAuth second: %v", err) } - if stage != 2 { - t.Fatalf("expected token refresh, stage %d", stage) + if credIndex != 2 { + t.Fatalf("expected token refresh, stage %d", credIndex) } } @@ -141,6 +141,9 @@ func TestErrorResponses(t *testing.T) { case "/latest/api/token": w.WriteHeader(http.StatusBadRequest) w.Write([]byte("no token")) + case "/latest/meta-data/iam/security-credentials/": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("no role")) default: t.Fatalf("unexpected path %s", r.URL.Path) } diff --git a/docs/auth-plugins.md b/docs/auth-plugins.md index cdb8dbb..f55b988 100644 --- a/docs/auth-plugins.md +++ b/docs/auth-plugins.md @@ -111,7 +111,7 @@ outgoing_auth: prefix: "Bearer " # optional (default: "Bearer ") ``` -Retrieves an ID token from the AWS Instance Metadata Service v2 for the provided `audience`, caches it until shortly before expiry, and attaches it to the chosen header on each outgoing request. +Retrieves the IAM role session token from the AWS Instance Metadata Service v2, caches it until shortly before expiry, and attaches it to the chosen header on each outgoing request. --- From b572470e8c163c5bc19bc71007b92519ff28ca34 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Tue, 25 Nov 2025 19:26:54 -0800 Subject: [PATCH 03/12] Rename AWS outgoing plugin to IMDS with alias --- .../{aws_oidc => aws_imds}/outgoing.go | 66 ++++++++++--------- .../{aws_oidc => aws_imds}/outgoing_test.go | 24 +++---- app/auth/plugins/plugins.go | 2 +- docs/auth-plugins.md | 9 +-- 4 files changed, 50 insertions(+), 51 deletions(-) rename app/auth/plugins/{aws_oidc => aws_imds}/outgoing.go (75%) rename app/auth/plugins/{aws_oidc => aws_imds}/outgoing_test.go (90%) diff --git a/app/auth/plugins/aws_oidc/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go similarity index 75% rename from app/auth/plugins/aws_oidc/outgoing.go rename to app/auth/plugins/aws_imds/outgoing.go index a25890f..9b06e96 100644 --- a/app/auth/plugins/aws_oidc/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -1,4 +1,4 @@ -package awsoidc +package awsimds import ( "context" @@ -13,16 +13,15 @@ import ( authplugins "github.com/winhowes/AuthTranslator/app/auth" ) -// awsOIDCParams configures the AWS OIDC plugin. -type awsOIDCParams struct { - Audience string `json:"audience"` - Header string `json:"header"` - Prefix string `json:"prefix"` +// awsIMDSParams configures the AWS IMDS plugin. +type awsIMDSParams struct { + Header string `json:"header"` + Prefix string `json:"prefix"` } -// AWSOIDC fetches the IAM role session token from the AWS Instance Metadata +// AWSIMDS fetches the IAM role session token from the AWS Instance Metadata // Service (IMDSv2) and adds it to outgoing requests. -type AWSOIDC struct{} +type AWSIMDS struct{} // MetadataHost is the base URL for the AWS metadata service. It can be // overridden in tests. @@ -33,28 +32,32 @@ var HTTPClient = &http.Client{Timeout: 5 * time.Second} var tokenCache = struct { sync.Mutex - m map[string]cachedToken -}{m: make(map[string]cachedToken)} + ct cachedToken +}{ct: cachedToken{}} type cachedToken struct { token string exp time.Time } +// AWSOIDC is kept as a backward-compatible alias for configurations still +// referencing the old plugin name. It delegates all behavior to AWSIMDS but +// advertises the legacy `aws_oidc` name. +type AWSOIDC struct{ AWSIMDS } + func (a *AWSOIDC) Name() string { return "aws_oidc" } -func (a *AWSOIDC) RequiredParams() []string { return []string{"audience"} } +func (a *AWSIMDS) Name() string { return "aws_imds" } + +func (a *AWSIMDS) RequiredParams() []string { return nil } -func (a *AWSOIDC) OptionalParams() []string { return []string{"header", "prefix"} } +func (a *AWSIMDS) OptionalParams() []string { return []string{"header", "prefix"} } -func (a *AWSOIDC) ParseParams(m map[string]interface{}) (interface{}, error) { - p, err := authplugins.ParseParams[awsOIDCParams](m) +func (a *AWSIMDS) ParseParams(m map[string]interface{}) (interface{}, error) { + p, err := authplugins.ParseParams[awsIMDSParams](m) if err != nil { return nil, err } - if p.Audience == "" { - return nil, fmt.Errorf("missing audience") - } if p.Header == "" { p.Header = "Authorization" } @@ -64,25 +67,25 @@ func (a *AWSOIDC) ParseParams(m map[string]interface{}) (interface{}, error) { return p, nil } -func (a *AWSOIDC) AddAuth(ctx context.Context, r *http.Request, params interface{}) error { - cfg, ok := params.(*awsOIDCParams) +func (a *AWSIMDS) AddAuth(ctx context.Context, r *http.Request, params interface{}) error { + cfg, ok := params.(*awsIMDSParams) if !ok { return fmt.Errorf("invalid config") } - tok, exp := getCachedToken(cfg.Audience) + tok, exp := getCachedToken() if tok == "" || time.Now().After(exp.Add(-1*time.Minute)) { var err error - tok, exp, err = fetchToken(ctx, cfg.Audience) + tok, exp, err = fetchToken(ctx) if err != nil { return err } - setCachedToken(cfg.Audience, tok, exp) + setCachedToken(tok, exp) } r.Header.Set(cfg.Header, cfg.Prefix+tok) return nil } -func fetchToken(ctx context.Context, aud string) (string, time.Time, error) { +func fetchToken(ctx context.Context) (string, time.Time, error) { metaToken, err := fetchMetadataToken(ctx) if err != nil { return "", time.Time{}, err @@ -197,20 +200,19 @@ func fetchRoleCredentials(ctx context.Context, metaToken, roleName string) (*rol return &rc, nil } -func getCachedToken(aud string) (string, time.Time) { +func getCachedToken() (string, time.Time) { tokenCache.Lock() defer tokenCache.Unlock() - ct, ok := tokenCache.m[aud] - if !ok { - return "", time.Time{} - } - return ct.token, ct.exp + return tokenCache.ct.token, tokenCache.ct.exp } -func setCachedToken(aud, tok string, exp time.Time) { +func setCachedToken(tok string, exp time.Time) { tokenCache.Lock() - tokenCache.m[aud] = cachedToken{token: tok, exp: exp} + tokenCache.ct = cachedToken{token: tok, exp: exp} tokenCache.Unlock() } -func init() { authplugins.RegisterOutgoing(&AWSOIDC{}) } +func init() { + authplugins.RegisterOutgoing(&AWSIMDS{}) + authplugins.RegisterOutgoing(&AWSOIDC{}) +} diff --git a/app/auth/plugins/aws_oidc/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go similarity index 90% rename from app/auth/plugins/aws_oidc/outgoing_test.go rename to app/auth/plugins/aws_imds/outgoing_test.go index 1d5dbda..a7cb3ce 100644 --- a/app/auth/plugins/aws_oidc/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -1,4 +1,4 @@ -package awsoidc +package awsimds import ( "context" @@ -14,7 +14,6 @@ func TestAddAuthFetchesAndCachesToken(t *testing.T) { sessionToken := "sts-session-token" metaToken := "meta123" roleName := "example-role" - aud := "urn:test" var requestCount int srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -51,10 +50,10 @@ func TestAddAuthFetchesAndCachesToken(t *testing.T) { MetadataHost = srv.URL HTTPClient = srv.Client() - tokenCache.m = map[string]cachedToken{} + tokenCache.ct = cachedToken{} - plugin := &AWSOIDC{} - paramsRaw, err := plugin.ParseParams(map[string]interface{}{"audience": aud}) + plugin := &AWSIMDS{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) if err != nil { t.Fatalf("parse params: %v", err) } @@ -86,7 +85,6 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { metaToken := "meta123" roleName := "role" sessionTokens := []string{"first", "second"} - aud := "urn:test" var credIndex int srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -113,10 +111,10 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { MetadataHost = srv.URL HTTPClient = srv.Client() - tokenCache.m = map[string]cachedToken{} + tokenCache.ct = cachedToken{} - plugin := &AWSOIDC{} - paramsRaw, err := plugin.ParseParams(map[string]interface{}{"audience": aud}) + plugin := &AWSIMDS{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) if err != nil { t.Fatalf("parse params: %v", err) } @@ -134,8 +132,6 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { } func TestErrorResponses(t *testing.T) { - aud := "urn:test" - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/latest/api/token": @@ -152,10 +148,10 @@ func TestErrorResponses(t *testing.T) { MetadataHost = srv.URL HTTPClient = srv.Client() - tokenCache.m = map[string]cachedToken{} + tokenCache.ct = cachedToken{} - plugin := &AWSOIDC{} - paramsRaw, err := plugin.ParseParams(map[string]interface{}{"audience": aud}) + plugin := &AWSIMDS{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) if err != nil { t.Fatalf("parse params: %v", err) } diff --git a/app/auth/plugins/plugins.go b/app/auth/plugins/plugins.go index 971ba1c..a9e666c 100644 --- a/app/auth/plugins/plugins.go +++ b/app/auth/plugins/plugins.go @@ -1,7 +1,7 @@ package plugins import ( - _ "github.com/winhowes/AuthTranslator/app/auth/plugins/aws_oidc" + _ "github.com/winhowes/AuthTranslator/app/auth/plugins/aws_imds" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/azure_oidc" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/basic" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/findreplace" diff --git a/docs/auth-plugins.md b/docs/auth-plugins.md index f55b988..3dcf9e5 100644 --- a/docs/auth-plugins.md +++ b/docs/auth-plugins.md @@ -33,7 +33,7 @@ AuthTranslator’s behaviour is extended by **plugins** – small Go packages th | Outbound | `google_oidc` | Attaches a Google identity token from the metadata service. | | Outbound | `gcp_token` | Uses a metadata service access token. | | Outbound | `azure_oidc` | Retrieves an Azure access token from the Instance Metadata Service. | -| Outbound | `aws_oidc` | Retrieves an AWS OIDC token from the Instance Metadata Service (IMDSv2). | +| Outbound | `aws_imds` | Retrieves an AWS IMDS session token from the Instance Metadata Service (IMDSv2). | | Outbound | `hmac_signature` | Computes an HMAC for the request. | | Outbound | `jwt` | Adds a signed JWT to the request. | | Outbound | `mtls` | Sends a client certificate and exposes the CN via header. | @@ -100,19 +100,20 @@ outgoing_auth: Obtains an access token from the Azure Instance Metadata Service for the specified `resource`, caches it, and attaches it to the configured header on each outgoing request. -### Outbound `aws_oidc` +### Outbound `aws_imds` ```yaml outgoing_auth: - - type: aws_oidc + - type: aws_imds params: - audience: urn:example header: Authorization # optional (default: Authorization) prefix: "Bearer " # optional (default: "Bearer ") ``` Retrieves the IAM role session token from the AWS Instance Metadata Service v2, caches it until shortly before expiry, and attaches it to the chosen header on each outgoing request. +> **Note:** The legacy type name `aws_oidc` remains supported for backward compatibility but now resolves to the IMDS session token flow described above. New configurations should prefer `aws_imds` to reflect the actual authentication mechanism. + --- ## Writing your own plugin From 16d04575364bf97f27e2bfae373b3ee455159f70 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Tue, 25 Nov 2025 19:31:39 -0800 Subject: [PATCH 04/12] Fix gofmt on plugin registry --- app/auth/plugins/plugins.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/auth/plugins/plugins.go b/app/auth/plugins/plugins.go index a9e666c..12dcc16 100644 --- a/app/auth/plugins/plugins.go +++ b/app/auth/plugins/plugins.go @@ -1,7 +1,7 @@ package plugins import ( - _ "github.com/winhowes/AuthTranslator/app/auth/plugins/aws_imds" + _ "github.com/winhowes/AuthTranslator/app/auth/plugins/aws_imds" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/azure_oidc" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/basic" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/findreplace" From a30283a554a0e5785f38d8e51d3a3f28e09babb0 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Tue, 25 Nov 2025 19:50:39 -0800 Subject: [PATCH 05/12] Sign AWS IMDS requests with SigV4 credentials --- app/auth/plugins/aws_imds/outgoing.go | 280 +++++++++++++++++---- app/auth/plugins/aws_imds/outgoing_test.go | 85 ++++--- docs/auth-plugins.md | 6 +- 3 files changed, 288 insertions(+), 83 deletions(-) diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go index 9b06e96..d8edaa6 100644 --- a/app/auth/plugins/aws_imds/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -1,11 +1,18 @@ package awsimds import ( + "bytes" "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "io" "net/http" + "net/url" + "path" + "sort" "strings" "sync" "time" @@ -15,8 +22,8 @@ import ( // awsIMDSParams configures the AWS IMDS plugin. type awsIMDSParams struct { - Header string `json:"header"` - Prefix string `json:"prefix"` + Region string `json:"region"` + Service string `json:"service"` } // AWSIMDS fetches the IAM role session token from the AWS Instance Metadata @@ -30,15 +37,7 @@ var MetadataHost = "http://169.254.169.254" // HTTPClient is used for all metadata requests. var HTTPClient = &http.Client{Timeout: 5 * time.Second} -var tokenCache = struct { - sync.Mutex - ct cachedToken -}{ct: cachedToken{}} - -type cachedToken struct { - token string - exp time.Time -} +var nowFunc = time.Now // AWSOIDC is kept as a backward-compatible alias for configurations still // referencing the old plugin name. It delegates all behavior to AWSIMDS but @@ -51,20 +50,10 @@ func (a *AWSIMDS) Name() string { return "aws_imds" } func (a *AWSIMDS) RequiredParams() []string { return nil } -func (a *AWSIMDS) OptionalParams() []string { return []string{"header", "prefix"} } +func (a *AWSIMDS) OptionalParams() []string { return []string{"region", "service"} } func (a *AWSIMDS) ParseParams(m map[string]interface{}) (interface{}, error) { - p, err := authplugins.ParseParams[awsIMDSParams](m) - if err != nil { - return nil, err - } - if p.Header == "" { - p.Header = "Authorization" - } - if p.Prefix == "" { - p.Prefix = "Bearer " - } - return p, nil + return authplugins.ParseParams[awsIMDSParams](m) } func (a *AWSIMDS) AddAuth(ctx context.Context, r *http.Request, params interface{}) error { @@ -72,45 +61,48 @@ func (a *AWSIMDS) AddAuth(ctx context.Context, r *http.Request, params interface if !ok { return fmt.Errorf("invalid config") } - tok, exp := getCachedToken() - if tok == "" || time.Now().After(exp.Add(-1*time.Minute)) { + creds, exp := getCachedCreds() + if creds == nil || nowFunc().After(exp.Add(-1*time.Minute)) { var err error - tok, exp, err = fetchToken(ctx) + creds, exp, err = fetchCredentials(ctx) if err != nil { return err } - setCachedToken(tok, exp) + setCachedCreds(creds, exp) } - r.Header.Set(cfg.Header, cfg.Prefix+tok) - return nil + region, service, err := determineRegionService(r.URL.Host, cfg) + if err != nil { + return err + } + return signRequest(r, region, service, creds) } -func fetchToken(ctx context.Context) (string, time.Time, error) { +func fetchCredentials(ctx context.Context) (*roleCredentials, time.Time, error) { metaToken, err := fetchMetadataToken(ctx) if err != nil { - return "", time.Time{}, err + return nil, time.Time{}, err } roleName, err := fetchRoleName(ctx, metaToken) if err != nil { - return "", time.Time{}, err + return nil, time.Time{}, err } credentials, err := fetchRoleCredentials(ctx, metaToken, roleName) if err != nil { - return "", time.Time{}, err + return nil, time.Time{}, err } - if credentials.Token == "" { - return "", time.Time{}, fmt.Errorf("empty session token from IMDS for role %s", roleName) + if credentials.AccessKeyID == "" || credentials.SecretAccessKey == "" || credentials.Token == "" { + return nil, time.Time{}, fmt.Errorf("incomplete credentials from IMDS for role %s", roleName) } exp, err := time.Parse(time.RFC3339, credentials.Expiration) if err != nil { - return "", time.Time{}, fmt.Errorf("parse expiration: %w", err) + return nil, time.Time{}, fmt.Errorf("parse expiration: %w", err) } - return credentials.Token, exp, nil + return credentials, exp, nil } func fetchMetadataToken(ctx context.Context) (string, error) { @@ -168,8 +160,10 @@ func fetchRoleName(ctx context.Context, metaToken string) (string, error) { } type roleCredentials struct { - Expiration string `json:"Expiration"` - Token string `json:"Token"` + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string `json:"SecretAccessKey"` + Token string `json:"Token"` + Expiration string `json:"Expiration"` } func fetchRoleCredentials(ctx context.Context, metaToken, roleName string) (*roleCredentials, error) { @@ -200,16 +194,208 @@ func fetchRoleCredentials(ctx context.Context, metaToken, roleName string) (*rol return &rc, nil } -func getCachedToken() (string, time.Time) { - tokenCache.Lock() - defer tokenCache.Unlock() - return tokenCache.ct.token, tokenCache.ct.exp +type cachedCreds struct { + creds *roleCredentials + exp time.Time +} + +var credsCache = struct { + sync.Mutex + cc cachedCreds +}{cc: cachedCreds{}} + +func getCachedCreds() (*roleCredentials, time.Time) { + credsCache.Lock() + defer credsCache.Unlock() + return credsCache.cc.creds, credsCache.cc.exp +} + +func setCachedCreds(creds *roleCredentials, exp time.Time) { + credsCache.Lock() + credsCache.cc = cachedCreds{creds: creds, exp: exp} + credsCache.Unlock() +} + +func determineRegionService(host string, cfg *awsIMDSParams) (string, string, error) { + region := strings.TrimSpace(cfg.Region) + service := strings.TrimSpace(cfg.Service) + + if region != "" && service != "" { + return region, service, nil + } + + host = strings.Split(host, ":")[0] // strip port if present + parts := strings.Split(host, ".") + if len(parts) >= 4 && parts[len(parts)-2] == "amazonaws" { + if service == "" { + service = parts[0] + } + if region == "" { + region = parts[1] + } + } + + if region == "" || service == "" { + return "", "", fmt.Errorf("aws_imds requires region and service; set params or use standard AWS hostname") + } + + return region, service, nil +} + +func signRequest(r *http.Request, region, service string, creds *roleCredentials) error { + now := nowFunc().UTC() + amzDate := now.Format("20060102T150405Z") + dateStamp := now.Format("20060102") + + if r.Header == nil { + r.Header = http.Header{} + } + + host := r.Host + if host == "" && r.URL != nil { + host = r.URL.Host + } + if host == "" { + return fmt.Errorf("request host is required for signing") + } + r.Header.Set("Host", host) + r.Header.Set("X-Amz-Date", amzDate) + r.Header.Set("X-Amz-Security-Token", creds.Token) + + body, err := readBody(r) + if err != nil { + return err + } + payloadHash := hashSHA256Hex(body) + r.Header.Set("X-Amz-Content-Sha256", payloadHash) + + signedHeaders, canonicalHeaders := canonicalizeHeaders(r.Header) + canonicalQuery := canonicalizeQuery(r.URL) + canonicalURI := canonicalURI(r.URL) + canonicalRequest := strings.Join([]string{ + r.Method, + canonicalURI, + canonicalQuery, + canonicalHeaders, + "", + signedHeaders, + payloadHash, + }, "\n") + + credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, service) + stringToSign := strings.Join([]string{ + "AWS4-HMAC-SHA256", + amzDate, + credentialScope, + hashSHA256Hex([]byte(canonicalRequest)), + }, "\n") + + signingKey := buildSigningKey(creds.SecretAccessKey, dateStamp, region, service) + signature := hex.EncodeToString(hmacSHA256(signingKey, stringToSign)) + + authHeader := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", creds.AccessKeyID, credentialScope, signedHeaders, signature) + r.Header.Set("Authorization", authHeader) + + return nil +} + +func readBody(r *http.Request) ([]byte, error) { + if r.Body == nil { + return []byte{}, nil + } + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + r.Body = io.NopCloser(bytes.NewReader(body)) + r.ContentLength = int64(len(body)) + return body, nil +} + +func canonicalizeHeaders(h http.Header) (string, string) { + lowerVals := make(map[string][]string, len(h)) + for k, v := range h { + lowerVals[strings.ToLower(k)] = v + } + keys := make([]string, 0, len(lowerVals)) + for k := range lowerVals { + keys = append(keys, k) + } + sort.Strings(keys) + var canonical strings.Builder + for _, k := range keys { + values := lowerVals[k] + for i := range values { + values[i] = strings.Join(strings.Fields(values[i]), " ") + } + canonical.WriteString(k) + canonical.WriteString(":") + canonical.WriteString(strings.Join(values, ",")) + canonical.WriteString("\n") + } + return strings.Join(keys, ";"), canonical.String() +} + +func canonicalizeQuery(u *url.URL) string { + if u == nil { + return "" + } + values, _ := url.ParseQuery(u.RawQuery) + if len(values) == 0 { + return "" + } + keys := make([]string, 0, len(values)) + for k := range values { + keys = append(keys, k) + } + sort.Strings(keys) + var parts []string + for _, k := range keys { + vals := values[k] + sort.Strings(vals) + for _, v := range vals { + parts = append(parts, fmt.Sprintf("%s=%s", escapeQueryComponent(k), escapeQueryComponent(v))) + } + } + return strings.Join(parts, "&") +} + +func escapeQueryComponent(v string) string { + escaped := url.QueryEscape(v) + escaped = strings.ReplaceAll(escaped, "+", "%20") + escaped = strings.ReplaceAll(escaped, "*", "%2A") + escaped = strings.ReplaceAll(escaped, "%7E", "~") + return escaped +} + +func canonicalURI(u *url.URL) string { + if u == nil { + return "/" + } + uri := u.EscapedPath() + if uri == "" { + uri = "/" + } + return path.Clean(uri) +} + +func hashSHA256Hex(b []byte) string { + h := sha256.Sum256(b) + return hex.EncodeToString(h[:]) +} + +func hmacSHA256(key []byte, msg string) []byte { + h := hmac.New(sha256.New, key) + h.Write([]byte(msg)) + return h.Sum(nil) } -func setCachedToken(tok string, exp time.Time) { - tokenCache.Lock() - tokenCache.ct = cachedToken{token: tok, exp: exp} - tokenCache.Unlock() +func buildSigningKey(secret, dateStamp, region, service string) []byte { + kDate := hmacSHA256([]byte("AWS4"+secret), dateStamp) + kRegion := hmacSHA256(kDate, region) + kService := hmacSHA256(kRegion, service) + kSigning := hmacSHA256(kService, "aws4_request") + return kSigning } func init() { diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go index a7cb3ce..4863363 100644 --- a/app/auth/plugins/aws_imds/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -5,43 +5,35 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "strings" "testing" "time" ) -func TestAddAuthFetchesAndCachesToken(t *testing.T) { +func TestAddAuthFetchesAndSigns(t *testing.T) { expires := time.Now().Add(2 * time.Minute).UTC().Truncate(time.Second) - sessionToken := "sts-session-token" metaToken := "meta123" roleName := "example-role" + creds := map[string]interface{}{ + "AccessKeyId": "AKIDEXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "Token": "sts-session-token", + "Expiration": expires.Format(time.RFC3339), + } var requestCount int + fixedNow := time.Date(2023, 1, 2, 15, 4, 5, 0, time.UTC) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/latest/api/token": requestCount++ - if r.Method != http.MethodPut { - t.Fatalf("expected PUT for token, got %s", r.Method) - } - if ttl := r.Header.Get("X-aws-ec2-metadata-token-ttl-seconds"); ttl == "" { - t.Fatalf("missing TTL header") - } w.Write([]byte(metaToken)) case "/latest/meta-data/iam/security-credentials/": requestCount++ - if got := r.Header.Get("X-aws-ec2-metadata-token"); got != metaToken { - t.Fatalf("expected metadata token %q, got %q", metaToken, got) - } w.Write([]byte(roleName)) case "/latest/meta-data/iam/security-credentials/" + roleName: requestCount++ - if got := r.Header.Get("X-aws-ec2-metadata-token"); got != metaToken { - t.Fatalf("expected metadata token %q, got %q", metaToken, got) - } - json.NewEncoder(w).Encode(map[string]interface{}{ - "Token": sessionToken, - "Expiration": expires.Format(time.RFC3339), - }) + json.NewEncoder(w).Encode(creds) default: t.Fatalf("unexpected path %s", r.URL.Path) } @@ -50,32 +42,53 @@ func TestAddAuthFetchesAndCachesToken(t *testing.T) { MetadataHost = srv.URL HTTPClient = srv.Client() - tokenCache.ct = cachedToken{} + credsCache.cc = cachedCreds{} + prevNow := nowFunc + nowFunc = func() time.Time { return fixedNow } + defer func() { nowFunc = prevNow }() plugin := &AWSIMDS{} paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) if err != nil { t.Fatalf("parse params: %v", err) } - req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + req, _ := http.NewRequest(http.MethodGet, "https://s3.us-west-2.amazonaws.com/example", nil) + if err := plugin.AddAuth(context.Background(), req, paramsRaw); err != nil { t.Fatalf("AddAuth: %v", err) } - if got := req.Header.Get("Authorization"); got != "Bearer "+sessionToken { - t.Fatalf("unexpected header: %s", got) + if got := req.Header.Get("X-Amz-Security-Token"); got != creds["Token"] { + t.Fatalf("missing security token header: %s", got) + } + authz := req.Header.Get("Authorization") + if !strings.HasPrefix(authz, "AWS4-HMAC-SHA256 ") { + t.Fatalf("expected SigV4 auth header, got %s", authz) + } + if !strings.Contains(authz, "Credential=AKIDEXAMPLE/20230102/us-west-2/s3/aws4_request") { + t.Fatalf("unexpected credential scope: %s", authz) + } + if !strings.Contains(authz, "SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token") { + t.Fatalf("missing signed headers: %s", authz) } + if !strings.Contains(authz, "Signature=") { + t.Fatalf("missing signature: %s", authz) + } + if got := req.Header.Get("X-Amz-Content-Sha256"); got != "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" { + t.Fatalf("unexpected payload hash: %s", got) + } + if requestCount != 3 { t.Fatalf("expected 3 metadata requests, got %d", requestCount) } // Second call should use cache. - req2, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + req2, _ := http.NewRequest(http.MethodGet, "https://s3.us-west-2.amazonaws.com/example", nil) if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil { t.Fatalf("AddAuth second: %v", err) } if requestCount != 3 { - t.Fatalf("expected cached token, still %d requests", requestCount) + t.Fatalf("expected cached credentials reuse, got %d metadata calls", requestCount) } } @@ -84,8 +97,7 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { expLater := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) metaToken := "meta123" roleName := "role" - sessionTokens := []string{"first", "second"} - var credIndex int + credIndex := 0 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { @@ -99,8 +111,10 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { exp = expLater } json.NewEncoder(w).Encode(map[string]interface{}{ - "Token": sessionTokens[credIndex], - "Expiration": exp.Format(time.RFC3339), + "AccessKeyId": "AKID", + "SecretAccessKey": "SECRET", + "Token": []string{"first", "second"}[credIndex], + "Expiration": exp.Format(time.RFC3339), }) credIndex++ default: @@ -111,23 +125,28 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { MetadataHost = srv.URL HTTPClient = srv.Client() - tokenCache.ct = cachedToken{} + credsCache.cc = cachedCreds{} + prevNow := nowFunc + current := time.Now() + nowFunc = func() time.Time { return current } + defer func() { nowFunc = prevNow }() plugin := &AWSIMDS{} paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) if err != nil { t.Fatalf("parse params: %v", err) } - req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + req, _ := http.NewRequest(http.MethodGet, "https://s3.us-east-1.amazonaws.com", nil) if err := plugin.AddAuth(context.Background(), req, paramsRaw); err != nil { t.Fatalf("AddAuth: %v", err) } - req2, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + current = current.Add(2 * time.Minute) + req2, _ := http.NewRequest(http.MethodGet, "https://s3.us-east-1.amazonaws.com", nil) if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil { t.Fatalf("AddAuth second: %v", err) } if credIndex != 2 { - t.Fatalf("expected token refresh, stage %d", credIndex) + t.Fatalf("expected credential refresh, got %d", credIndex) } } @@ -148,7 +167,7 @@ func TestErrorResponses(t *testing.T) { MetadataHost = srv.URL HTTPClient = srv.Client() - tokenCache.ct = cachedToken{} + credsCache.cc = cachedCreds{} plugin := &AWSIMDS{} paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) diff --git a/docs/auth-plugins.md b/docs/auth-plugins.md index 3dcf9e5..da59516 100644 --- a/docs/auth-plugins.md +++ b/docs/auth-plugins.md @@ -106,11 +106,11 @@ configured header on each outgoing request. outgoing_auth: - type: aws_imds params: - header: Authorization # optional (default: Authorization) - prefix: "Bearer " # optional (default: "Bearer ") + region: us-west-2 # optional, inferred from AWS hostname if omitted + service: s3 # optional, inferred from AWS hostname if omitted ``` -Retrieves the IAM role session token from the AWS Instance Metadata Service v2, caches it until shortly before expiry, and attaches it to the chosen header on each outgoing request. +Retrieves temporary IAM role credentials from the AWS Instance Metadata Service v2, caches them until shortly before expiry, and applies SigV4 signing (including `X-Amz-Security-Token`) to each outgoing request. If the upstream hostname follows the standard `service.region.amazonaws.com` pattern, the plugin auto‑discovers the service and region; otherwise set them explicitly. > **Note:** The legacy type name `aws_oidc` remains supported for backward compatibility but now resolves to the IMDS session token flow described above. New configurations should prefer `aws_imds` to reflect the actual authentication mechanism. From d199d5820f801aafd6f82afbde82b5eb346ba88b Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Tue, 25 Nov 2025 19:57:44 -0800 Subject: [PATCH 06/12] Improve AWS IMDS host parsing for SigV4 --- app/auth/plugins/aws_imds/outgoing.go | 16 +++++++++++----- app/auth/plugins/aws_imds/outgoing_test.go | 11 +++++++++++ 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go index d8edaa6..671eca6 100644 --- a/app/auth/plugins/aws_imds/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -227,11 +227,17 @@ func determineRegionService(host string, cfg *awsIMDSParams) (string, string, er host = strings.Split(host, ":")[0] // strip port if present parts := strings.Split(host, ".") if len(parts) >= 4 && parts[len(parts)-2] == "amazonaws" { - if service == "" { - service = parts[0] - } - if region == "" { - region = parts[1] + // Use the right-most service and region portions to support hosts with + // additional labels (e.g., bucket.s3.us-west-2.amazonaws.com). + serviceIdx := len(parts) - 4 + regionIdx := len(parts) - 3 + if serviceIdx >= 0 && regionIdx >= 0 { + if service == "" { + service = parts[serviceIdx] + } + if region == "" { + region = parts[regionIdx] + } } } diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go index 4863363..a3cdabf 100644 --- a/app/auth/plugins/aws_imds/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -150,6 +150,17 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { } } +func TestDetermineRegionServiceWithResourcePrefix(t *testing.T) { + cfg := &awsIMDSParams{} + region, service, err := determineRegionService("mybucket.s3.us-west-2.amazonaws.com", cfg) + if err != nil { + t.Fatalf("determineRegionService: %v", err) + } + if region != "us-west-2" || service != "s3" { + t.Fatalf("unexpected derived values region=%s service=%s", region, service) + } +} + func TestErrorResponses(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { From d8b5ec904cc5a8b96c35258bc293e4a465045242 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Wed, 26 Nov 2025 03:55:15 -0800 Subject: [PATCH 07/12] Preserve canonical URI slashes for AWS IMDS signing --- app/auth/plugins/aws_imds/outgoing.go | 6 ++++-- app/auth/plugins/aws_imds/outgoing_test.go | 24 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go index 671eca6..914c8ad 100644 --- a/app/auth/plugins/aws_imds/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -11,7 +11,6 @@ import ( "io" "net/http" "net/url" - "path" "sort" "strings" "sync" @@ -382,7 +381,10 @@ func canonicalURI(u *url.URL) string { if uri == "" { uri = "/" } - return path.Clean(uri) + if !strings.HasPrefix(uri, "/") { + uri = "/" + uri + } + return uri } func hashSHA256Hex(b []byte) string { diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go index a3cdabf..0001027 100644 --- a/app/auth/plugins/aws_imds/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -190,3 +191,26 @@ func TestErrorResponses(t *testing.T) { t.Fatalf("expected error from metadata token fetch") } } + +func TestCanonicalURIPreservesTrailingAndRepeatedSlashes(t *testing.T) { + cases := []struct { + name string + path string + want string + }{ + {name: "empty", path: "", want: "/"}, + {name: "root", path: "/", want: "/"}, + {name: "trailing slash", path: "/foo/bar/", want: "/foo/bar/"}, + {name: "repeated slashes", path: "/foo//bar//baz", want: "/foo//bar//baz"}, + {name: "no leading slash", path: "foo/bar", want: "/foo/bar"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + u := &url.URL{Path: tc.path} + if got := canonicalURI(u); got != tc.want { + t.Fatalf("canonicalURI(%q) = %q, want %q", tc.path, got, tc.want) + } + }) + } +} From 3bcdac62ade8284bdc2636c0f389ead737498822 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Wed, 26 Nov 2025 04:05:34 -0800 Subject: [PATCH 08/12] Handle dualstack hostnames in AWS IMDS detection --- app/auth/plugins/aws_imds/outgoing.go | 6 +++++- app/auth/plugins/aws_imds/outgoing_test.go | 11 +++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go index 914c8ad..70dbc3f 100644 --- a/app/auth/plugins/aws_imds/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -232,7 +232,11 @@ func determineRegionService(host string, cfg *awsIMDSParams) (string, string, er regionIdx := len(parts) - 3 if serviceIdx >= 0 && regionIdx >= 0 { if service == "" { - service = parts[serviceIdx] + candidate := parts[serviceIdx] + if candidate == "dualstack" && serviceIdx > 0 { + candidate = parts[serviceIdx-1] + } + service = candidate } if region == "" { region = parts[regionIdx] diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go index 0001027..24fca57 100644 --- a/app/auth/plugins/aws_imds/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -162,6 +162,17 @@ func TestDetermineRegionServiceWithResourcePrefix(t *testing.T) { } } +func TestDetermineRegionServiceDualstack(t *testing.T) { + cfg := &awsIMDSParams{} + region, service, err := determineRegionService("s3.dualstack.us-east-1.amazonaws.com", cfg) + if err != nil { + t.Fatalf("determineRegionService: %v", err) + } + if region != "us-east-1" || service != "s3" { + t.Fatalf("unexpected derived values region=%s service=%s", region, service) + } +} + func TestErrorResponses(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { From 8ea06269947cf93f977b639e6689cf688fae6028 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Wed, 26 Nov 2025 04:17:49 -0800 Subject: [PATCH 09/12] Normalize dot segments in AWS IMDS canonical URI --- app/auth/plugins/aws_imds/outgoing.go | 59 +++++++++++++++++++++- app/auth/plugins/aws_imds/outgoing_test.go | 3 ++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go index 70dbc3f..11b5671 100644 --- a/app/auth/plugins/aws_imds/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -388,7 +388,64 @@ func canonicalURI(u *url.URL) string { if !strings.HasPrefix(uri, "/") { uri = "/" + uri } - return uri + return removeDotSegments(uri) +} + +func removeDotSegments(path string) string { + if path == "" { + return "" + } + input := path + output := "" + for len(input) > 0 { + switch { + case strings.HasPrefix(input, "../"): + input = input[3:] + case strings.HasPrefix(input, "./"): + input = input[2:] + case strings.HasPrefix(input, "/./"): + input = "/" + input[3:] + case input == "/.": + input = "/" + case strings.HasPrefix(input, "/../"): + input = "/" + input[4:] + output = removeLastSegment(output) + case input == "/..": + input = "/" + output = removeLastSegment(output) + case input == "." || input == "..": + input = "" + default: + var segment string + if strings.HasPrefix(input, "/") { + if idx := strings.Index(input[1:], "/"); idx != -1 { + segment = input[:idx+1] + input = input[idx+1:] + } else { + segment = input + input = "" + } + } else { + if idx := strings.IndexByte(input, '/'); idx != -1 { + segment = input[:idx] + input = input[idx:] + } else { + segment = input + input = "" + } + } + output += segment + } + } + return output +} + +func removeLastSegment(path string) string { + idx := strings.LastIndex(path, "/") + if idx == -1 { + return "" + } + return path[:idx] } func hashSHA256Hex(b []byte) string { diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go index 24fca57..5f507bf 100644 --- a/app/auth/plugins/aws_imds/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -213,6 +213,9 @@ func TestCanonicalURIPreservesTrailingAndRepeatedSlashes(t *testing.T) { {name: "root", path: "/", want: "/"}, {name: "trailing slash", path: "/foo/bar/", want: "/foo/bar/"}, {name: "repeated slashes", path: "/foo//bar//baz", want: "/foo//bar//baz"}, + {name: "dot segment", path: "/foo/./bar", want: "/foo/bar"}, + {name: "parent segment", path: "/foo/../bar", want: "/bar"}, + {name: "parent with trailing", path: "/foo/bar/../", want: "/foo/"}, {name: "no leading slash", path: "foo/bar", want: "/foo/bar"}, } From 50ce3d0ec5f9adb0494b62d5fdafeded40e6d582 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Wed, 26 Nov 2025 04:29:04 -0800 Subject: [PATCH 10/12] Disable proxies for AWS IMDS metadata client --- app/auth/plugins/aws_imds/outgoing.go | 15 +++++++++- app/auth/plugins/aws_imds/outgoing_test.go | 32 ++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go index 11b5671..4c09acb 100644 --- a/app/auth/plugins/aws_imds/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -34,7 +34,10 @@ type AWSIMDS struct{} var MetadataHost = "http://169.254.169.254" // HTTPClient is used for all metadata requests. -var HTTPClient = &http.Client{Timeout: 5 * time.Second} +var HTTPClient = &http.Client{ + Timeout: 5 * time.Second, + Transport: defaultIMDSTransport(), +} var nowFunc = time.Now @@ -448,6 +451,16 @@ func removeLastSegment(path string) string { return path[:idx] } +func defaultIMDSTransport() *http.Transport { + if t, ok := http.DefaultTransport.(*http.Transport); ok { + clone := t.Clone() + clone.Proxy = nil + return clone + } + + return &http.Transport{Proxy: nil} +} + func hashSHA256Hex(b []byte) string { h := sha256.Sum256(b) return hex.EncodeToString(h[:]) diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go index 5f507bf..a371686 100644 --- a/app/auth/plugins/aws_imds/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -41,9 +41,16 @@ func TestAddAuthFetchesAndSigns(t *testing.T) { })) defer srv.Close() + prevHost := MetadataHost + prevClient := HTTPClient MetadataHost = srv.URL HTTPClient = srv.Client() credsCache.cc = cachedCreds{} + defer func() { + MetadataHost = prevHost + HTTPClient = prevClient + credsCache.cc = cachedCreds{} + }() prevNow := nowFunc nowFunc = func() time.Time { return fixedNow } defer func() { nowFunc = prevNow }() @@ -124,9 +131,16 @@ func TestExpiresSoonTriggersRefresh(t *testing.T) { })) defer srv.Close() + prevHost := MetadataHost + prevClient := HTTPClient MetadataHost = srv.URL HTTPClient = srv.Client() credsCache.cc = cachedCreds{} + defer func() { + MetadataHost = prevHost + HTTPClient = prevClient + credsCache.cc = cachedCreds{} + }() prevNow := nowFunc current := time.Now() nowFunc = func() time.Time { return current } @@ -188,9 +202,16 @@ func TestErrorResponses(t *testing.T) { })) defer srv.Close() + prevHost := MetadataHost + prevClient := HTTPClient MetadataHost = srv.URL HTTPClient = srv.Client() credsCache.cc = cachedCreds{} + defer func() { + MetadataHost = prevHost + HTTPClient = prevClient + credsCache.cc = cachedCreds{} + }() plugin := &AWSIMDS{} paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) @@ -228,3 +249,14 @@ func TestCanonicalURIPreservesTrailingAndRepeatedSlashes(t *testing.T) { }) } } + +func TestDefaultHTTPClientDisablesProxy(t *testing.T) { + client := HTTPClient + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("unexpected transport type %T", client.Transport) + } + if transport.Proxy != nil { + t.Fatalf("expected proxy to be disabled, got %v", transport.Proxy) + } +} From c9f062f668b111a1fa43ce2339bc67eb21167759 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Wed, 26 Nov 2025 18:31:24 -0800 Subject: [PATCH 11/12] Fix AWS IMDS canonical request formatting --- app/auth/plugins/aws_imds/outgoing.go | 1 - app/auth/plugins/aws_imds/outgoing_test.go | 30 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go index 4c09acb..be68d97 100644 --- a/app/auth/plugins/aws_imds/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -289,7 +289,6 @@ func signRequest(r *http.Request, region, service string, creds *roleCredentials canonicalURI, canonicalQuery, canonicalHeaders, - "", signedHeaders, payloadHash, }, "\n") diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go index a371686..0a3282d 100644 --- a/app/auth/plugins/aws_imds/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -2,7 +2,9 @@ package awsimds import ( "context" + "encoding/hex" "encoding/json" + "fmt" "net/http" "net/http/httptest" "net/url" @@ -86,6 +88,34 @@ func TestAddAuthFetchesAndSigns(t *testing.T) { t.Fatalf("unexpected payload hash: %s", got) } + // Verify the canonical request and signature are fully deterministic. + headers := req.Header.Clone() + headers.Del("Authorization") + signedHeaders, canonicalHeaders := canonicalizeHeaders(headers) + canonicalQuery := canonicalizeQuery(req.URL) + canonicalURI := canonicalURI(req.URL) + canonicalRequest := strings.Join([]string{ + req.Method, + canonicalURI, + canonicalQuery, + canonicalHeaders, + signedHeaders, + req.Header.Get("X-Amz-Content-Sha256"), + }, "\n") + credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", fixedNow.Format("20060102"), "us-west-2", "s3") + stringToSign := strings.Join([]string{ + "AWS4-HMAC-SHA256", + req.Header.Get("X-Amz-Date"), + credentialScope, + hashSHA256Hex([]byte(canonicalRequest)), + }, "\n") + signingKey := buildSigningKey(creds["SecretAccessKey"].(string), fixedNow.Format("20060102"), "us-west-2", "s3") + expectedSig := hex.EncodeToString(hmacSHA256(signingKey, stringToSign)) + expectedAuth := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", creds["AccessKeyId"], credentialScope, signedHeaders, expectedSig) + if authz != expectedAuth { + t.Fatalf("unexpected authorization header:\nwant %s\ngot %s", expectedAuth, authz) + } + if requestCount != 3 { t.Fatalf("expected 3 metadata requests, got %d", requestCount) } From 290e9e8633d6f384f6813b00276cfeabe24b26d2 Mon Sep 17 00:00:00 2001 From: Winston Howes Date: Wed, 26 Nov 2025 20:32:43 -0800 Subject: [PATCH 12/12] Close request bodies when signing AWS IMDS requests --- app/auth/plugins/aws_imds/outgoing.go | 7 ++- app/auth/plugins/aws_imds/outgoing_test.go | 67 ++++++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go index be68d97..c34432b 100644 --- a/app/auth/plugins/aws_imds/outgoing.go +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -314,10 +314,15 @@ func readBody(r *http.Request) ([]byte, error) { if r.Body == nil { return []byte{}, nil } - body, err := io.ReadAll(r.Body) + origBody := r.Body + body, err := io.ReadAll(origBody) + closeErr := origBody.Close() if err != nil { return nil, err } + if closeErr != nil { + return nil, closeErr + } r.Body = io.NopCloser(bytes.NewReader(body)) r.ContentLength = int64(len(body)) return body, nil diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go index 0a3282d..d488e40 100644 --- a/app/auth/plugins/aws_imds/outgoing_test.go +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -4,7 +4,9 @@ import ( "context" "encoding/hex" "encoding/json" + "errors" "fmt" + "io" "net/http" "net/http/httptest" "net/url" @@ -13,6 +15,27 @@ import ( "time" ) +type closingBuffer struct { + data []byte + readIdx int + closed bool + closeErr error +} + +func (c *closingBuffer) Read(p []byte) (int, error) { + if c.readIdx >= len(c.data) { + return 0, io.EOF + } + n := copy(p, c.data[c.readIdx:]) + c.readIdx += n + return n, nil +} + +func (c *closingBuffer) Close() error { + c.closed = true + return c.closeErr +} + func TestAddAuthFetchesAndSigns(t *testing.T) { expires := time.Now().Add(2 * time.Minute).UTC().Truncate(time.Second) metaToken := "meta123" @@ -254,6 +277,50 @@ func TestErrorResponses(t *testing.T) { } } +func TestReadBodyClosesOriginal(t *testing.T) { + recorder := &closingBuffer{data: []byte("payload")} + req, _ := http.NewRequest(http.MethodPost, "http://example.com", recorder) + + body, err := readBody(req) + if err != nil { + t.Fatalf("readBody: %v", err) + } + if !recorder.closed { + t.Fatalf("expected original body to be closed") + } + if string(body) != "payload" { + t.Fatalf("unexpected body contents: %s", string(body)) + } + if req.Body == recorder { + t.Fatalf("expected request body to be replaced") + } + reread, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("reading replaced body: %v", err) + } + if string(reread) != "payload" { + t.Fatalf("unexpected reread body: %s", string(reread)) + } + if req.ContentLength != int64(len(body)) { + t.Fatalf("unexpected content length: %d", req.ContentLength) + } +} + +func TestReadBodyCloseError(t *testing.T) { + recorder := &closingBuffer{data: []byte("payload"), closeErr: errors.New("close failure")} + req, _ := http.NewRequest(http.MethodPost, "http://example.com", recorder) + + if _, err := readBody(req); err == nil { + t.Fatalf("expected close error") + } + if !recorder.closed { + t.Fatalf("expected original body to be closed even on error") + } + if req.Body != recorder { + t.Fatalf("expected body not to be replaced on error") + } +} + func TestCanonicalURIPreservesTrailingAndRepeatedSlashes(t *testing.T) { cases := []struct { name string