From 215faa1303520a353980392e5ae8d550eca55ec6 Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Sun, 24 Aug 2025 18:43:33 +0200 Subject: [PATCH] templates: make "join" work with non-string slices and map values Add a custom join function that allows for non-string slices to be joined, following the same rules as "fmt.Sprint", it will use the fmt.Stringer interface if implemented, or "error" if the type has an "Error()". For maps, it joins the map-values, for example: docker image inspect --format '{{join .Config.Labels ", "}}' ubuntu 24.04, ubuntu Signed-off-by: Sebastiaan van Stijn --- templates/templates.go | 42 ++++++++++++++++- templates/templates_test.go | 90 +++++++++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/templates/templates.go b/templates/templates.go index 4af4496d19a1..ea381a5f68ce 100644 --- a/templates/templates.go +++ b/templates/templates.go @@ -6,6 +6,9 @@ package templates import ( "bytes" "encoding/json" + "fmt" + "reflect" + "sort" "strings" "text/template" ) @@ -26,7 +29,7 @@ var basicFunctions = template.FuncMap{ return strings.TrimSpace(buf.String()) }, "split": strings.Split, - "join": strings.Join, + "join": joinElements, "title": strings.Title, //nolint:nolintlint,staticcheck // strings.Title is deprecated, but we only use it for ASCII, so replacing with golang.org/x/text is out of scope "lower": strings.ToLower, "upper": strings.ToUpper, @@ -103,3 +106,40 @@ func truncateWithLength(source string, length int) string { } return source[:length] } + +// joinElements joins a slice of items with the given separator. It uses +// [strings.Join] if it's a slice of strings, otherwise uses [fmt.Sprint] +// to join each item to the output. +func joinElements(elems any, sep string) (string, error) { + if elems == nil { + return "", nil + } + + if ss, ok := elems.([]string); ok { + return strings.Join(ss, sep), nil + } + + switch rv := reflect.ValueOf(elems); rv.Kind() { //nolint:exhaustive // ignore: too many options to make exhaustive + case reflect.Array, reflect.Slice: + var b strings.Builder + for i := range rv.Len() { + if i > 0 { + b.WriteString(sep) + } + _, _ = fmt.Fprint(&b, rv.Index(i).Interface()) + } + return b.String(), nil + + case reflect.Map: + var out []string + for _, k := range rv.MapKeys() { + out = append(out, fmt.Sprint(rv.MapIndex(k).Interface())) + } + // Not ideal, but trying to keep a consistent order + sort.Strings(out) + return strings.Join(out, sep), nil + + default: + return "", fmt.Errorf("expected slice, got %T", elems) + } +} diff --git a/templates/templates_test.go b/templates/templates_test.go index e9dbaefd0e5e..ed1ee5b95d13 100644 --- a/templates/templates_test.go +++ b/templates/templates_test.go @@ -3,6 +3,7 @@ package templates import ( "bytes" "testing" + "text/template" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" @@ -139,3 +140,92 @@ func TestHeaderFunctions(t *testing.T) { }) } } + +type stringerString string + +func (s stringerString) String() string { + return "stringer" + string(s) +} + +type stringerAndError string + +func (s stringerAndError) String() string { + return "stringer" + string(s) +} + +func (s stringerAndError) Error() string { + return "error" + string(s) +} + +func TestJoinElements(t *testing.T) { + tests := []struct { + doc string + data any + expOut string + expErr string + }{ + { + doc: "nil", + data: nil, + expOut: `output: ""`, + }, + { + doc: "non-slice", + data: "hello", + expOut: `output: "`, + expErr: `error calling join: expected slice, got string`, + }, + { + doc: "structs", + data: []struct{ A, B string }{{"1", "2"}, {"3", "4"}}, + expOut: `output: "{1 2}, {3 4}"`, + }, + { + doc: "map with strings", + data: map[string]string{"A": "1", "B": "2", "C": "3"}, + expOut: `output: "1, 2, 3"`, + }, + { + doc: "map with stringers", + data: map[string]stringerString{"A": "1", "B": "2", "C": "3"}, + expOut: `output: "stringer1, stringer2, stringer3"`, + }, + { + doc: "map with errors", + data: []stringerAndError{"1", "2", "3"}, + expOut: `output: "error1, error2, error3"`, + }, + { + doc: "stringers", + data: []stringerString{"1", "2", "3"}, + expOut: `output: "stringer1, stringer2, stringer3"`, + }, + { + doc: "stringer with errors", + data: []stringerAndError{"1", "2", "3"}, + expOut: `output: "error1, error2, error3"`, + }, + { + doc: "slice of bools", + data: []bool{true, false, true}, + expOut: `output: "true, false, true"`, + }, + } + + const formatStr = `output: "{{- join . ", " -}}"` + tmpl, err := New("my-template").Funcs(template.FuncMap{"join": joinElements}).Parse(formatStr) + assert.NilError(t, err) + + for _, tc := range tests { + t.Run(tc.doc, func(t *testing.T) { + var b bytes.Buffer + err := tmpl.Execute(&b, tc.data) + if tc.expErr != "" { + assert.ErrorContains(t, err, tc.expErr) + } else { + assert.NilError(t, err) + } + assert.Equal(t, b.String(), tc.expOut) + }) + } +}