diff --git a/executor.go b/executor.go index 500b768..ed7fb0e 100644 --- a/executor.go +++ b/executor.go @@ -2,6 +2,7 @@ package graphql import ( "encoding/json" + "errors" "fmt" "math" "reflect" @@ -13,16 +14,18 @@ import ( . "github.com/playlyfe/go-graphql/language" "github.com/playlyfe/go-graphql/utils" + nctx "golang.org/x/net/context" ) type ResolveParams struct { - Executor *Executor - Request *Document - Schema *Document - Context interface{} - Source interface{} - Args map[string]interface{} - Field *Field + Executor *Executor + Request *Document + Schema *Document + Context interface{} + RequestContext nctx.Context + Source interface{} + Args map[string]interface{} + Field *Field } type Error struct { @@ -42,6 +45,7 @@ func (list *ErrorList) Add(err *Error) { } type RequestContext struct { + NetContext nctx.Context AppContext interface{} Document *Document ErrorList *ErrorList @@ -84,6 +88,8 @@ type GroupedField struct { Fields []*Field } +var ContextDoneErr error = errors.New("Context canceled / done before query complete.") + /** * Prepares an object map of variableValues of the correct type based on the * provided variable definitions and arbitrary input. If the input cannot be @@ -113,7 +119,6 @@ func (executor *Executor) resolveNamedType(ntype ASTNode) *NamedType { return unmodifiedType.(*NamedType) } } - return nil } func (executor *Executor) printType(ttype ASTNode) string { @@ -127,14 +132,6 @@ func (executor *Executor) printType(ttype ASTNode) string { panic("Unexpected AST node type") } -// TODO: Implement type printer -/*func PrintType (ntype ASTNode) string { - output := "" - - - } -}*/ - /** * Prepares an object map of argument values given a list of argument * definitions and list of argument AST nodes. @@ -292,7 +289,7 @@ func (executor *Executor) variableValue(context interface{}, ntype ASTNode, inpu } } } - for key, _ := range object { + for key := range object { if _, exists := inputType.FieldIndex[key]; !exists { return nil, &GraphQLError{ Message: fmt.Sprintf("In field %q: Unknown field", key), @@ -740,6 +737,10 @@ func handleGQLError(result map[string]interface{}, err error) (map[string]interf } func (executor *Executor) Execute(context interface{}, request string, variables map[string]interface{}, operationName string) (map[string]interface{}, error) { + return executor.ExecuteWithContext(context, request, variables, operationName, nil) +} + +func (executor *Executor) ExecuteWithContext(context interface{}, request string, variables map[string]interface{}, operationName string, netContext nctx.Context) (map[string]interface{}, error) { parser := &Parser{} result := map[string]interface{}{} document, err := parser.Parse(&ParseParams{ @@ -749,7 +750,17 @@ func (executor *Executor) Execute(context interface{}, request string, variables return handleGQLError(result, err) } + if netContext == nil { + netContext = nctx.Background() + } + + subContext, cancelFn := nctx.WithCancel(netContext) + defer func() { + cancelFn() + }() + reqCtx := &RequestContext{ + NetContext: subContext, AppContext: context, Document: document, ErrorList: &ErrorList{}, @@ -782,10 +793,11 @@ func (executor *Executor) Execute(context interface{}, request string, variables } else { if executor.Before != nil { err = executor.Before(&ResolveParams{ - Executor: executor, - Schema: executor.Schema.Document, - Request: reqCtx.Document, - Context: reqCtx.AppContext, + RequestContext: reqCtx.NetContext, + Executor: executor, + Schema: executor.Schema.Document, + Request: reqCtx.Document, + Context: reqCtx.AppContext, }, selectedOperation.Operation) if err != nil { result, err = handleGQLError(result, err) @@ -824,10 +836,11 @@ func (executor *Executor) Execute(context interface{}, request string, variables if executor.After != nil { err = executor.After(&ResolveParams{ - Executor: executor, - Schema: executor.Schema.Document, - Request: reqCtx.Document, - Context: reqCtx.AppContext, + RequestContext: reqCtx.NetContext, + Executor: executor, + Schema: executor.Schema.Document, + Request: reqCtx.Document, + Context: reqCtx.AppContext, }, result) if err != nil { return nil, err @@ -838,12 +851,10 @@ func (executor *Executor) Execute(context interface{}, request string, variables } func (executor *Executor) selectionSet(reqCtx *RequestContext, isParallel bool, objectType *ObjectTypeDefinition, source interface{}, selectionSet *SelectionSet) (map[string]interface{}, error) { - //log.Printf("collecting fields") groupedFields, err := executor.collectFields(reqCtx, objectType, selectionSet, &utils.Set{}) if err != nil { return nil, err } - //log.Printf("resolving fields") return executor.resolveGroupedFields(reqCtx, isParallel, objectType, source, groupedFields) } @@ -1040,15 +1051,24 @@ func (executor *Executor) doesFragmentTypeApply(objectType *ObjectTypeDefinition func (executor *Executor) resolveGroupedFields(reqCtx *RequestContext, isParallel bool, objectType *ObjectTypeDefinition, source interface{}, groupedFields []*GroupedField) (map[string]interface{}, error) { result := map[string]interface{}{} - // TODO: Use go routines? if isParallel && !executor.Debug { + doneCh := reqCtx.NetContext.Done() errs := []error{} panics := []interface{}{} wg := sync.WaitGroup{} mutex := sync.Mutex{} errMutex := sync.Mutex{} + contextCanceled := false + + GroupedFieldsLoop: for _, groupForResponseKey := range groupedFields { - //log.Printf("evaluating field entry for '%s'", responseKey) + select { + case <-doneCh: + contextCanceled = true + break GroupedFieldsLoop + default: + } + wg.Add(1) go func(responseKey string, fields []*Field) { defer func() { @@ -1066,9 +1086,7 @@ func (executor *Executor) resolveGroupedFields(reqCtx *RequestContext, isParalle errs = append(errs, err) errMutex.Unlock() } - //log.Printf("Adding '%s' with value '%#v' to response", key, value) if key != "" { - //log.Printf("Adding key '%s' with value '%#v' to response", responseKey, value) mutex.Lock() result[responseKey] = value mutex.Unlock() @@ -1082,18 +1100,16 @@ func (executor *Executor) resolveGroupedFields(reqCtx *RequestContext, isParalle } if len(errs) > 0 { return nil, errs[0] + } else if contextCanceled { + return nil, ContextDoneErr } - } else { for _, groupForResponseKey := range groupedFields { - //log.Printf("evaluating field entry for '%s'", responseKey) key, value, err := executor.getFieldEntry(reqCtx, objectType, source, groupForResponseKey.ResponseKey, groupForResponseKey.Fields) if err != nil { return nil, err } - //log.Printf("Adding '%s' with value '%#v' to response", key, value) if key != "" { - //log.Printf("Adding key '%s' with value '%#v' to response", responseKey, value) result[groupForResponseKey.ResponseKey] = value } } @@ -1103,14 +1119,11 @@ func (executor *Executor) resolveGroupedFields(reqCtx *RequestContext, isParalle func (executor *Executor) getFieldEntry(reqCtx *RequestContext, objectType *ObjectTypeDefinition, object interface{}, responseKey string, fields []*Field) (string, interface{}, error) { firstField := fields[0] - //log.Printf("Test %#v", firstField.Name.Value) fieldType := executor.getFieldTypeFromObjectType(objectType, firstField) if fieldType == nil { - //log.Printf("field type of selection '%s' could not be determined", firstField.Name.Value) return "", nil, nil } resolvedObject, err := executor.resolveFieldOnObject(reqCtx, objectType, object, fieldType, firstField) - //log.Printf("field %s resolved to %#v", firstField.Name.Value, resolvedObject) if err != nil { return "", nil, err } @@ -1161,12 +1174,9 @@ func (executor *Executor) completeValueCatchingError(reqCtx *RequestContext, obj } func (executor *Executor) completeValue(reqCtx *RequestContext, objectType *ObjectTypeDefinition, fieldType ASTNode, field *Field, result interface{}, subSelectionSet *SelectionSet) (interface{}, error) { - //var err error - //log.Printf("completing value on %#v", result) if nonNullType, ok := fieldType.(*NonNullType); ok { innerType := nonNullType.Type completedResult, err := executor.completeValue(reqCtx, objectType, innerType, field, result, subSelectionSet) - //log.Printf("completed result of %#v is %#v", result, completedResult) if err != nil { return nil, err } @@ -1342,12 +1352,17 @@ func (executor *Executor) completeValue(reqCtx *RequestContext, objectType *Obje Field: field, } } - return nil, nil } func (executor *Executor) resolveFieldOnObject(reqCtx *RequestContext, objectType *ObjectTypeDefinition, object interface{}, fieldType ASTNode, firstField *Field) (interface{}, error) { + doneCh := reqCtx.NetContext.Done() + select { + case <-doneCh: + return nil, ContextDoneErr + default: + } - resolverName := objectType.Name.Value + "/" + firstField.Name.Value + resolverName := fmt.Sprintf("%s/%s", objectType.Name.Value, firstField.Name.Value) if resolver, ok := executor.Resolvers[resolverName]; ok { var resolveFn ResolveFn var beforeFn BeforeFn @@ -1378,13 +1393,14 @@ func (executor *Executor) resolveFieldOnObject(reqCtx *RequestContext, objectTyp // Execute the before function if it is defined resolveParams := &ResolveParams{ - Executor: executor, - Schema: executor.Schema.Document, - Request: reqCtx.Document, - Context: reqCtx.AppContext, - Source: object, - Args: args, - Field: firstField, + RequestContext: reqCtx.NetContext, + Executor: executor, + Schema: executor.Schema.Document, + Request: reqCtx.Document, + Context: reqCtx.AppContext, + Source: object, + Args: args, + Field: firstField, } if beforeFn != nil { @@ -1438,6 +1454,13 @@ func (executor *Executor) resolveFieldOnObject(reqCtx *RequestContext, objectTyp return result, nil } + // Check again to see if the context is finished before we bother doing more work. + select { + case <-doneCh: + return nil, ContextDoneErr + default: + } + sourceVal := reflect.ValueOf(object) sourceValType := sourceVal.Type() sourceValKind := sourceValType.Kind() diff --git a/executor_test.go b/executor_test.go index 0cc9957..21f329e 100644 --- a/executor_test.go +++ b/executor_test.go @@ -10,14 +10,14 @@ import ( ) type Author struct { - ID int `json:"id"` - Name string `json:"name"` - IsPublished string `json:"isPublished"` - Author *Author `json:"author"` - Title string `json:"title"` - Body string `json:"body"` - keywords []interface{} `json:"keywords"` - RecentArticle *Article `json:"recentArticle"` + ID int `json:"id"` + Name string `json:"name"` + IsPublished string `json:"isPublished"` + Author *Author `json:"author"` + Title string `json:"title"` + Body string `json:"body"` + keywords []interface{} + RecentArticle *Article `json:"recentArticle"` } type Image struct { @@ -3392,9 +3392,9 @@ func TestExecutor(t *testing.T) { "deserializedValue": nil, }, "errors": []map[string]interface{}{ - map[string]interface{}{ + { "locations": []map[string]interface{}{ - map[string]interface{}{ + { "line": 3, "column": 13, }, @@ -3501,9 +3501,9 @@ func TestExecutor(t *testing.T) { "deserializedValue": nil, }, "errors": []map[string]interface{}{ - map[string]interface{}{ + { "locations": []map[string]interface{}{ - map[string]interface{}{ + { "line": 3, "column": 11, },