diff --git a/expr.go b/expr.go index a668eb9..3d29be9 100644 --- a/expr.go +++ b/expr.go @@ -32,47 +32,14 @@ func (t *Template) evalExpr(expr ast.Expr, local map[string]any) (any, error) { return a.Interface(), nil } -func (t *Template) performOp(a, b reflect.Value, op ast.Operator) (any, error) { +func (t *Template) performOp(a, b reflect.Value, op ast.Operator) (result any, err error) { if op.Value == "in" { - switch b.Kind() { - case reflect.Slice, reflect.Array: - if a.CanConvert(b.Type().Elem()) { - a = a.Convert(b.Type().Elem()) - } else { - return nil, ast.PosError(op, "mismatched types in expression (%s and %s)", a.Type(), b.Type()) - } - case reflect.Map: - if a.CanConvert(b.Type().Key()) { - a = a.Convert(b.Type().Key()) - } else { - return nil, ast.PosError(op, "mismatched types in expression (%s and %s)", a.Type(), b.Type()) - } - case reflect.String: - if a.Kind() != reflect.String { - return nil, ast.PosError(op, "mismatched types in expression (%s and %s)", a.Type(), b.Type()) - } - default: - return nil, ast.PosError(op, "the in operator can only be used on strings, arrays, and slices (got %s and %s)", a.Type(), b.Type()) - } - } else if !a.IsValid() && !b.IsValid() { - return true, nil - } else if !a.IsValid() { - return nil, ast.PosError(op, "nil must be on the right side of an expression") - } else if !b.IsValid() { - if op.Value != "==" && op.Value != "!=" { - return nil, ast.PosError(op, "invalid operator for nil value (expected == or !=, got %s)", op.Value) - } - - switch a.Kind() { - case reflect.Chan, reflect.Slice, reflect.Map, reflect.Func, reflect.Interface, reflect.Pointer: - if op.Value == "==" { - return a.IsNil(), nil - } else { - return !a.IsNil(), nil - } - default: - return nil, ast.PosError(op, "values of type %s cannot be compared against nil", a.Type()) + a, b, err = handleIn(op, a, b) + if err != nil { + return nil, err } + } else if !a.IsValid() || !b.IsValid() { + return handleNil(op, a, b) } else if b.CanConvert(a.Type()) { b = b.Convert(a.Type()) } else { @@ -193,3 +160,51 @@ func (t *Template) performOp(a, b reflect.Value, op ast.Operator) (any, error) { } return false, ast.PosError(op, "unknown operator: %q", op.Value) } + +func handleIn(op ast.Operator, a, b reflect.Value) (c, d reflect.Value, err error) { + switch b.Kind() { + case reflect.Slice, reflect.Array: + if a.CanConvert(b.Type().Elem()) { + a = a.Convert(b.Type().Elem()) + } else { + return a, b, ast.PosError(op, "mismatched types in expression (%s and %s)", a.Type(), b.Type()) + } + case reflect.Map: + if a.CanConvert(b.Type().Key()) { + a = a.Convert(b.Type().Key()) + } else { + return a, b, ast.PosError(op, "mismatched types in expression (%s and %s)", a.Type(), b.Type()) + } + case reflect.String: + if a.Kind() != reflect.String { + return a, b, ast.PosError(op, "mismatched types in expression (%s and %s)", a.Type(), b.Type()) + } + default: + return a, b, ast.PosError(op, "the in operator can only be used on strings, arrays, and slices (got %s and %s)", a.Type(), b.Type()) + } + return a, b, nil +} + +func handleNil(op ast.Operator, a, b reflect.Value) (any, error) { + if !a.IsValid() && !b.IsValid() { + return true, nil + } else if !a.IsValid() { + return nil, ast.PosError(op, "nil must be on the right side of an expression") + } else if !b.IsValid() { + if op.Value != "==" && op.Value != "!=" { + return nil, ast.PosError(op, "invalid operator for nil value (expected == or !=, got %s)", op.Value) + } + + switch a.Kind() { + case reflect.Chan, reflect.Slice, reflect.Map, reflect.Func, reflect.Interface, reflect.Pointer: + if op.Value == "==" { + return a.IsNil(), nil + } else { + return !a.IsNil(), nil + } + default: + return nil, ast.PosError(op, "values of type %s cannot be compared against nil", a.Type()) + } + } + return nil, nil +}