Skip to content

Commit 3bbfed9

Browse files
Add ConvertOrderBy method for MongoDB-style sorting (#40)
Add `ORDER BY` support with `ConvertOrderBy` method. - Converts MongoDB-style sort objects to PostgreSQL `ORDER BY` clauses. - Supports both regular columns and JSONB fields with dual sorting. - Includes integration tests and fuzz tests.
1 parent da10c9d commit 3bbfed9

File tree

7 files changed

+522
-13
lines changed

7 files changed

+522
-13
lines changed

README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,50 @@ values := []any{"aztec", "nuke", "", 2, 10}
9797
(given "customdata" is configured with `filter.WithNestedJSONB("customdata", "password", "playerCount")`)
9898

9999

100+
## Order By Support
101+
102+
In addition to filtering, this package also supports converting MongoDB-style sort objects into PostgreSQL ORDER BY clauses using the `ConvertOrderBy` method:
103+
104+
```go
105+
// Convert a sort object to an ORDER BY clause
106+
sortInput := []byte(`{"playerCount": -1, "name": 1}`)
107+
orderBy, err := converter.ConvertOrderBy(sortInput)
108+
if err != nil {
109+
// handle error
110+
}
111+
fmt.Println(orderBy) // "playerCount" DESC, "name" ASC
112+
113+
db.Query("SELECT * FROM games ORDER BY " + orderBy)
114+
```
115+
116+
### Sort Direction Values:
117+
- `1`: Ascending (ASC)
118+
- `-1`: Descending (DESC)
119+
120+
### Return value
121+
The `ConvertOrderBy` method returns a string that can be directly used in an SQL ORDER BY clause. When the input is an empty object or `nil`, it returns an empty string. Keep in mind that the method does not add the `ORDER BY` keyword itself; you need to include it in your SQL query.
122+
123+
### JSONB Field Sorting:
124+
For JSONB fields, the package generates sophisticated ORDER BY clauses that handle both numeric and text sorting:
125+
126+
```go
127+
// With WithNestedJSONB("metadata", "created_at"):
128+
sortInput := []byte(`{"score": -1}`)
129+
orderBy, err := converter.ConvertOrderBy(sortInput)
130+
// Generates: (CASE WHEN jsonb_typeof(metadata->'score') = 'number' THEN (metadata->>'score')::numeric END) DESC NULLS LAST, metadata->>'score' DESC NULLS LAST
131+
```
132+
133+
This ensures proper sorting whether the JSONB field contains numeric or text values.
134+
135+
> [!TIP]
136+
> Always add an `, id ASC` to your ORDER BY clause to ensure a consistent order (where `id` is your primary key).
137+
> ```go
138+
> if orderBy != "" {
139+
> orderBy += ", "
140+
> }
141+
> orderBy += "id ASC"
142+
> ```
143+
100144
## Difference with MongoDB
101145
102146
- The MongoDB query filters don't have the option to compare fields with each other. This package adds the `$field` operator to compare fields with each other.

filter/converter.go

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
208208
// `column != ANY(...)` does not work, so we need to do `NOT column = ANY(...)` instead.
209209
neg = "NOT "
210210
}
211-
inner = append(inner, fmt.Sprintf("(%s%s = ANY($%d))", neg, c.columnName(key), paramIndex))
211+
inner = append(inner, fmt.Sprintf("(%s%s = ANY($%d))", neg, c.columnName(key, true), paramIndex))
212212
paramIndex++
213213
if c.arrayDriver != nil {
214214
v[operator] = c.arrayDriver(v[operator])
@@ -245,7 +245,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
245245
//
246246
// EXISTS (SELECT 1 FROM unnest("foo") AS __filter_placeholder WHERE ("__filter_placeholder"::text = $1))
247247
//
248-
inner = append(inner, fmt.Sprintf("EXISTS (SELECT 1 FROM unnest(%s) AS %s WHERE %s)", c.columnName(key), c.placeholderName, innerConditions))
248+
inner = append(inner, fmt.Sprintf("EXISTS (SELECT 1 FROM unnest(%s) AS %s WHERE %s)", c.columnName(key, true), c.placeholderName, innerConditions))
249249
}
250250
values = append(values, innerValues...)
251251
case "$field":
@@ -254,7 +254,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
254254
return "", nil, fmt.Errorf("invalid value for $field operator (must be string): %v", v[operator])
255255
}
256256

