diff --git a/cmd/pgcql-cli/main.go b/cmd/pgcql-cli/main.go index 11049f2..2b150bc 100644 --- a/cmd/pgcql-cli/main.go +++ b/cmd/pgcql-cli/main.go @@ -14,8 +14,7 @@ func main() { flag.StringVar(&serverChoiceColumn, "s", "text", "column for cql.serverChoice") def := pgcql.NewPgDefinition() if serverChoiceColumn != "" { - serverChoice := &pgcql.FieldString{} - serverChoice.WithFullText("english").SetColumn(serverChoiceColumn) + serverChoice := pgcql.NewFieldString().WithFullText("english").WithColumn(serverChoiceColumn) def.AddField("cql.serverChoice", serverChoice) } flag.Parse() @@ -26,8 +25,7 @@ func main() { } for i := 0; i < len(flag.Args()); i++ { if i < len(flag.Args())-1 { - field := &pgcql.FieldString{} - field.WithLikeOps() + field := pgcql.NewFieldString().WithLikeOps() def.AddField(flag.Args()[i], field) continue } diff --git a/pgcql/pg_field_common.go b/pgcql/pg_field_common.go index 866d75b..5a17324 100644 --- a/pgcql/pg_field_common.go +++ b/pgcql/pg_field_common.go @@ -1,8 +1,6 @@ package pgcql import ( - "fmt" - "github.com/indexdata/cql-go/cql" ) @@ -24,7 +22,7 @@ func (f *FieldCommon) handleUnorderedRelation(sc cql.SearchClause) (string, erro case cql.NE: return "<>", nil default: - return "", fmt.Errorf("unsupported relation %s", sc.Relation) + return "", &PgError{message: "unsupported relation " + string(sc.Relation)} } } @@ -35,7 +33,7 @@ func (f *FieldCommon) handleOrderedRelation(sc cql.SearchClause) (string, error) case "=", "<>", ">", "<", "<=", ">=": return string(sc.Relation), nil default: - return "", &PgError{message: "unsupported operator " + string(sc.Relation)} + return "", &PgError{message: "unsupported relation " + string(sc.Relation)} } } diff --git a/pgcql/pg_field_string.go b/pgcql/pg_field_string.go index 6c936dd..f9412d7 100644 --- a/pgcql/pg_field_string.go +++ b/pgcql/pg_field_string.go @@ -2,6 +2,7 @@ package pgcql import ( "fmt" + "strings" "github.com/indexdata/cql-go/cql" ) @@ -11,9 +12,19 @@ type FieldString struct { language string enableLike bool enableExact bool + enableSplit bool } -func (f *FieldString) WithFullText(language string) Field { +func NewFieldString() *FieldString { + return &FieldString{} +} + +func (f *FieldString) WithColumn(column string) *FieldString { + f.column = column + return f +} + +func (f *FieldString) WithFullText(language string) *FieldString { if language == "" { f.language = "simple" } else { @@ -22,18 +33,32 @@ func (f *FieldString) WithFullText(language string) Field { return f } -func (f *FieldString) WithLikeOps() Field { +func (f *FieldString) WithLikeOps() *FieldString { f.enableExact = true f.enableLike = true return f } -func (f *FieldString) WithExact() Field { +func (f *FieldString) WithExact() *FieldString { f.enableExact = true return f } +func (f *FieldString) WithSplit() *FieldString { + f.enableSplit = true + return f +} + func maskedExact(cqlTerm string) (string, error) { + terms, err := maskedSplit(cqlTerm, "") + if err != nil { + return "", err + } + return terms[0], nil +} + +func maskedSplit(cqlTerm string, splitChars string) ([]string, error) { + terms := make([]string, 0) var pgTerm []rune backslash := false @@ -43,29 +68,36 @@ func maskedExact(cqlTerm string) (string, error) { case '*', '"', '?', '^', '\\': pgTerm = append(pgTerm, c) default: - return "", fmt.Errorf("a masking backslash in a CQL string must be followed by *, ?, ^, \" or \\") + return terms, fmt.Errorf("a masking backslash in a CQL string must be followed by *, ?, ^, \" or \\") } backslash = false } else { switch c { case '*': - return "", fmt.Errorf("masking op * unsupported") + return terms, fmt.Errorf("masking op * unsupported") case '?': - return "", fmt.Errorf("masking op ? unsupported") + return terms, fmt.Errorf("masking op ? unsupported") case '^': - return "", fmt.Errorf("anchor op ^ unsupported") + return terms, fmt.Errorf("anchor op ^ unsupported") case '\\': - // Do nothing, just set backslash to true + backslash = true default: + if strings.ContainsRune(splitChars, c) { + if len(pgTerm) > 0 { + terms = append(terms, string(pgTerm)) + } + pgTerm = []rune{} + continue + } pgTerm = append(pgTerm, c) } - backslash = c == '\\' } } if backslash { - return "", fmt.Errorf("a CQL string must not end with a masking backslash") + return terms, fmt.Errorf("a CQL string must not end with a masking backslash") } - return string(pgTerm), nil + terms = append(terms, string(pgTerm)) + return terms, nil } func maskedLike(cqlTerm string) (string, bool, error) { @@ -95,13 +127,12 @@ func maskedLike(cqlTerm string) (string, bool, error) { case '^': return "", false, fmt.Errorf("anchor op ^ unsupported") case '\\': - // Do nothing, just set backslash to true + backslash = true case '%', '_': pgTerm = append(pgTerm, '\\', c) default: pgTerm = append(pgTerm, c) } - backslash = c == '\\' } } if backslash { @@ -117,30 +148,61 @@ func (f *FieldString) handleEmptyTerm(sc cql.SearchClause) string { return "" } +func (f *FieldString) generateFullText(sc cql.SearchClause, queryArgumentIndex int, pgfunc string) (string, []any, error) { + pgTerm, err := maskedExact(sc.Term) + if err != nil { + return "", nil, err + } + sql := "to_tsvector('" + f.language + "', " + f.column + ") @@ " + pgfunc + "('" + f.language + "', " + fmt.Sprintf("$%d", queryArgumentIndex) + ")" + return sql, []any{pgTerm}, nil +} + +func (f *FieldString) generateIn(sc cql.SearchClause, queryArgumentIndex int, not bool) (string, []any, error) { + pgTerms, err := maskedSplit(sc.Term, " ") + if err != nil { + return "", nil, err + } + sql := f.column + if not { + sql += " NOT" + } + sql += " IN(" + anyTerms := make([]any, len(pgTerms)) + for i, v := range pgTerms { + if i > 0 { + sql += ", " + } + sql += fmt.Sprintf("$%d", queryArgumentIndex+i) + anyTerms[i] = v + } + sql += ")" + return sql, anyTerms, nil +} + func (f *FieldString) Generate(sc cql.SearchClause, queryArgumentIndex int) (string, []any, error) { sql := f.handleEmptyTerm(sc) if sql != "" { return sql, nil, nil } fulltext := f.language != "" - var pgfunc string if fulltext { - if sc.Relation == cql.ADJ || sc.Relation == cql.EQ { - pgfunc = "phraseto_tsquery" - } else if sc.Relation == cql.ALL { - pgfunc = "plainto_tsquery" + switch sc.Relation { + case cql.ADJ, cql.EQ: + return f.generateFullText(sc, queryArgumentIndex, "phraseto_tsquery") + case cql.ALL: + return f.generateFullText(sc, queryArgumentIndex, "plainto_tsquery") } } - if pgfunc != "" { - pgTerm, err := maskedExact(sc.Term) - if err != nil { - return "", nil, err + if f.enableSplit { + if sc.Relation == cql.ANY { + return f.generateIn(sc, queryArgumentIndex, false) + } + if sc.Relation == cql.NE { + return f.generateIn(sc, queryArgumentIndex, true) } - sql := "to_tsvector('" + f.language + "', " + f.column + ") @@ " + pgfunc + "('" + f.language + "', " + fmt.Sprintf("$%d", queryArgumentIndex) + ")" - return sql, []any{pgTerm}, nil } if !f.enableExact { - return "", nil, &PgError{message: "exact search not supported"} + return "", nil, &PgError{message: "unsupported relation " + string(sc.Relation)} } if f.enableLike && (sc.Relation == cql.EQ || sc.Relation == cql.EXACT || sc.Relation == cql.NE) { pgTerm, ops, err := maskedLike(sc.Term) diff --git a/pgcql/pgcql_test.go b/pgcql/pgcql_test.go index 561af00..1eee2ca 100644 --- a/pgcql/pgcql_test.go +++ b/pgcql/pgcql_test.go @@ -27,8 +27,10 @@ func TestParsing(t *testing.T) { assert.Equal(t, title.GetColumn(), "Title", "GetColumn() should return the column name") - author := &FieldString{} - author.WithLikeOps().SetColumn("Author") + author := NewFieldString().WithLikeOps().WithColumn("Author") + + tag := &FieldString{} + tag.WithSplit().SetColumn("Tag") serverChoice := &FieldString{} serverChoice.WithExact().SetColumn("T") @@ -36,7 +38,7 @@ func TestParsing(t *testing.T) { full := &FieldString{} full.WithFullText("english") - def.AddField("title", title).AddField("author", author).AddField("cql.serverChoice", serverChoice).AddField("full", full) + def.AddField("title", title).AddField("author", author).AddField("cql.serverChoice", serverChoice).AddField("full", full).AddField("tag", tag) price := &FieldNumber{} def.AddField("price", price) @@ -54,6 +56,9 @@ func TestParsing(t *testing.T) { {"title==2", "Title = $1", []any{"2"}}, {"title exact 2", "Title = $1", []any{"2"}}, {"title<>2", "Title <> $1", []any{"2"}}, + {"tag any \"1 23 45\"", "Tag IN($1, $2, $3)", []any{"1", "23", "45"}}, + {"tag <> \"1 23 45\"", "Tag NOT IN($1, $2, $3)", []any{"1", "23", "45"}}, + {"tag any \"*\"", "error: masking op * unsupported", nil}, {"a or b and c", "(T = $1 OR T = $2) AND T = $3", []any{"a", "b", "c"}}, {"title = abc", "Title = $1", []any{"abc"}}, {"author = \"test\"", "Author = $1", []any{"test"}}, @@ -87,7 +92,8 @@ func TestParsing(t *testing.T) { {"full adj \"abc\"", "to_tsvector('english', full) @@ phraseto_tsquery('english', $1)", []any{"abc"}}, {"full all \"abc\"", "to_tsvector('english', full) @@ plainto_tsquery('english', $1)", []any{"abc"}}, {"full=\"a*\"", "error: masking op * unsupported", nil}, - {"full any x", "error: exact search not supported", nil}, + {"full any x", "error: unsupported relation any", nil}, + {"full > x", "error: unsupported relation >", nil}, {"price = 10", "price = $1", []any{10.0}}, {"price == 10", "price = $1", []any{10.0}}, {"price exact 10", "price = $1", []any{10.0}}, @@ -97,7 +103,7 @@ func TestParsing(t *testing.T) { {"price < 10.95", "price < $1", []any{10.95}}, {"price <= 10.95", "price <= $1", []any{10.95}}, {"price <= beta", "error: invalid number beta", nil}, - {"price all 10.95", "error: unsupported operator all", nil}, + {"price all 10.95", "error: unsupported relation all", nil}, {"price = \"\"", "price IS NOT NULL", []any{}}, } { var parser cql.Parser diff --git a/pgcql/pgx_test.go b/pgcql/pgx_test.go index 9975d4c..29ae282 100644 --- a/pgcql/pgx_test.go +++ b/pgcql/pgx_test.go @@ -54,16 +54,16 @@ func TestPgx(t *testing.T) { err := conn.Close(ctx) assert.NoError(t, err, "failed to close db connection") }() - _, err = conn.Exec(ctx, "CREATE TABLE mytable (id SERIAL PRIMARY KEY, title TEXT, author TEXT, year INT)") + _, err = conn.Exec(ctx, "CREATE TABLE mytable (id SERIAL PRIMARY KEY, title TEXT, author TEXT, tag TEXT, year INT)") assert.NoError(t, err, "failed to create mytable") var rows pgx.Rows - rows, err = conn.Query(ctx, "INSERT INTO mytable (title, author, year) VALUES ($1, $2, $3)", "the art of computer programming, volume 1", "donald e. knuth", 1968) + rows, err = conn.Query(ctx, "INSERT INTO mytable (title, author, tag, year) VALUES ($1, $2, $3, $4)", "the art of computer programming, volume 1", "donald e. knuth", "tag1", 1968) assert.NoError(t, err, "failed to insert data") rows.Close() - rows, err = conn.Query(ctx, "INSERT INTO mytable (title, author, year) VALUES ($1, $2, $3)", "the TeXbook", "d. e. knuth", 1984) + rows, err = conn.Query(ctx, "INSERT INTO mytable (title, author, tag, year) VALUES ($1, $2, $3, $4)", "the TeXbook", "d. e. knuth", "tag2", 1984) assert.NoError(t, err, "failed to insert data") rows.Close() @@ -77,6 +77,7 @@ func TestPgx(t *testing.T) { def.AddField("title", (&FieldString{}).WithExact()) def.AddField("author", (&FieldString{}).WithExact()) def.AddField("year", (&FieldNumber{})) + def.AddField("tag", (&FieldString{}).WithSplit()) var parser cql.Parser for _, testcase := range []struct { @@ -107,6 +108,9 @@ func TestPgx(t *testing.T) { {"year <= 1984", []int{1, 2}}, {"year >= 1984", []int{2, 3}}, {"year > 1984", []int{3}}, + {"tag any \"tag1\"", []int{1}}, + {"tag <> \"tag1\"", []int{2}}, + {"tag any \"tag1 tag2 tag3\"", []int{1, 2}}, } { runQuery(t, parser, conn, ctx, def, testcase.query, testcase.expectedIds) }