diff --git a/app/auth/plugins/aws_imds/outgoing.go b/app/auth/plugins/aws_imds/outgoing.go new file mode 100644 index 0000000..c34432b --- /dev/null +++ b/app/auth/plugins/aws_imds/outgoing.go @@ -0,0 +1,490 @@ +package awsimds + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "sort" + "strings" + "sync" + "time" + + authplugins "github.com/winhowes/AuthTranslator/app/auth" +) + +// awsIMDSParams configures the AWS IMDS plugin. +type awsIMDSParams struct { + Region string `json:"region"` + Service string `json:"service"` +} + +// AWSIMDS fetches the IAM role session token from the AWS Instance Metadata +// Service (IMDSv2) and adds it to outgoing requests. +type AWSIMDS struct{} + +// MetadataHost is the base URL for the AWS metadata service. It can be +// overridden in tests. +var MetadataHost = "http://169.254.169.254" + +// HTTPClient is used for all metadata requests. +var HTTPClient = &http.Client{ + Timeout: 5 * time.Second, + Transport: defaultIMDSTransport(), +} + +var nowFunc = time.Now + +// AWSOIDC is kept as a backward-compatible alias for configurations still +// referencing the old plugin name. It delegates all behavior to AWSIMDS but +// advertises the legacy `aws_oidc` name. +type AWSOIDC struct{ AWSIMDS } + +func (a *AWSOIDC) Name() string { return "aws_oidc" } + +func (a *AWSIMDS) Name() string { return "aws_imds" } + +func (a *AWSIMDS) RequiredParams() []string { return nil } + +func (a *AWSIMDS) OptionalParams() []string { return []string{"region", "service"} } + +func (a *AWSIMDS) ParseParams(m map[string]interface{}) (interface{}, error) { + return authplugins.ParseParams[awsIMDSParams](m) +} + +func (a *AWSIMDS) AddAuth(ctx context.Context, r *http.Request, params interface{}) error { + cfg, ok := params.(*awsIMDSParams) + if !ok { + return fmt.Errorf("invalid config") + } + creds, exp := getCachedCreds() + if creds == nil || nowFunc().After(exp.Add(-1*time.Minute)) { + var err error + creds, exp, err = fetchCredentials(ctx) + if err != nil { + return err + } + setCachedCreds(creds, exp) + } + region, service, err := determineRegionService(r.URL.Host, cfg) + if err != nil { + return err + } + return signRequest(r, region, service, creds) +} + +func fetchCredentials(ctx context.Context) (*roleCredentials, time.Time, error) { + metaToken, err := fetchMetadataToken(ctx) + if err != nil { + return nil, time.Time{}, err + } + + roleName, err := fetchRoleName(ctx, metaToken) + if err != nil { + return nil, time.Time{}, err + } + + credentials, err := fetchRoleCredentials(ctx, metaToken, roleName) + if err != nil { + return nil, time.Time{}, err + } + + if credentials.AccessKeyID == "" || credentials.SecretAccessKey == "" || credentials.Token == "" { + return nil, time.Time{}, fmt.Errorf("incomplete credentials from IMDS for role %s", roleName) + } + + exp, err := time.Parse(time.RFC3339, credentials.Expiration) + if err != nil { + return nil, time.Time{}, fmt.Errorf("parse expiration: %w", err) + } + + return credentials, exp, nil +} + +func fetchMetadataToken(ctx context.Context) (string, error) { + tokenURL := fmt.Sprintf("%s/latest/api/token", MetadataHost) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, tokenURL, nil) + if err != nil { + return "", err + } + req.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") + + resp, err := HTTPClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("token fetch status %s: %s", resp.Status, body) + } + + tokenBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return string(tokenBytes), nil +} + +func fetchRoleName(ctx context.Context, metaToken string) (string, error) { + roleURL := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/", MetadataHost) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, roleURL, nil) + if err != nil { + return "", err + } + req.Header.Set("X-aws-ec2-metadata-token", metaToken) + + resp, err := HTTPClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("role name status %s: %s", resp.Status, body) + } + + roleBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + roleName := strings.TrimSpace(string(roleBytes)) + if roleName == "" { + return "", fmt.Errorf("empty role name from IMDS") + } + return roleName, nil +} + +type roleCredentials struct { + AccessKeyID string `json:"AccessKeyId"` + SecretAccessKey string `json:"SecretAccessKey"` + Token string `json:"Token"` + Expiration string `json:"Expiration"` +} + +func fetchRoleCredentials(ctx context.Context, metaToken, roleName string) (*roleCredentials, error) { + credsURL := fmt.Sprintf("%s/latest/meta-data/iam/security-credentials/%s", MetadataHost, roleName) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, credsURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("X-aws-ec2-metadata-token", metaToken) + + resp, err := HTTPClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("role credentials status %s: %s", resp.Status, body) + } + + var rc roleCredentials + if err := json.NewDecoder(resp.Body).Decode(&rc); err != nil { + return nil, err + } + if rc.Expiration == "" { + return nil, fmt.Errorf("missing expiration in role credentials") + } + return &rc, nil +} + +type cachedCreds struct { + creds *roleCredentials + exp time.Time +} + +var credsCache = struct { + sync.Mutex + cc cachedCreds +}{cc: cachedCreds{}} + +func getCachedCreds() (*roleCredentials, time.Time) { + credsCache.Lock() + defer credsCache.Unlock() + return credsCache.cc.creds, credsCache.cc.exp +} + +func setCachedCreds(creds *roleCredentials, exp time.Time) { + credsCache.Lock() + credsCache.cc = cachedCreds{creds: creds, exp: exp} + credsCache.Unlock() +} + +func determineRegionService(host string, cfg *awsIMDSParams) (string, string, error) { + region := strings.TrimSpace(cfg.Region) + service := strings.TrimSpace(cfg.Service) + + if region != "" && service != "" { + return region, service, nil + } + + host = strings.Split(host, ":")[0] // strip port if present + parts := strings.Split(host, ".") + if len(parts) >= 4 && parts[len(parts)-2] == "amazonaws" { + // Use the right-most service and region portions to support hosts with + // additional labels (e.g., bucket.s3.us-west-2.amazonaws.com). + serviceIdx := len(parts) - 4 + regionIdx := len(parts) - 3 + if serviceIdx >= 0 && regionIdx >= 0 { + if service == "" { + candidate := parts[serviceIdx] + if candidate == "dualstack" && serviceIdx > 0 { + candidate = parts[serviceIdx-1] + } + service = candidate + } + if region == "" { + region = parts[regionIdx] + } + } + } + + if region == "" || service == "" { + return "", "", fmt.Errorf("aws_imds requires region and service; set params or use standard AWS hostname") + } + + return region, service, nil +} + +func signRequest(r *http.Request, region, service string, creds *roleCredentials) error { + now := nowFunc().UTC() + amzDate := now.Format("20060102T150405Z") + dateStamp := now.Format("20060102") + + if r.Header == nil { + r.Header = http.Header{} + } + + host := r.Host + if host == "" && r.URL != nil { + host = r.URL.Host + } + if host == "" { + return fmt.Errorf("request host is required for signing") + } + r.Header.Set("Host", host) + r.Header.Set("X-Amz-Date", amzDate) + r.Header.Set("X-Amz-Security-Token", creds.Token) + + body, err := readBody(r) + if err != nil { + return err + } + payloadHash := hashSHA256Hex(body) + r.Header.Set("X-Amz-Content-Sha256", payloadHash) + + signedHeaders, canonicalHeaders := canonicalizeHeaders(r.Header) + canonicalQuery := canonicalizeQuery(r.URL) + canonicalURI := canonicalURI(r.URL) + canonicalRequest := strings.Join([]string{ + r.Method, + canonicalURI, + canonicalQuery, + canonicalHeaders, + signedHeaders, + payloadHash, + }, "\n") + + credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", dateStamp, region, service) + stringToSign := strings.Join([]string{ + "AWS4-HMAC-SHA256", + amzDate, + credentialScope, + hashSHA256Hex([]byte(canonicalRequest)), + }, "\n") + + signingKey := buildSigningKey(creds.SecretAccessKey, dateStamp, region, service) + signature := hex.EncodeToString(hmacSHA256(signingKey, stringToSign)) + + authHeader := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", creds.AccessKeyID, credentialScope, signedHeaders, signature) + r.Header.Set("Authorization", authHeader) + + return nil +} + +func readBody(r *http.Request) ([]byte, error) { + if r.Body == nil { + return []byte{}, nil + } + origBody := r.Body + body, err := io.ReadAll(origBody) + closeErr := origBody.Close() + if err != nil { + return nil, err + } + if closeErr != nil { + return nil, closeErr + } + r.Body = io.NopCloser(bytes.NewReader(body)) + r.ContentLength = int64(len(body)) + return body, nil +} + +func canonicalizeHeaders(h http.Header) (string, string) { + lowerVals := make(map[string][]string, len(h)) + for k, v := range h { + lowerVals[strings.ToLower(k)] = v + } + keys := make([]string, 0, len(lowerVals)) + for k := range lowerVals { + keys = append(keys, k) + } + sort.Strings(keys) + var canonical strings.Builder + for _, k := range keys { + values := lowerVals[k] + for i := range values { + values[i] = strings.Join(strings.Fields(values[i]), " ") + } + canonical.WriteString(k) + canonical.WriteString(":") + canonical.WriteString(strings.Join(values, ",")) + canonical.WriteString("\n") + } + return strings.Join(keys, ";"), canonical.String() +} + +func canonicalizeQuery(u *url.URL) string { + if u == nil { + return "" + } + values, _ := url.ParseQuery(u.RawQuery) + if len(values) == 0 { + return "" + } + keys := make([]string, 0, len(values)) + for k := range values { + keys = append(keys, k) + } + sort.Strings(keys) + var parts []string + for _, k := range keys { + vals := values[k] + sort.Strings(vals) + for _, v := range vals { + parts = append(parts, fmt.Sprintf("%s=%s", escapeQueryComponent(k), escapeQueryComponent(v))) + } + } + return strings.Join(parts, "&") +} + +func escapeQueryComponent(v string) string { + escaped := url.QueryEscape(v) + escaped = strings.ReplaceAll(escaped, "+", "%20") + escaped = strings.ReplaceAll(escaped, "*", "%2A") + escaped = strings.ReplaceAll(escaped, "%7E", "~") + return escaped +} + +func canonicalURI(u *url.URL) string { + if u == nil { + return "/" + } + uri := u.EscapedPath() + if uri == "" { + uri = "/" + } + if !strings.HasPrefix(uri, "/") { + uri = "/" + uri + } + return removeDotSegments(uri) +} + +func removeDotSegments(path string) string { + if path == "" { + return "" + } + input := path + output := "" + for len(input) > 0 { + switch { + case strings.HasPrefix(input, "../"): + input = input[3:] + case strings.HasPrefix(input, "./"): + input = input[2:] + case strings.HasPrefix(input, "/./"): + input = "/" + input[3:] + case input == "/.": + input = "/" + case strings.HasPrefix(input, "/../"): + input = "/" + input[4:] + output = removeLastSegment(output) + case input == "/..": + input = "/" + output = removeLastSegment(output) + case input == "." || input == "..": + input = "" + default: + var segment string + if strings.HasPrefix(input, "/") { + if idx := strings.Index(input[1:], "/"); idx != -1 { + segment = input[:idx+1] + input = input[idx+1:] + } else { + segment = input + input = "" + } + } else { + if idx := strings.IndexByte(input, '/'); idx != -1 { + segment = input[:idx] + input = input[idx:] + } else { + segment = input + input = "" + } + } + output += segment + } + } + return output +} + +func removeLastSegment(path string) string { + idx := strings.LastIndex(path, "/") + if idx == -1 { + return "" + } + return path[:idx] +} + +func defaultIMDSTransport() *http.Transport { + if t, ok := http.DefaultTransport.(*http.Transport); ok { + clone := t.Clone() + clone.Proxy = nil + return clone + } + + return &http.Transport{Proxy: nil} +} + +func hashSHA256Hex(b []byte) string { + h := sha256.Sum256(b) + return hex.EncodeToString(h[:]) +} + +func hmacSHA256(key []byte, msg string) []byte { + h := hmac.New(sha256.New, key) + h.Write([]byte(msg)) + return h.Sum(nil) +} + +func buildSigningKey(secret, dateStamp, region, service string) []byte { + kDate := hmacSHA256([]byte("AWS4"+secret), dateStamp) + kRegion := hmacSHA256(kDate, region) + kService := hmacSHA256(kRegion, service) + kSigning := hmacSHA256(kService, "aws4_request") + return kSigning +} + +func init() { + authplugins.RegisterOutgoing(&AWSIMDS{}) + authplugins.RegisterOutgoing(&AWSOIDC{}) +} diff --git a/app/auth/plugins/aws_imds/outgoing_test.go b/app/auth/plugins/aws_imds/outgoing_test.go new file mode 100644 index 0000000..d488e40 --- /dev/null +++ b/app/auth/plugins/aws_imds/outgoing_test.go @@ -0,0 +1,359 @@ +package awsimds + +import ( + "context" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +type closingBuffer struct { + data []byte + readIdx int + closed bool + closeErr error +} + +func (c *closingBuffer) Read(p []byte) (int, error) { + if c.readIdx >= len(c.data) { + return 0, io.EOF + } + n := copy(p, c.data[c.readIdx:]) + c.readIdx += n + return n, nil +} + +func (c *closingBuffer) Close() error { + c.closed = true + return c.closeErr +} + +func TestAddAuthFetchesAndSigns(t *testing.T) { + expires := time.Now().Add(2 * time.Minute).UTC().Truncate(time.Second) + metaToken := "meta123" + roleName := "example-role" + creds := map[string]interface{}{ + "AccessKeyId": "AKIDEXAMPLE", + "SecretAccessKey": "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY", + "Token": "sts-session-token", + "Expiration": expires.Format(time.RFC3339), + } + var requestCount int + fixedNow := time.Date(2023, 1, 2, 15, 4, 5, 0, time.UTC) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/api/token": + requestCount++ + w.Write([]byte(metaToken)) + case "/latest/meta-data/iam/security-credentials/": + requestCount++ + w.Write([]byte(roleName)) + case "/latest/meta-data/iam/security-credentials/" + roleName: + requestCount++ + json.NewEncoder(w).Encode(creds) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer srv.Close() + + prevHost := MetadataHost + prevClient := HTTPClient + MetadataHost = srv.URL + HTTPClient = srv.Client() + credsCache.cc = cachedCreds{} + defer func() { + MetadataHost = prevHost + HTTPClient = prevClient + credsCache.cc = cachedCreds{} + }() + prevNow := nowFunc + nowFunc = func() time.Time { return fixedNow } + defer func() { nowFunc = prevNow }() + + plugin := &AWSIMDS{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) + if err != nil { + t.Fatalf("parse params: %v", err) + } + req, _ := http.NewRequest(http.MethodGet, "https://s3.us-west-2.amazonaws.com/example", nil) + + if err := plugin.AddAuth(context.Background(), req, paramsRaw); err != nil { + t.Fatalf("AddAuth: %v", err) + } + + if got := req.Header.Get("X-Amz-Security-Token"); got != creds["Token"] { + t.Fatalf("missing security token header: %s", got) + } + authz := req.Header.Get("Authorization") + if !strings.HasPrefix(authz, "AWS4-HMAC-SHA256 ") { + t.Fatalf("expected SigV4 auth header, got %s", authz) + } + if !strings.Contains(authz, "Credential=AKIDEXAMPLE/20230102/us-west-2/s3/aws4_request") { + t.Fatalf("unexpected credential scope: %s", authz) + } + if !strings.Contains(authz, "SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token") { + t.Fatalf("missing signed headers: %s", authz) + } + if !strings.Contains(authz, "Signature=") { + t.Fatalf("missing signature: %s", authz) + } + if got := req.Header.Get("X-Amz-Content-Sha256"); got != "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" { + t.Fatalf("unexpected payload hash: %s", got) + } + + // Verify the canonical request and signature are fully deterministic. + headers := req.Header.Clone() + headers.Del("Authorization") + signedHeaders, canonicalHeaders := canonicalizeHeaders(headers) + canonicalQuery := canonicalizeQuery(req.URL) + canonicalURI := canonicalURI(req.URL) + canonicalRequest := strings.Join([]string{ + req.Method, + canonicalURI, + canonicalQuery, + canonicalHeaders, + signedHeaders, + req.Header.Get("X-Amz-Content-Sha256"), + }, "\n") + credentialScope := fmt.Sprintf("%s/%s/%s/aws4_request", fixedNow.Format("20060102"), "us-west-2", "s3") + stringToSign := strings.Join([]string{ + "AWS4-HMAC-SHA256", + req.Header.Get("X-Amz-Date"), + credentialScope, + hashSHA256Hex([]byte(canonicalRequest)), + }, "\n") + signingKey := buildSigningKey(creds["SecretAccessKey"].(string), fixedNow.Format("20060102"), "us-west-2", "s3") + expectedSig := hex.EncodeToString(hmacSHA256(signingKey, stringToSign)) + expectedAuth := fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s", creds["AccessKeyId"], credentialScope, signedHeaders, expectedSig) + if authz != expectedAuth { + t.Fatalf("unexpected authorization header:\nwant %s\ngot %s", expectedAuth, authz) + } + + if requestCount != 3 { + t.Fatalf("expected 3 metadata requests, got %d", requestCount) + } + + // Second call should use cache. + req2, _ := http.NewRequest(http.MethodGet, "https://s3.us-west-2.amazonaws.com/example", nil) + if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil { + t.Fatalf("AddAuth second: %v", err) + } + if requestCount != 3 { + t.Fatalf("expected cached credentials reuse, got %d metadata calls", requestCount) + } +} + +func TestExpiresSoonTriggersRefresh(t *testing.T) { + expSoon := time.Now().Add(30 * time.Second).UTC().Truncate(time.Second) + expLater := time.Now().Add(10 * time.Minute).UTC().Truncate(time.Second) + metaToken := "meta123" + roleName := "role" + credIndex := 0 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/api/token": + w.Write([]byte(metaToken)) + case "/latest/meta-data/iam/security-credentials/": + w.Write([]byte(roleName)) + case "/latest/meta-data/iam/security-credentials/" + roleName: + exp := expSoon + if credIndex > 0 { + exp = expLater + } + json.NewEncoder(w).Encode(map[string]interface{}{ + "AccessKeyId": "AKID", + "SecretAccessKey": "SECRET", + "Token": []string{"first", "second"}[credIndex], + "Expiration": exp.Format(time.RFC3339), + }) + credIndex++ + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer srv.Close() + + prevHost := MetadataHost + prevClient := HTTPClient + MetadataHost = srv.URL + HTTPClient = srv.Client() + credsCache.cc = cachedCreds{} + defer func() { + MetadataHost = prevHost + HTTPClient = prevClient + credsCache.cc = cachedCreds{} + }() + prevNow := nowFunc + current := time.Now() + nowFunc = func() time.Time { return current } + defer func() { nowFunc = prevNow }() + + plugin := &AWSIMDS{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) + if err != nil { + t.Fatalf("parse params: %v", err) + } + req, _ := http.NewRequest(http.MethodGet, "https://s3.us-east-1.amazonaws.com", nil) + if err := plugin.AddAuth(context.Background(), req, paramsRaw); err != nil { + t.Fatalf("AddAuth: %v", err) + } + current = current.Add(2 * time.Minute) + req2, _ := http.NewRequest(http.MethodGet, "https://s3.us-east-1.amazonaws.com", nil) + if err := plugin.AddAuth(context.Background(), req2, paramsRaw); err != nil { + t.Fatalf("AddAuth second: %v", err) + } + if credIndex != 2 { + t.Fatalf("expected credential refresh, got %d", credIndex) + } +} + +func TestDetermineRegionServiceWithResourcePrefix(t *testing.T) { + cfg := &awsIMDSParams{} + region, service, err := determineRegionService("mybucket.s3.us-west-2.amazonaws.com", cfg) + if err != nil { + t.Fatalf("determineRegionService: %v", err) + } + if region != "us-west-2" || service != "s3" { + t.Fatalf("unexpected derived values region=%s service=%s", region, service) + } +} + +func TestDetermineRegionServiceDualstack(t *testing.T) { + cfg := &awsIMDSParams{} + region, service, err := determineRegionService("s3.dualstack.us-east-1.amazonaws.com", cfg) + if err != nil { + t.Fatalf("determineRegionService: %v", err) + } + if region != "us-east-1" || service != "s3" { + t.Fatalf("unexpected derived values region=%s service=%s", region, service) + } +} + +func TestErrorResponses(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest/api/token": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("no token")) + case "/latest/meta-data/iam/security-credentials/": + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("no role")) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer srv.Close() + + prevHost := MetadataHost + prevClient := HTTPClient + MetadataHost = srv.URL + HTTPClient = srv.Client() + credsCache.cc = cachedCreds{} + defer func() { + MetadataHost = prevHost + HTTPClient = prevClient + credsCache.cc = cachedCreds{} + }() + + plugin := &AWSIMDS{} + paramsRaw, err := plugin.ParseParams(map[string]interface{}{}) + if err != nil { + t.Fatalf("parse params: %v", err) + } + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + if err := plugin.AddAuth(context.Background(), req, paramsRaw); err == nil { + t.Fatalf("expected error from metadata token fetch") + } +} + +func TestReadBodyClosesOriginal(t *testing.T) { + recorder := &closingBuffer{data: []byte("payload")} + req, _ := http.NewRequest(http.MethodPost, "http://example.com", recorder) + + body, err := readBody(req) + if err != nil { + t.Fatalf("readBody: %v", err) + } + if !recorder.closed { + t.Fatalf("expected original body to be closed") + } + if string(body) != "payload" { + t.Fatalf("unexpected body contents: %s", string(body)) + } + if req.Body == recorder { + t.Fatalf("expected request body to be replaced") + } + reread, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("reading replaced body: %v", err) + } + if string(reread) != "payload" { + t.Fatalf("unexpected reread body: %s", string(reread)) + } + if req.ContentLength != int64(len(body)) { + t.Fatalf("unexpected content length: %d", req.ContentLength) + } +} + +func TestReadBodyCloseError(t *testing.T) { + recorder := &closingBuffer{data: []byte("payload"), closeErr: errors.New("close failure")} + req, _ := http.NewRequest(http.MethodPost, "http://example.com", recorder) + + if _, err := readBody(req); err == nil { + t.Fatalf("expected close error") + } + if !recorder.closed { + t.Fatalf("expected original body to be closed even on error") + } + if req.Body != recorder { + t.Fatalf("expected body not to be replaced on error") + } +} + +func TestCanonicalURIPreservesTrailingAndRepeatedSlashes(t *testing.T) { + cases := []struct { + name string + path string + want string + }{ + {name: "empty", path: "", want: "/"}, + {name: "root", path: "/", want: "/"}, + {name: "trailing slash", path: "/foo/bar/", want: "/foo/bar/"}, + {name: "repeated slashes", path: "/foo//bar//baz", want: "/foo//bar//baz"}, + {name: "dot segment", path: "/foo/./bar", want: "/foo/bar"}, + {name: "parent segment", path: "/foo/../bar", want: "/bar"}, + {name: "parent with trailing", path: "/foo/bar/../", want: "/foo/"}, + {name: "no leading slash", path: "foo/bar", want: "/foo/bar"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + u := &url.URL{Path: tc.path} + if got := canonicalURI(u); got != tc.want { + t.Fatalf("canonicalURI(%q) = %q, want %q", tc.path, got, tc.want) + } + }) + } +} + +func TestDefaultHTTPClientDisablesProxy(t *testing.T) { + client := HTTPClient + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("unexpected transport type %T", client.Transport) + } + if transport.Proxy != nil { + t.Fatalf("expected proxy to be disabled, got %v", transport.Proxy) + } +} diff --git a/app/auth/plugins/plugins.go b/app/auth/plugins/plugins.go index 4e914ed..37c8252 100644 --- a/app/auth/plugins/plugins.go +++ b/app/auth/plugins/plugins.go @@ -1,6 +1,7 @@ package plugins import ( + _ "github.com/winhowes/AuthTranslator/app/auth/plugins/aws_imds" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/azure_managed_identity" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/basic" _ "github.com/winhowes/AuthTranslator/app/auth/plugins/findreplace" diff --git a/docs/auth-plugins.md b/docs/auth-plugins.md index 5896042..89a7c63 100644 --- a/docs/auth-plugins.md +++ b/docs/auth-plugins.md @@ -32,6 +32,7 @@ AuthTranslator’s behaviour is extended by **plugins** – small Go packages th | Outbound | `basic` | Adds HTTP Basic credentials to the upstream request. | | Outbound | `google_oidc` | Attaches a Google identity token from the metadata service. | | Outbound | `gcp_token` | Uses a metadata service access token. | +| Outbound | `aws_imds` | Retrieves an AWS IMDS session token from the Instance Metadata Service (IMDSv2). | | Outbound | `azure_managed_identity` | Retrieves an Azure access token from the Instance Metadata Service. | | Outbound | `hmac_signature` | Computes an HMAC for the request. | | Outbound | `jwt` | Adds a signed JWT to the request. | @@ -99,6 +100,20 @@ outgoing_auth: Obtains an access token from the Azure Instance Metadata Service for the specified `resource`, caches it, and attaches it to the configured header on each outgoing request. +### Outbound `aws_imds` + +```yaml +outgoing_auth: + - type: aws_imds + params: + region: us-west-2 # optional, inferred from AWS hostname if omitted + service: s3 # optional, inferred from AWS hostname if omitted +``` + +Retrieves temporary IAM role credentials from the AWS Instance Metadata Service v2, caches them until shortly before expiry, and applies SigV4 signing (including `X-Amz-Security-Token`) to each outgoing request. If the upstream hostname follows the standard `service.region.amazonaws.com` pattern, the plugin auto‑discovers the service and region; otherwise set them explicitly. + +> **Note:** The legacy type name `aws_oidc` remains supported for backward compatibility but now resolves to the IMDS session token flow described above. New configurations should prefer `aws_imds` to reflect the actual authentication mechanism. + --- ## Writing your own plugin