257-
inner = append(inner, fmt.Sprintf("(%s = %s)", c.columnName(key), c.columnName(vv)))
257+
inner = append(inner, fmt.Sprintf("(%s = %s)", c.columnName(key, true), c.columnName(vv, true)))
258258
default:
259259
value := v[operator]
260260
isNumericOperator := false
@@ -274,8 +274,8 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
274274
return "", nil, fmt.Errorf("invalid value for %s operator (must be object with $field key only): %v", operator, value)
275275
}
276276

277-
left := c.columnName(key)
278-
right := c.columnName(field)
277+
left := c.columnName(key, true)
278+
right := c.columnName(field, true)
279279

280280
if isNumericOperator {
281281
if c.isNestedColumn(key) {
@@ -304,9 +304,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
304304
}
305305

306306
if isNumericOperator && isNumeric(value) && c.isNestedColumn(key) {
307-
inner = append(inner, fmt.Sprintf("((%s)::numeric %s $%d)", c.columnName(key), op, paramIndex))
307+
inner = append(inner, fmt.Sprintf("((%s)::numeric %s $%d)", c.columnName(key, true), op, paramIndex))
308308
} else {
309-
inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key), op, paramIndex))
309+
inner = append(inner, fmt.Sprintf("(%s %s $%d)", c.columnName(key, true), op, paramIndex))
310310
}
311311
paramIndex++
312312
values = append(values, value)
@@ -329,9 +329,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
329329
}
330330
}
331331
if isNestedColumn {
332-
conditions = append(conditions, fmt.Sprintf("(jsonb_path_match(%s, 'exists($.%s)') AND %s IS NULL)", c.nestedColumn, key, c.columnName(key)))
332+
conditions = append(conditions, fmt.Sprintf("(jsonb_path_match(%s, 'exists($.%s)') AND %s IS NULL)", c.nestedColumn, key, c.columnName(key, true)))
333333
} else {
334-
conditions = append(conditions, fmt.Sprintf("(%s IS NULL)", c.columnName(key)))
334+
conditions = append(conditions, fmt.Sprintf("(%s IS NULL)", c.columnName(key, true)))
335335
}
336336
default:
337337
// Prevent cryptic errors like:
@@ -341,9 +341,9 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
341341
}
342342
if isNumeric(value) && c.isNestedColumn(key) {
343343
// If the value is numeric and the column is a nested JSONB column, we need to cast the column to numeric.
344-
conditions = append(conditions, fmt.Sprintf("((%s)::numeric = $%d)", c.columnName(key), paramIndex))
344+
conditions = append(conditions, fmt.Sprintf("((%s)::numeric = $%d)", c.columnName(key, true), paramIndex))
345345
} else {
346-
conditions = append(conditions, fmt.Sprintf("(%s = $%d)", c.columnName(key), paramIndex))
346+
conditions = append(conditions, fmt.Sprintf("(%s = $%d)", c.columnName(key, true), paramIndex))
347347
}
348348
paramIndex++
349349
values = append(values, value)
@@ -358,7 +358,7 @@ func (c *Converter) convertFilter(filter map[string]any, paramIndex int) (string
358358
return result, values, nil
359359
}
360360

361-
func (c *Converter) columnName(column string) string {
361+
func (c *Converter) columnName(column string, jsonFieldAsText bool) string {
362362
if column == c.placeholderName {
363363
return fmt.Sprintf(`%q::text`, column)
364364
}
@@ -370,7 +370,10 @@ func (c *Converter) columnName(column string) string {
370370
return fmt.Sprintf("%q", column)
371371
}
372372
}
373-
return fmt.Sprintf(`%q->>'%s'`, c.nestedColumn, column)
373+
if jsonFieldAsText {
374+
return fmt.Sprintf(`%q->>'%s'`, c.nestedColumn, column)
375+
}
376+
return fmt.Sprintf(`%q->'%s'`, c.nestedColumn, column)
374377
}
375378

