Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 73 additions & 50 deletions executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package graphql

import (
"encoding/json"
"errors"
"fmt"
"math"
"reflect"
Expand All @@ -13,16 +14,18 @@ import (

. "github.com/playlyfe/go-graphql/language"
"github.com/playlyfe/go-graphql/utils"
nctx "golang.org/x/net/context"
)

type ResolveParams struct {
Executor *Executor
Request *Document
Schema *Document
Context interface{}
Source interface{}
Args map[string]interface{}
Field *Field
Executor *Executor
Request *Document
Schema *Document
Context interface{}
RequestContext nctx.Context
Source interface{}
Args map[string]interface{}
Field *Field
}

type Error struct {
Expand All @@ -42,6 +45,7 @@ func (list *ErrorList) Add(err *Error) {
}

type RequestContext struct {
NetContext nctx.Context
AppContext interface{}
Document *Document
ErrorList *ErrorList
Expand Down Expand Up @@ -84,6 +88,8 @@ type GroupedField struct {
Fields []*Field
}

var ContextDoneErr error = errors.New("Context canceled / done before query complete.")

/**
* Prepares an object map of variableValues of the correct type based on the
* provided variable definitions and arbitrary input. If the input cannot be
Expand Down Expand Up @@ -113,7 +119,6 @@ func (executor *Executor) resolveNamedType(ntype ASTNode) *NamedType {
return unmodifiedType.(*NamedType)
}
}
return nil
}

func (executor *Executor) printType(ttype ASTNode) string {
Expand All @@ -127,14 +132,6 @@ func (executor *Executor) printType(ttype ASTNode) string {
panic("Unexpected AST node type")
}

// TODO: Implement type printer
/*func PrintType (ntype ASTNode) string {
output := ""


}
}*/

/**
* Prepares an object map of argument values given a list of argument
* definitions and list of argument AST nodes.
Expand Down Expand Up @@ -292,7 +289,7 @@ func (executor *Executor) variableValue(context interface{}, ntype ASTNode, inpu
}
}
}
for key, _ := range object {
for key := range object {
if _, exists := inputType.FieldIndex[key]; !exists {
return nil, &GraphQLError{
Message: fmt.Sprintf("In field %q: Unknown field", key),
Expand Down Expand Up @@ -740,6 +737,10 @@ func handleGQLError(result map[string]interface{}, err error) (map[string]interf
}

func (executor *Executor) Execute(context interface{}, request string, variables map[string]interface{}, operationName string) (map[string]interface{}, error) {
return executor.ExecuteWithContext(context, request, variables, operationName, nil)
}

func (executor *Executor) ExecuteWithContext(context interface{}, request string, variables map[string]interface{}, operationName string, netContext nctx.Context) (map[string]interface{}, error) {
parser := &Parser{}
result := map[string]interface{}{}
document, err := parser.Parse(&ParseParams{
Expand All @@ -749,7 +750,17 @@ func (executor *Executor) Execute(context interface{}, request string, variables
return handleGQLError(result, err)
}

if netContext == nil {
netContext = nctx.Background()
}

subContext, cancelFn := nctx.WithCancel(netContext)
defer func() {
cancelFn()
}()

reqCtx := &RequestContext{
NetContext: subContext,
AppContext: context,
Document: document,
ErrorList: &ErrorList{},
Expand Down Expand Up @@ -782,10 +793,11 @@ func (executor *Executor) Execute(context interface{}, request string, variables
} else {
if executor.Before != nil {
err = executor.Before(&ResolveParams{
Executor: executor,
Schema: executor.Schema.Document,
Request: reqCtx.Document,
Context: reqCtx.AppContext,
RequestContext: reqCtx.NetContext,
Executor: executor,
Schema: executor.Schema.Document,
Request: reqCtx.Document,
Context: reqCtx.AppContext,
}, selectedOperation.Operation)
if err != nil {
result, err = handleGQLError(result, err)
Expand Down Expand Up @@ -824,10 +836,11 @@ func (executor *Executor) Execute(context interface{}, request string, variables

if executor.After != nil {
err = executor.After(&ResolveParams{
Executor: executor,
Schema: executor.Schema.Document,
Request: reqCtx.Document,
Context: reqCtx.AppContext,
RequestContext: reqCtx.NetContext,
Executor: executor,
Schema: executor.Schema.Document,
Request: reqCtx.Document,
Context: reqCtx.AppContext,
}, result)
if err != nil {
return nil, err
Expand All @@ -838,12 +851,10 @@ func (executor *Executor) Execute(context interface{}, request string, variables
}

func (executor *Executor) selectionSet(reqCtx *RequestContext, isParallel bool, objectType *ObjectTypeDefinition, source interface{}, selectionSet *SelectionSet) (map[string]interface{}, error) {
//log.Printf("collecting fields")
groupedFields, err := executor.collectFields(reqCtx, objectType, selectionSet, &utils.Set{})
if err != nil {
return nil, err
}
//log.Printf("resolving fields")
return executor.resolveGroupedFields(reqCtx, isParallel, objectType, source, groupedFields)

}
Expand Down Expand Up @@ -1040,15 +1051,24 @@ func (executor *Executor) doesFragmentTypeApply(objectType *ObjectTypeDefinition
func (executor *Executor) resolveGroupedFields(reqCtx *RequestContext, isParallel bool, objectType *ObjectTypeDefinition, source interface{}, groupedFields []*GroupedField) (map[string]interface{}, error) {
result := map[string]interface{}{}

// TODO: Use go routines?
if isParallel && !executor.Debug {
doneCh := reqCtx.NetContext.Done()
errs := []error{}
panics := []interface{}{}
wg := sync.WaitGroup{}
mutex := sync.Mutex{}
errMutex := sync.Mutex{}
contextCanceled := false

GroupedFieldsLoop:
for _, groupForResponseKey := range groupedFields {
//log.Printf("evaluating field entry for '%s'", responseKey)
select {
case <-doneCh:
contextCanceled = true
break GroupedFieldsLoop
default:
}

wg.Add(1)
go func(responseKey string, fields []*Field) {
defer func() {
Expand All @@ -1066,9 +1086,7 @@ func (executor *Executor) resolveGroupedFields(reqCtx *RequestContext, isParalle
errs = append(errs, err)
errMutex.Unlock()
}
//log.Printf("Adding '%s' with value '%#v' to response", key, value)
if key != "" {
//log.Printf("Adding key '%s' with value '%#v' to response", responseKey, value)
mutex.Lock()
result[responseKey] = value
mutex.Unlock()
Expand All @@ -1082,18 +1100,16 @@ func (executor *Executor) resolveGroupedFields(reqCtx *RequestContext, isParalle
}
if len(errs) > 0 {
return nil, errs[0]
} else if contextCanceled {
return nil, ContextDoneErr
}

} else {
for _, groupForResponseKey := range groupedFields {
//log.Printf("evaluating field entry for '%s'", responseKey)
key, value, err := executor.getFieldEntry(reqCtx, objectType, source, groupForResponseKey.ResponseKey, groupForResponseKey.Fields)
if err != nil {
return nil, err
}
//log.Printf("Adding '%s' with value '%#v' to response", key, value)
if key != "" {
//log.Printf("Adding key '%s' with value '%#v' to response", responseKey, value)
result[groupForResponseKey.ResponseKey] = value
}
}
Expand All @@ -1103,14 +1119,11 @@ func (executor *Executor) resolveGroupedFields(reqCtx *RequestContext, isParalle

func (executor *Executor) getFieldEntry(reqCtx *RequestContext, objectType *ObjectTypeDefinition, object interface{}, responseKey string, fields []*Field) (string, interface{}, error) {
firstField := fields[0]
//log.Printf("Test %#v", firstField.Name.Value)
fieldType := executor.getFieldTypeFromObjectType(objectType, firstField)
if fieldType == nil {
//log.Printf("field type of selection '%s' could not be determined", firstField.Name.Value)
return "", nil, nil
}
resolvedObject, err := executor.resolveFieldOnObject(reqCtx, objectType, object, fieldType, firstField)
//log.Printf("field %s resolved to %#v", firstField.Name.Value, resolvedObject)
if err != nil {
return "", nil, err
}
Expand Down Expand Up @@ -1161,12 +1174,9 @@ func (executor *Executor) completeValueCatchingError(reqCtx *RequestContext, obj
}

func (executor *Executor) completeValue(reqCtx *RequestContext, objectType *ObjectTypeDefinition, fieldType ASTNode, field *Field, result interface{}, subSelectionSet *SelectionSet) (interface{}, error) {
//var err error
//log.Printf("completing value on %#v", result)
if nonNullType, ok := fieldType.(*NonNullType); ok {
innerType := nonNullType.Type
completedResult, err := executor.completeValue(reqCtx, objectType, innerType, field, result, subSelectionSet)
//log.Printf("completed result of %#v is %#v", result, completedResult)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1342,12 +1352,17 @@ func (executor *Executor) completeValue(reqCtx *RequestContext, objectType *Obje
Field: field,
}
}
return nil, nil
}

func (executor *Executor) resolveFieldOnObject(reqCtx *RequestContext, objectType *ObjectTypeDefinition, object interface{}, fieldType ASTNode, firstField *Field) (interface{}, error) {
doneCh := reqCtx.NetContext.Done()
select {
case <-doneCh:
return nil, ContextDoneErr
default:
}

resolverName := objectType.Name.Value + "/" + firstField.Name.Value
resolverName := fmt.Sprintf("%s/%s", objectType.Name.Value, firstField.Name.Value)
if resolver, ok := executor.Resolvers[resolverName]; ok {
var resolveFn ResolveFn
var beforeFn BeforeFn
Expand Down Expand Up @@ -1378,13 +1393,14 @@ func (executor *Executor) resolveFieldOnObject(reqCtx *RequestContext, objectTyp

// Execute the before function if it is defined
resolveParams := &ResolveParams{
Executor: executor,
Schema: executor.Schema.Document,
Request: reqCtx.Document,
Context: reqCtx.AppContext,
Source: object,
Args: args,
Field: firstField,
RequestContext: reqCtx.NetContext,
Executor: executor,
Schema: executor.Schema.Document,
Request: reqCtx.Document,
Context: reqCtx.AppContext,
Source: object,
Args: args,
Field: firstField,
}

if beforeFn != nil {
Expand Down Expand Up @@ -1438,6 +1454,13 @@ func (executor *Executor) resolveFieldOnObject(reqCtx *RequestContext, objectTyp
return result, nil
}

// Check again to see if the context is finished before we bother doing more work.
select {
case <-doneCh:
return nil, ContextDoneErr
default:
}

sourceVal := reflect.ValueOf(object)
sourceValType := sourceVal.Type()
sourceValKind := sourceValType.Kind()
Expand Down
24 changes: 12 additions & 12 deletions executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import (
)

type Author struct {
ID int `json:"id"`
Name string `json:"name"`
IsPublished string `json:"isPublished"`
Author *Author `json:"author"`
Title string `json:"title"`
Body string `json:"body"`
keywords []interface{} `json:"keywords"`
RecentArticle *Article `json:"recentArticle"`
ID int `json:"id"`
Name string `json:"name"`
IsPublished string `json:"isPublished"`
Author *Author `json:"author"`
Title string `json:"title"`
Body string `json:"body"`
keywords []interface{}
RecentArticle *Article `json:"recentArticle"`
}

type Image struct {
Expand Down Expand Up @@ -3392,9 +3392,9 @@ func TestExecutor(t *testing.T) {
"deserializedValue": nil,
},
"errors": []map[string]interface{}{
map[string]interface{}{
{
"locations": []map[string]interface{}{
map[string]interface{}{
{
"line": 3,
"column": 13,
},
Expand Down Expand Up @@ -3501,9 +3501,9 @@ func TestExecutor(t *testing.T) {
"deserializedValue": nil,
},
"errors": []map[string]interface{}{
map[string]interface{}{
{
"locations": []map[string]interface{}{
map[string]interface{}{
{
"line": 3,
"column": 11,
},
Expand Down