diff --git a/default.go b/default.go index ddde003..fc26062 100644 --- a/default.go +++ b/default.go @@ -65,7 +65,7 @@ func init() { a.RequireNumOfArguments("len", 1, 1) expression := a.Get(0) - if expression.Kind() == reflect.Ptr { + if expression.Kind() == reflect.Ptr || expression.Kind() == reflect.Interface { expression = expression.Elem() } diff --git a/eval.go b/eval.go index 11d7c36..2763125 100644 --- a/eval.go +++ b/eval.go @@ -55,7 +55,7 @@ func (renderer RendererFunc) Render(r *Runtime) { // Ranger a value implementing a ranger interface is able to iterate on his value // and can be used directly in a range statement type Ranger interface { - Range() (reflect.Value, reflect.Value, bool) + Range() (interface{}, interface{}, bool) } type escapeeWriter struct { @@ -221,6 +221,9 @@ func (state *Runtime) Resolve(name string) reflect.Value { } func (st *Runtime) recover(err *error) { + // reset state scope and context just to be safe (they might not be cleared properly if there was a panic while using the state) + st.scope = &scope{} + st.context = reflect.Value{} pool_State.Put(st) if recovered := recover(); recovered != nil { var is bool @@ -457,21 +460,21 @@ func (st *Runtime) executeList(list *ListNode) { if isSet { if isLet { if isKeyVal { - st.variables[node.Set.Left[0].String()] = indexValue - st.variables[node.Set.Left[1].String()] = rangeValue + st.variables[node.Set.Left[0].String()] = reflect.ValueOf(indexValue) + st.variables[node.Set.Left[1].String()] = reflect.ValueOf(rangeValue) } else { - st.variables[node.Set.Left[0].String()] = rangeValue + st.variables[node.Set.Left[0].String()] = reflect.ValueOf(rangeValue) } } else { if isKeyVal { - st.executeSet(node.Set.Left[0], indexValue) - st.executeSet(node.Set.Left[1], rangeValue) + st.executeSet(node.Set.Left[0], reflect.ValueOf(indexValue)) + st.executeSet(node.Set.Left[1], reflect.ValueOf(rangeValue)) } else { - st.executeSet(node.Set.Left[0], rangeValue) + st.executeSet(node.Set.Left[0], reflect.ValueOf(rangeValue)) } } } else { - st.context = rangeValue + st.context = reflect.ValueOf(rangeValue) } st.executeList(node.List) indexValue, rangeValue, end = ranger.Range() @@ -605,7 +608,12 @@ func (st *Runtime) evalPrimaryExpressionGroup(node Expression) reflect.Value { return baseExpression.MapIndex(indexExpression) case reflect.Array, reflect.String, reflect.Slice: if canNumber(indexType.Kind()) { - return baseExpression.Index(int(castInt64(indexExpression))) + index := int(castInt64(indexExpression)) + if 0 <= index && index < baseExpression.Len() { + return baseExpression.Index(index) + } else { + node.errorf("%s index out of range (index: %d, len: %d)", baseExpression.Kind().String(), index, baseExpression.Len()) + } } else { node.errorf("non numeric value in index expression kind %s", baseExpression.Kind().String()) } @@ -650,6 +658,20 @@ func (st *Runtime) evalPrimaryExpressionGroup(node Expression) reflect.Value { return st.evalBaseExpressionGroup(node) } +// notNil returns false when v.IsValid() == false +// or when v's kind can be nil and v.IsNil() == true +func notNil(v reflect.Value) bool { + if !v.IsValid() { + return false + } + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return !v.IsNil() + default: + return true + } +} + func (st *Runtime) isSet(node Node) bool { nodeType := node.Type() @@ -668,7 +690,7 @@ func (st *Runtime) isSet(node Node) bool { indexExpression := st.evalPrimaryExpressionGroup(node.Index) indexType := indexExpression.Type() - if baseExpression.Kind() == reflect.Ptr { + if baseExpression.Kind() == reflect.Ptr || baseExpression.Kind() == reflect.Interface { baseExpression = baseExpression.Elem() } @@ -682,7 +704,8 @@ func (st *Runtime) isSet(node Node) bool { node.errorf("%s is not assignable|convertible to map key %s", indexType.String(), key.String()) } } - return baseExpression.MapIndex(indexExpression).IsValid() + value := baseExpression.MapIndex(indexExpression) + return notNil(value) case reflect.Array, reflect.String, reflect.Slice: if canNumber(indexType.Kind()) { i := int(castInt64(indexExpression)) @@ -695,7 +718,9 @@ func (st *Runtime) isSet(node Node) bool { i := int(castInt64(indexExpression)) return i >= 0 && i < baseExpression.NumField() } else if indexType.Kind() == reflect.String { - return getFieldOrMethodValue(indexExpression.String(), baseExpression).IsValid() + fieldValue := getFieldOrMethodValue(indexExpression.String(), baseExpression) + return notNil(fieldValue) + } else { node.errorf("non numeric value in index expression kind %s", baseExpression.Kind().String()) } @@ -703,30 +728,21 @@ func (st *Runtime) isSet(node Node) bool { node.errorf("indexing is not supported in value type %s", baseExpression.Kind().String()) } case NodeIdentifier: - if st.Resolve(node.String()).IsValid() == false { - return false - } + value := st.Resolve(node.String()) + return notNil(value) case NodeField: node := node.(*FieldNode) resolved := st.context for i := 0; i < len(node.Ident); i++ { resolved = getFieldOrMethodValue(node.Ident[i], resolved) - if !resolved.IsValid() { + if !notNil(resolved) { return false } } case NodeChain: node := node.(*ChainNode) - var value = st.evalPrimaryExpressionGroup(node.Node) - if !value.IsValid() { - return false - } - for i := 0; i < len(node.Field); i++ { - value := getFieldOrMethodValue(node.Field[i], value) - if !value.IsValid() { - return false - } - } + resolved, _ := st.evalFieldAccessExpression(node) + return notNil(resolved) default: //todo: maybe work some edge cases if !(nodeType > beginExpressions && nodeType < endExpressions) { @@ -1104,14 +1120,9 @@ func (st *Runtime) evalBaseExpressionGroup(node Node) reflect.Value { } return resolved case NodeChain: - node := node.(*ChainNode) - var resolved = st.evalPrimaryExpressionGroup(node.Node) - for i := 0; i < len(node.Field); i++ { - fieldValue := getFieldOrMethodValue(node.Field[i], resolved) - if !fieldValue.IsValid() { - node.errorf("there is no field or method %q in %s", node.Field[i], getTypeString(resolved)) - } - resolved = fieldValue + resolved, err := st.evalFieldAccessExpression(node.(*ChainNode)) + if err != nil { + node.error(err) } return resolved case NodeNumber: @@ -1169,6 +1180,17 @@ func (st *Runtime) evalCommandExpression(node *CommandNode) (reflect.Value, bool return term, false } +func (st *Runtime) evalFieldAccessExpression(node *ChainNode) (reflect.Value, error) { + resolved := st.evalPrimaryExpressionGroup(node.Node) + for i := 0; i < len(node.Field); i++ { + resolved = getFieldOrMethodValue(node.Field[i], resolved) + if !resolved.IsValid() { + return resolved, fmt.Errorf("there is no field or method %q in %s", node.Field[i], getTypeString(resolved)) + } + } + return resolved, nil +} + type escapeWriter struct { rawWriter io.Writer safeWriter SafeWriter @@ -1446,6 +1468,10 @@ var cachedStructsMutex = sync.RWMutex{} var cachedStructsFieldIndex = map[reflect.Type]map[string][]int{} func getFieldOrMethodValue(key string, v reflect.Value) reflect.Value { + if !v.IsValid() { + return reflect.Value{} + } + value := getValue(key, v) if value.Kind() == reflect.Interface && !value.IsNil() { value = value.Elem() @@ -1586,11 +1612,11 @@ type sliceRanger struct { i int } -func (s *sliceRanger) Range() (index, value reflect.Value, end bool) { +func (s *sliceRanger) Range() (index, value interface{}, end bool) { s.i++ - index = reflect.ValueOf(&s.i).Elem() + index = s.i if s.i < s.len { - value = s.v.Index(s.i) + value = s.v.Index(s.i).Interface() return } pool_sliceRanger.Put(s) @@ -1602,8 +1628,9 @@ type chanRanger struct { v reflect.Value } -func (s *chanRanger) Range() (_, value reflect.Value, end bool) { - value, end = s.v.Recv() +func (s *chanRanger) Range() (_, value interface{}, end bool) { + _value, end := s.v.Recv() + value = _value.Interface() if end { pool_chanRanger.Put(s) } @@ -1617,10 +1644,11 @@ type mapRanger struct { i int } -func (s *mapRanger) Range() (index, value reflect.Value, end bool) { +func (s *mapRanger) Range() (index, value interface{}, end bool) { if s.i < s.len { - index = s.keys[s.i] - value = s.v.MapIndex(index) + _index := s.keys[s.i] + index = _index.Interface() + value = s.v.MapIndex(_index).Interface() s.i++ return } diff --git a/eval_test.go b/eval_test.go index 1110249..f7a07f5 100644 --- a/eval_test.go +++ b/eval_test.go @@ -460,6 +460,20 @@ func TestEvalBuiltinExpression(t *testing.T) { RunJetTest(t, data, nil, "LenExpression_1", `{{len("111")}}`, "3") RunJetTest(t, data, nil, "LenExpression_2", `{{isset(data)?len(data):0}}`, "0") RunJetTest(t, data, []string{"", "", "", ""}, "LenExpression_3", `{{len(.)}}`, "4") + data.Set( + "foo", map[string]interface{}{ + "asd": map[string]string{ + "bar": "baz", + }, + }, + ) + RunJetTest(t, data, nil, "IsSetExpression_1", `{{isset(foo)}}`, "true") + RunJetTest(t, data, nil, "IsSetExpression_2", `{{isset(foo.asd)}}`, "true") + RunJetTest(t, data, nil, "IsSetExpression_3", `{{isset(foo.asd.bar)}}`, "true") + RunJetTest(t, data, nil, "IsSetExpression_4", `{{isset(asd)}}`, "false") + RunJetTest(t, data, nil, "IsSetExpression_5", `{{isset(foo.bar)}}`, "false") + RunJetTest(t, data, nil, "IsSetExpression_6", `{{isset(foo.asd.foo)}}`, "false") + RunJetTest(t, data, nil, "IsSetExpression_7", `{{isset(foo.asd.bar.xyz)}}`, "false") } func TestEvalAutoescape(t *testing.T) { diff --git a/func.go b/func.go index 8e6668b..7591ee2 100644 --- a/func.go +++ b/func.go @@ -26,6 +26,11 @@ type Arguments struct { argVal []reflect.Value } +// IsSet checks whether an argument is set or not. It behaves like the build-in isset function. +func (a *Arguments) IsSet(argumentIndex int) bool { + return a.runtime.isSet(a.argExpr[argumentIndex]) +} + // Get gets an argument by index. func (a *Arguments) Get(argumentIndex int) reflect.Value { if argumentIndex < len(a.argVal) { diff --git a/global.go b/global.go new file mode 100644 index 0000000..1926b25 --- /dev/null +++ b/global.go @@ -0,0 +1,9 @@ +package jet + +var abortTemplateOnError = true + +// SetAbortTemplateOnError controls whether the template rendering process should be aborted when an error is encountered. +/// Default behavior is to abort the template rendering process when an error is encountered, so abortOnError == true. +func SetAbortTemplateOnError(abortOnError bool) { + abortTemplateOnError = abortOnError +} diff --git a/node.go b/node.go index 6ae3362..0abf378 100644 --- a/node.go +++ b/node.go @@ -61,7 +61,9 @@ func (node *NodeBase) error(err error) { } func (node *NodeBase) errorf(format string, v ...interface{}) { - panic(fmt.Errorf("Jet Runtime Error(%q:%d): %s", node.TemplateName, node.Line, fmt.Sprintf(format, v...))) + if abortTemplateOnError { + panic(fmt.Errorf("Jet Runtime Error(%q:%d): %s", node.TemplateName, node.Line, fmt.Sprintf(format, v...))) + } } // Type returns itself and provides an easy default implementation