376379
func (c *Converter) isColumnAllowed(column string) bool {
@@ -404,3 +407,77 @@ func (c *Converter) isNestedColumn(column string) bool {
404407
}
405408
return true
406409
}
410+
411+
// ConvertOrderBy converts a JSON object with field names and sort directions
412+
// into a PostgreSQL ORDER BY clause. The JSON object should have keys with values
413+
// of 1 (ASC) or -1 (DESC).
414+
//
415+
// For JSONB fields, it generates clauses that handle both numeric and text sorting.
416+
//
417+
// Example: {"playerCount": -1, "name": 1} -> "playerCount DESC, name ASC"
418+
func (c *Converter) ConvertOrderBy(query []byte) (string, error) {
419+
keyValues, err := objectInOrder(query)
420+
if err != nil {
421+
return "", err
422+
}
423+
424+
parts := make([]string, 0, len(keyValues))
425+
426+
for _, kv := range keyValues {
427+
key, value := kv.Key, kv.Value
428+
429+
if !isValidPostgresIdentifier(key) {
430+
return "", fmt.Errorf("invalid column name: %s", key)
431+
}
432+
if !c.isColumnAllowed(key) {
433+
return "", ColumnNotAllowedError{Column: key}
434+
}
435+
436+
// Convert value to number for direction
437+
var direction string
438+
switch v := value.(type) {
439+
case json.Number:
440+
if num, err := v.Int64(); err == nil {
441+
switch num {
442+
case 1:
443+
direction = "ASC"
444+
case -1:
445+
direction = "DESC"
446+
default:
447+
return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value)
448+
}
449+
} else {
450+
return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value)
451+
}
452+
case float64:
453+
switch v {
454+
case 1:
455+
direction = "ASC"
456+
case -1:
457+
direction = "DESC"
458+
default:
459+
return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value)
460+
}
461+
default:
462+
return "", fmt.Errorf("invalid order direction for field %s: %v (must be 1 or -1)", key, value)
463+
}
464+
465+
var fieldClause string
466+
if c.isNestedColumn(key) {
467+
// For JSONB fields, handle both numeric and text sorting.
468+
// We need to use the raw JSONB reference for jsonb_typeof, but columnName() for the actual sorting
469+
fieldClause = fmt.Sprintf("(CASE WHEN jsonb_typeof(%s) = 'number' THEN (%s)::numeric END) %s NULLS LAST, %s %s NULLS LAST", c.columnName(key, false), c.columnName(key, true), direction, c.columnName(key, true), direction)
470+
} else {
471+
// Regular field.
472+
fieldClause = fmt.Sprintf(`%s %s NULLS LAST`, c.columnName(key, true), direction)
473+
}
474+
475+
parts = append(parts, fieldClause)
476+
}
477+
478+
if len(parts) == 0 {
479+
return "", nil
480+
}
481+
482+
return strings.Join(parts, ", "), nil
483+
}

