Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 63 additions & 13 deletions jrpc2/jsonrpc2.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jrpc2

import (
"encoding"
"encoding/hex"
"encoding/json"
"errors"
Expand Down Expand Up @@ -312,6 +313,10 @@ func GetNamedParams(target Method) map[string]interface{} {
}

func isZero(x interface{}) bool {
if x == nil {
return true
}

return reflect.DeepEqual(x, reflect.Zero(reflect.TypeOf(x)).Interface())
}

Expand Down Expand Up @@ -354,10 +359,7 @@ func ParseNamedParams(target Method, params map[string]interface{}) error {
targetValue := reflect.Indirect(reflect.ValueOf(target))
err := innerParseNamed(targetValue, params)
if err != nil {
fmt.Println("ERR")
fmt.Println(err)
fmt.Println(targetValue)
fmt.Println(params)
return err
}
return nil
}
Expand All @@ -372,11 +374,16 @@ func innerParseNamed(targetValue reflect.Value, params map[string]interface{}) e
continue
}
fT := tType.Field(i)
// check for the json tag match, as well a simple
// lower case name match
tag, _ := fT.Tag.Lookup("json")
if tag == key || key == strings.ToLower(fT.Name) {

name, omit := parseTag(tag)

if name == key || key == strings.ToLower(fT.Name) {
found = true
if omit && isZero(value) {
break
}

err := innerParse(targetValue, fVal, value)
if err != nil {
return err
Expand Down Expand Up @@ -411,14 +418,12 @@ func innerParse(targetValue reflect.Value, fVal reflect.Value, value interface{}
}

// json.RawMessage escape hatch
var eg json.RawMessage
if fVal.Type() == reflect.TypeOf(eg) {
if strings.Contains(fVal.Type().String(), "RawMessage") {
out, err := json.Marshal(value)
if err != nil {
return err
}
jm := json.RawMessage(out)
fVal.Set(reflect.ValueOf(jm))
fVal.Set(reflect.ValueOf(out).Convert(fVal.Type()))
return nil
}

Expand Down Expand Up @@ -473,8 +478,12 @@ func innerParse(targetValue reflect.Value, fVal reflect.Value, value interface{}
return nil
}

av := value.([]interface{})
av, ok := value.([]interface{})
if !ok {
return NewError(nil, InvalidParams, fmt.Sprintf("Expected JSON array for slice field %s, but got %T", fVal.Type().Name(), value))
}
fVal.Set(reflect.MakeSlice(fVal.Type(), len(av), len(av)))

for i := range av {
err := innerParse(targetValue, fVal.Index(i), av[i])
if err != nil {
Expand Down Expand Up @@ -517,12 +526,53 @@ func innerParse(targetValue reflect.Value, fVal reflect.Value, value interface{}
}
case reflect.Ptr:
if v.Kind() == reflect.Invalid {
// i'm afraid that's a nil, my dear
return nil
}

umType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
tmType := reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
ptrType := fVal.Type()

if ptrType.Implements(umType) || reflect.PointerTo(ptrType.Elem()).Implements(umType) {
n := reflect.New(ptrType.Elem())
data, err := json.Marshal(value)
if err != nil {
return err
}
if err := json.Unmarshal(data, n.Interface()); err != nil {
return err
}
fVal.Set(n)
return nil
}

if ptrType.Implements(tmType) || reflect.PointerTo(ptrType.Elem()).Implements(tmType) {
s, ok := value.(string)
if !ok {
return NewError(nil, InvalidParams, fmt.Sprintf("Expected string input for %s.%s, but got %T", targetValue.Type().Name(), fVal.Type().Name(), value))
}
n := reflect.New(ptrType.Elem())
if err := n.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(s)); err != nil {
return err
}
fVal.Set(n)
return nil
}

if fVal.Type().Elem().Kind() != reflect.Struct {
n := reflect.New(fVal.Type().Elem())
err := innerParse(targetValue, n.Elem(), value)
if err != nil {
return err
}
fVal.Set(n)
return nil
}

if v.Kind() != reflect.Map {
return NewError(nil, InvalidParams, fmt.Sprintf("Types don't match. Expected a map[string]interface{} from the JSON, instead got %s", v.Kind().String()))
}

if fVal.IsNil() {
// You need a new pointer object thing here
// so allocate one with this voodoo-magique
Expand Down
41 changes: 40 additions & 1 deletion jrpc2/jsonrpc2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/assert"
)

//// This section (below) is for method json marshalling,
// This section (below) is for method json marshalling,
// with special emphasis on how the parameters get marshalled
// and unmarshalled to/from 'Method' objects
type HelloMethod struct {
Expand Down Expand Up @@ -653,3 +653,42 @@ func TestServerRegistry(t *testing.T) {
err_ = server.Unregister(method)
assert.Equal(t, "Method not registered", err_.Error())
}

type OmitEmptyMethod struct {
Required string `json:"required"`
Optional *string `json:"optional,omitempty"`
Count *uint32 `json:"count,omitempty"`
}

func (m OmitEmptyMethod) New() interface{} {
return &OmitEmptyMethod{}
}

func (m OmitEmptyMethod) Call() (jrpc2.Result, error) {
return nil, nil
}

func (m OmitEmptyMethod) Name() string {
return "omit_empty"
}

func TestParsingOmitEmptyFields(t *testing.T) {
requestJson := `{"id":1,"method":"omit_empty","params":{"required":"value","optional":"hello","count":7},"jsonrpc":"2.0"}`
s := jrpc2.NewServer()
s.Register(&OmitEmptyMethod{})

var result jrpc2.Request
err := s.Unmarshal([]byte(requestJson), &result)
assert.Nil(t, err)

method, ok := result.Method.(*OmitEmptyMethod)
assert.True(t, ok)
assert.Equal(t, "omit_empty", method.Name())
assert.Equal(t, "value", method.Required)

assert.NotNil(t, method.Optional)
assert.Equal(t, "hello", *method.Optional)

assert.NotNil(t, method.Count)
assert.Equal(t, uint32(7), *method.Count)
}