diff --git a/cmd/cql-cli/main.go b/cmd/cql-cli/main.go
index 5b358c4..084f9af 100644
--- a/cmd/cql-cli/main.go
+++ b/cmd/cql-cli/main.go
@@ -36,7 +36,11 @@ func main() {
case "struct":
fmt.Printf("%+v\n", query)
case "xcql":
- fmt.Print((&cql.Xcql{}).Marshal(query, 2))
+ os.Stdout.WriteString("\n")
+ err := (&cql.Xcql{}).Write(query, 2, os.Stdout)
+ if err != nil {
+ fmt.Fprintln(os.Stderr, "ERROR", err)
+ }
default:
fmt.Fprintln(os.Stderr, "Unknown output format:", outFmt)
os.Exit(1)
diff --git a/cql/parser_test.go b/cql/parser_test.go
index 41956c6..fffd3f3 100644
--- a/cql/parser_test.go
+++ b/cql/parser_test.go
@@ -1,6 +1,7 @@
package cql
import (
+ "errors"
"strings"
"testing"
)
@@ -507,6 +508,24 @@ func TestParseXml(t *testing.T) {
ok: false,
expect: "EOF expected at position 3",
},
+ {
+ name: "invalid",
+ input: "\"\x05\"",
+ ok: true,
+ expect: `
+
+
+cql.serverChoice
+
+=
+
+` + "\xef\xbf\xbd" + // replacement character, #FFFD
+ `
+
+
+
+`,
+ },
} {
t.Run(testcase.name, func(t *testing.T) {
node, err := p.Parse(testcase.input)
@@ -515,7 +534,11 @@ func TestParseXml(t *testing.T) {
t.Fatalf("expected OK for query %s . Got error: %s", testcase.input, err)
}
var xcql Xcql
- xml := xcql.Marshal(node, testcase.tab)
+ bytes, err := xcql.MarshalIndent(node, testcase.tab)
+ if err != nil {
+ t.Fatalf("error marshalling query %s: %s", testcase.input, err)
+ }
+ xml := string(bytes)
if xml != testcase.expect {
t.Fatalf("Different XML for query %s\nExpect:\n%s\nGot:\n%s", testcase.input, testcase.expect, xml)
}
@@ -1055,3 +1078,28 @@ func TestBoolClauseString(t *testing.T) {
t.Fatalf("expected:\n%s\nwas:\n%s", in, out)
}
}
+
+type FailWriter struct{}
+
+func (f *FailWriter) Write(p []byte) (n int, err error) {
+ return 0, errors.New("write error")
+}
+
+func TestXcqlFailWriter(t *testing.T) {
+ var p Parser
+ query, err := p.Parse("a")
+ if err != nil {
+ t.Fatalf("parse error: %s", err)
+ }
+ var xcql Xcql
+ err = xcql.Write(query, 0, &FailWriter{})
+ if err == nil {
+ t.Fatalf("expected error but got nil")
+ }
+ xcql.err = nil
+ xcql.cdata("hello")
+ err = xcql.err
+ if err == nil {
+ t.Fatalf("expected error but got nil")
+ }
+}
diff --git a/cql/xcql.go b/cql/xcql.go
index 09b4e12..c566de4 100644
--- a/cql/xcql.go
+++ b/cql/xcql.go
@@ -1,40 +1,36 @@
package cql
import (
- "strings"
- "unicode/utf8"
+ "bytes"
+ "encoding/xml"
+ "io"
)
type Xcql struct {
- sb strings.Builder
+ w io.Writer
+ err error
tab int
}
func (xcql *Xcql) cdata(msg string) {
- pos := 0
- for pos < len(msg) {
- r, w := utf8.DecodeRuneInString(msg[pos:])
- switch r {
- case utf8.RuneError:
- return
- case '&':
- xcql.sb.WriteString("&")
- case '<':
- xcql.sb.WriteString("<")
- case '>':
- xcql.sb.WriteString(">")
- default:
- xcql.sb.WriteRune(r)
- }
- pos += w
+ err := xml.EscapeText(xcql.w, []byte(msg))
+ if err != nil && xcql.err == nil {
+ xcql.err = err
+ }
+}
+
+func (xcql *Xcql) write(msg string) {
+ _, err := xcql.w.Write([]byte(msg))
+ if err != nil && xcql.err == nil {
+ xcql.err = err
}
}
func (xcql *Xcql) pr(level int, msg string) {
for i := 0; i < level*xcql.tab; i++ {
- xcql.sb.WriteString(" ")
+ xcql.write(" ")
}
- xcql.sb.WriteString(msg)
+ xcql.write(msg)
}
func (xcql *Xcql) toXmlMod(modifiers []Modifier, level int) {
@@ -150,11 +146,17 @@ func (xcql *Xcql) toXmlSort(query Query, level int) {
}
}
-func (xcql *Xcql) Marshal(query Query, tab int) string {
- xcql.sb.Reset()
+func (xcql *Xcql) Write(query Query, tab int, w io.Writer) error {
+ xcql.w = w
xcql.tab = tab
xcql.pr(0, "\n")
xcql.toXmlSort(query, 1)
xcql.pr(0, "\n")
- return xcql.sb.String()
+ return xcql.err
+}
+
+func (xcql *Xcql) MarshalIndent(query Query, tab int) ([]byte, error) {
+ buf := new(bytes.Buffer)
+ err := xcql.Write(query, tab, buf)
+ return buf.Bytes(), err
}