filter/converter_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,3 +641,134 @@ func TestConverter_AccessControl(t *testing.T) {
641641
t.Run("nested but disallow password, disallow",
642642
f(`{"password": "hacks"}`, no("password"), filter.WithNestedJSONB("meta", "created_at"), filter.WithDisallowColumns("password")))
643643
}
644+
645+
func TestConverter_ConvertOrderBy(t *testing.T) {
646+
tests := []struct {
647+
name string
648+
options []filter.Option
649+
input string
650+
expected string
651+
err error
652+
}{
653+
{
654+
"single field ascending",
655+
[]filter.Option{filter.WithAllowAllColumns()},
656+
`{"playerCount": 1}`,
657+
`"playerCount" ASC NULLS LAST`,
658+
nil,
659+
},
660+
{
661+
"single field descending",
662+
[]filter.Option{filter.WithAllowAllColumns()},
663+
`{"playerCount": -1}`,
664+
`"playerCount" DESC NULLS LAST`,
665+
nil,
666+
},
667+
{
668+
"multiple fields",
669+
[]filter.Option{filter.WithAllowAllColumns()},
670+
`{"playerCount": -1, "name": 1}`,
671+
`"playerCount" DESC NULLS LAST, "name" ASC NULLS LAST`,
672+
nil,
673+
},
674+
{
675+
"nested JSONB single field ascending",
676+
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
677+
`{"map": 1}`,
678+
`(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) ASC NULLS LAST, "customdata"->>'map' ASC NULLS LAST`,
679+
nil,
680+
},
681+
{
682+
"nested JSONB single field descending",
683+
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
684+
`{"map": -1}`,
685+
`(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) DESC NULLS LAST, "customdata"->>'map' DESC NULLS LAST`,
686+
nil,
687+
},
688+
{
689+
"nested JSONB multiple fields",
690+
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
691+
`{"map": 1, "bar": -1}`,
692+
`(CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) ASC NULLS LAST, "customdata"->>'map' ASC NULLS LAST, (CASE WHEN jsonb_typeof("customdata"->'bar') = 'number' THEN ("customdata"->>'bar')::numeric END) DESC NULLS LAST, "customdata"->>'bar' DESC NULLS LAST`,
693+
nil,
694+
},
695+
{
696+
"mixed nested and regular fields",
697+
[]filter.Option{filter.WithNestedJSONB("customdata", "created_at")},
698+
`{"created_at": 1, "map": -1}`,
699+
`"created_at" ASC NULLS LAST, (CASE WHEN jsonb_typeof("customdata"->'map') = 'number' THEN ("customdata"->>'map')::numeric END) DESC NULLS LAST, "customdata"->>'map' DESC NULLS LAST`,
700+
nil,
701+
},
702+
{
703+
"field name with spaces",
704+
[]filter.Option{filter.WithAllowAllColumns()},
705+
`{"my_field": 1}`,
706+
`"my_field" ASC NULLS LAST`,
707+
nil,
708+
},
709+
{
710+
"empty object",
711+
[]filter.Option{filter.WithAllowAllColumns()},
712+
`{}`,
713+
``,
714+
nil,
715+
},
716+
{
717+
"invalid field name for SQL injection",
718+
[]filter.Option{filter.WithAllowAllColumns()},
719+
`{"my field": 1}`,
720+
``,
721+
fmt.Errorf("invalid column name: my field"),
722+
},
723+
{
724+
"invalid direction value",
725+
[]filter.Option{filter.WithAllowAllColumns()},
726+
`{"playerCount": 2}`,
727+
``,
728+
fmt.Errorf("invalid order direction for field playerCount: 2 (must be 1 or -1)"),
729+
},
730+
{
731+
"invalid direction string",
732+
[]filter.Option{filter.WithAllowAllColumns()},
733+
`{"playerCount": "asc"}`,
734+
``,
735+
fmt.Errorf("invalid order direction for field playerCount: asc (must be 1 or -1)"),
736+
},
737+
{
738+
"disallowed column",
739+
[]filter.Option{filter.WithAllowColumns("name")},
740+
`{"playerCount": 1}`,
741+
``,
742+
filter.ColumnNotAllowedError{Column: "playerCount"},
743+
},
744+
}
745+
746+
for _, tt := range tests {
747+
t.Run(tt.name, func(t *testing.T) {
748+
converter, err := filter.NewConverter(tt.options...)
749+
if err != nil {
750+
t.Fatalf("Failed to create converter: %v", err)
751+
}
752+
753+
result, err := converter.ConvertOrderBy([]byte(tt.input))
754+
755+
if tt.err != nil {
756+
if err == nil {
757+
t.Fatalf("Expected error %v, got nil", tt.err)
758+
}
759+
if err.Error() != tt.err.Error() {
760+
t.Errorf("Expected error %v, got %v", tt.err, err)
761+
}
762+
return
763+
}
764+
765+
if err != nil {
766+
t.Fatalf("Unexpected error: %v", err)
767+
}
768+
769+
if result != tt.expected {
770+
t.Errorf("Expected %q, got %q", tt.expected, result)
771+
}
772+
})
773+
}
774+
}

filter/errors.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,12 @@ type ColumnNotAllowedError struct {
1212
func (e ColumnNotAllowedError) Error() string {
1313
return fmt.Sprintf("column not allowed: %s", e.Column)
1414
}
15+
16+
type InvalidOrderDirectionError struct {
17+
Field string
18+
Value any
19+
}
20+
21+
func (e InvalidOrderDirectionError) Error() string {
22+
return fmt.Sprintf("invalid order direction for field %s: %v (must be 1 or -1)", e.Field, e.Value)
23+
}

0 commit comments

Comments
 (0)