diff --git a/cmd/allowlist/main.go b/cmd/allowlist/main.go index 1d061af..7391f5a 100644 --- a/cmd/allowlist/main.go +++ b/cmd/allowlist/main.go @@ -11,7 +11,12 @@ import ( "github.com/winhowes/AuthTranslator/cmd/allowlist/plugins" ) -var file = flag.String("file", "allowlist.yaml", "allowlist file") +var ( + file = flag.String("file", "allowlist.yaml", "allowlist file") + yamlMarshal = yaml.Marshal + writeFile = os.WriteFile + exitFunc = os.Exit +) func usage() { fmt.Fprintf(flag.CommandLine.Output(), `Usage: allowlist [options] \n\n`) @@ -61,6 +66,7 @@ func addEntry(args []string) { fs.Parse(args) if *integ == "" || *caller == "" || *capName == "" { fmt.Println("-integration, -caller and -capability required") + fs.Usage() return } var params map[string]interface{} @@ -132,16 +138,16 @@ func addEntry(args []string) { callerCfg.Capabilities = append(callerCfg.Capabilities, plugins.CapabilityConfig{Name: *capName, Params: params}) } - out, err := yaml.Marshal(entries) + out, err := yamlMarshal(entries) if err != nil { fmt.Fprintln(os.Stderr, err) - os.Exit(1) + exitFunc(1) } out = bytes.ReplaceAll(out, []byte("params: {}"), []byte("params: null")) - if err := os.WriteFile(*file, out, 0644); err != nil { + if err := writeFile(*file, out, 0644); err != nil { fmt.Fprintln(os.Stderr, err) - os.Exit(1) + exitFunc(1) } } @@ -158,6 +164,7 @@ func removeEntry(args []string) { if *integ == "" || *caller == "" || *capName == "" { fmt.Println("-integration, -caller and -capability required") + fs.Usage() return } @@ -208,14 +215,14 @@ func removeEntry(args []string) { break } - out, err := yaml.Marshal(entries) + out, err := yamlMarshal(entries) if err != nil { fmt.Fprintln(os.Stderr, err) - os.Exit(1) + exitFunc(1) } out = bytes.ReplaceAll(out, []byte("params: {}"), []byte("params: null")) - if err := os.WriteFile(*file, out, 0644); err != nil { + if err := writeFile(*file, out, 0644); err != nil { fmt.Fprintln(os.Stderr, err) - os.Exit(1) + exitFunc(1) } } diff --git a/cmd/allowlist/main_test.go b/cmd/allowlist/main_test.go index f01ae23..4c9d890 100644 --- a/cmd/allowlist/main_test.go +++ b/cmd/allowlist/main_test.go @@ -3,6 +3,7 @@ package main import ( "bytes" "flag" + "fmt" yaml "gopkg.in/yaml.v3" "io" "os" @@ -397,6 +398,29 @@ func TestAddEntryParamTrim(t *testing.T) { } } +func TestAddEntryIgnoresEmptyParams(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "allow.yaml") + old := *file + *file = path + t.Cleanup(func() { *file = old }) + + addEntry([]string{"-integration", "foo", "-caller", "u1", "-capability", "cap", "-params", "k=v1,, ,other=v2"}) + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed reading file: %v", err) + } + var entries []plugins.AllowlistEntry + if err := yaml.Unmarshal(data, &entries); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + params := entries[0].Callers[0].Capabilities[0].Params + if len(params) != 2 || params["k"] != "v1" || params["other"] != "v2" { + t.Fatalf("unexpected params: %#v", params) + } +} + func TestAddEntryMissingArgs(t *testing.T) { tmpDir := t.TempDir() path := filepath.Join(tmpDir, "allow.yaml") @@ -482,6 +506,46 @@ func TestRemoveEntryCapabilityNotFound(t *testing.T) { } } +func TestRemoveEntrySkipsOtherCallers(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "allow.yaml") + + initial := []plugins.AllowlistEntry{ + { + Integration: "foo", + Callers: []plugins.CallerConfig{ + {ID: "u1", Capabilities: []plugins.CapabilityConfig{{Name: "cap1"}}}, + {ID: "u2", Capabilities: []plugins.CapabilityConfig{{Name: "cap2"}}}, + }, + }, + } + data, _ := yaml.Marshal(initial) + if err := os.WriteFile(path, data, 0644); err != nil { + t.Fatal(err) + } + + old := *file + *file = path + t.Cleanup(func() { *file = old }) + + removeEntry([]string{"-integration", "foo", "-caller", "u2", "-capability", "cap2"}) + + out, err := os.ReadFile(path) + if err != nil { + t.Fatalf("failed reading file: %v", err) + } + var entries []plugins.AllowlistEntry + if err := yaml.Unmarshal(out, &entries); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + want := []plugins.AllowlistEntry{ + {Integration: "foo", Callers: []plugins.CallerConfig{{ID: "u1", Capabilities: []plugins.CapabilityConfig{{Name: "cap1"}}}}}, + } + if !reflect.DeepEqual(entries, want) { + t.Fatalf("entries mismatch: %#v", entries) + } +} + func TestRemoveEntryFormatsParamsNull(t *testing.T) { tmpDir := t.TempDir() path := filepath.Join(tmpDir, "allow.yaml") @@ -677,6 +741,90 @@ func TestAddEntryWriteError(t *testing.T) { } } +func TestAddEntryMarshalError(t *testing.T) { + oldMarshal := yamlMarshal + oldExit := exitFunc + yamlMarshal = func(interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + } + exitFunc = func(code int) { + panic(fmt.Sprintf("exit %d", code)) + } + t.Cleanup(func() { + yamlMarshal = oldMarshal + exitFunc = oldExit + }) + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected exit") + } + }() + + addEntry([]string{"-integration", "foo", "-caller", "u1", "-capability", "cap"}) +} + +func TestRemoveEntryMarshalError(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "allow.yaml") + initial := []plugins.AllowlistEntry{{Integration: "foo", Callers: []plugins.CallerConfig{{ID: "u1", Capabilities: []plugins.CapabilityConfig{{Name: "cap"}}}}}} + data, _ := yaml.Marshal(initial) + os.WriteFile(path, data, 0644) + + oldMarshal := yamlMarshal + oldExit := exitFunc + oldFile := *file + yamlMarshal = func(interface{}) ([]byte, error) { + return nil, fmt.Errorf("marshal error") + } + exitFunc = func(code int) { + panic(fmt.Sprintf("exit %d", code)) + } + *file = path + t.Cleanup(func() { + yamlMarshal = oldMarshal + exitFunc = oldExit + *file = oldFile + }) + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected exit") + } + }() + + removeEntry([]string{"-integration", "foo", "-caller", "u1", "-capability", "cap"}) +} + +func TestRemoveEntryWriteError(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "allow.yaml") + initial := []plugins.AllowlistEntry{{Integration: "foo", Callers: []plugins.CallerConfig{{ID: "u1", Capabilities: []plugins.CapabilityConfig{{Name: "cap"}}}}}} + data, _ := yaml.Marshal(initial) + os.WriteFile(path, data, 0644) + + oldWrite := writeFile + oldExit := exitFunc + oldFile := *file + writeFile = func(string, []byte, os.FileMode) error { + return fmt.Errorf("write fail") + } + exitFunc = func(code int) { + panic(fmt.Sprintf("exit %d", code)) + } + *file = path + t.Cleanup(func() { + writeFile = oldWrite + exitFunc = oldExit + *file = oldFile + }) + defer func() { + if r := recover(); r == nil { + t.Fatalf("expected exit") + } + }() + + removeEntry([]string{"-integration", "foo", "-caller", "u1", "-capability", "cap"}) +} + func TestRemoveEntryInvalidYAML(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "allow.yaml")