diff --git a/expr.go b/expr.go index ef319e1..262cac2 100644 --- a/expr.go +++ b/expr.go @@ -1,12 +1,20 @@ package salix import ( + "errors" "reflect" "strings" "go.elara.ws/salix/internal/ast" ) +var ( + ErrModulusFloat = errors.New("modulo operation cannot be performed on floats") + ErrTypeMismatch = errors.New("mismatched types") + ErrLogicalNonBool = errors.New("logical operations may only be performed on boolean values") + ErrInOpInvalidTypes = errors.New("the in operator can only be used on strings, arrays, and slices") +) + func (t *Template) evalExpr(expr ast.Expr, local map[string]any) (any, error) { val, err := t.getValue(expr.First, local) if err != nil { @@ -20,136 +28,145 @@ func (t *Template) evalExpr(expr ast.Expr, local map[string]any) (any, error) { return nil, err } b := reflect.ValueOf(val) - a = reflect.ValueOf(t.performOp(a, b, exprB.Operator)) + + result, err := t.performOp(a, b, exprB.Operator) + if err != nil { + return nil, err + } + + a = reflect.ValueOf(result) } return a.Interface(), nil } -func (t *Template) performOp(a, b reflect.Value, op ast.Op) any { - if op.Op() == "in" { +func (t *Template) performOp(a, b reflect.Value, op ast.Operator) (any, error) { + if op.Value == "in" { switch b.Kind() { case reflect.Slice, reflect.Array: if a.CanConvert(b.Type().Elem()) { a = a.Convert(b.Type().Elem()) } else { - panic("todo: invalid in operation") + return nil, t.posError(op, "%w (%s and %s)", ErrTypeMismatch, a.Type(), b.Type()) } case reflect.String: if a.Kind() != reflect.String { - panic("todo: invalid in operation") + return nil, t.posError(op, "%w (%s and %s)", ErrTypeMismatch, a.Type(), b.Type()) } + default: + return nil, t.posError(op, "%w (got %s and %s)", ErrInOpInvalidTypes, a.Type(), b.Type()) } } else if b.CanConvert(a.Type()) { b = b.Convert(a.Type()) } else { - panic("todo: invalid operation") + return nil, t.posError(op, "%w (%s and %s)", ErrTypeMismatch, a.Type(), b.Type()) } - switch op.Op() { + switch op.Value { case "==": - return a.Equal(b) + return a.Equal(b), nil case "&&": if a.Kind() != reflect.Bool || b.Kind() != reflect.Bool { - panic("todo: invalid logical") + return nil, t.posError(op, "%w", ErrLogicalNonBool) } - return a.Bool() && b.Bool() + return a.Bool() && b.Bool(), nil case "||": if a.Kind() != reflect.Bool || b.Kind() != reflect.Bool { - panic("todo: invalid logical") + return nil, t.posError(op, "%w", ErrLogicalNonBool) } - return a.Bool() || b.Bool() + return a.Bool() || b.Bool(), nil case ">=": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() >= b.Int() + return a.Int() >= b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() >= b.Uint() + return a.Uint() >= b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() >= b.Float() + return a.Float() >= b.Float(), nil } case "<=": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() <= b.Int() + return a.Int() <= b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() <= b.Uint() + return a.Uint() <= b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() <= b.Float() + return a.Float() <= b.Float(), nil } case ">": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() > b.Int() + return a.Int() > b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() > b.Uint() + return a.Uint() > b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() > b.Float() + return a.Float() > b.Float(), nil } case "<": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() < b.Int() + return a.Int() < b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() < b.Uint() + return a.Uint() < b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() < b.Float() + return a.Float() < b.Float(), nil } case "+": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() + b.Int() + return a.Int() + b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() + b.Uint() + return a.Uint() + b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() + b.Float() + return a.Float() + b.Float(), nil } case "-": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() - b.Int() + return a.Int() - b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() - b.Uint() + return a.Uint() - b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() - b.Float() + return a.Float() - b.Float(), nil } case "*": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() * b.Int() + return a.Int() * b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() * b.Uint() + return a.Uint() * b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() * b.Float() + return a.Float() * b.Float(), nil } case "/": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() / b.Int() + return a.Int() / b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() / b.Uint() + return a.Uint() / b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() / b.Float() + return a.Float() / b.Float(), nil } case "%": switch a.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return a.Int() % b.Int() + return a.Int() % b.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return a.Uint() % b.Uint() + return a.Uint() % b.Uint(), nil case reflect.Float64, reflect.Float32: - return a.Float() % b.Float() + return nil, t.posError(op, "%w", ErrModulusFloat) + } case "in": if a.Kind() == reflect.String && b.Kind() == reflect.String { - return strings.Contains(b.String(), a.String()) + return strings.Contains(b.String(), a.String()), nil } else { for i := 0; i < b.Len(); i++ { if a.Equal(b.Index(i)) { - return true + return true, nil } } - return false + return false, nil } } - return false + return false, nil } diff --git a/internal/ast/ast.go b/internal/ast/ast.go index 9ae2365..c3bbe13 100644 --- a/internal/ast/ast.go +++ b/internal/ast/ast.go @@ -159,10 +159,6 @@ func (b Bool) Pos() Position { return b.Position } -type Op interface { - Op() string -} - type Operator struct { Value string Position Position @@ -172,23 +168,6 @@ func (op Operator) Pos() Position { return op.Position } -func (op Operator) Op() string { - return op.Value -} - -type Logical struct { - Value string - Position Position -} - -func (l Logical) Pos() Position { - return l.Position -} - -func (l Logical) Op() string { - return l.Value -} - type Ternary struct { Condition Node IfTrue Node