Improve error handling in expressions

This commit is contained in:
Elara 2023-10-30 17:16:45 -07:00
parent dce4ee2e8f
commit 7e307d8811
2 changed files with 60 additions and 64 deletions

103
expr.go
View File

@ -1,12 +1,20 @@
package salix package salix
import ( import (
"errors"
"reflect" "reflect"
"strings" "strings"
"go.elara.ws/salix/internal/ast" "go.elara.ws/salix/internal/ast"
) )
var (
ErrModulusFloat = errors.New("modulo operation cannot be performed on floats")
ErrTypeMismatch = errors.New("mismatched types")
ErrLogicalNonBool = errors.New("logical operations may only be performed on boolean values")
ErrInOpInvalidTypes = errors.New("the in operator can only be used on strings, arrays, and slices")
)
func (t *Template) evalExpr(expr ast.Expr, local map[string]any) (any, error) { func (t *Template) evalExpr(expr ast.Expr, local map[string]any) (any, error) {
val, err := t.getValue(expr.First, local) val, err := t.getValue(expr.First, local)
if err != nil { if err != nil {
@ -20,136 +28,145 @@ func (t *Template) evalExpr(expr ast.Expr, local map[string]any) (any, error) {
return nil, err return nil, err
} }
b := reflect.ValueOf(val) b := reflect.ValueOf(val)
a = reflect.ValueOf(t.performOp(a, b, exprB.Operator))
result, err := t.performOp(a, b, exprB.Operator)
if err != nil {
return nil, err
}
a = reflect.ValueOf(result)
} }
return a.Interface(), nil return a.Interface(), nil
} }
func (t *Template) performOp(a, b reflect.Value, op ast.Op) any { func (t *Template) performOp(a, b reflect.Value, op ast.Operator) (any, error) {
if op.Op() == "in" { if op.Value == "in" {
switch b.Kind() { switch b.Kind() {
case reflect.Slice, reflect.Array: case reflect.Slice, reflect.Array:
if a.CanConvert(b.Type().Elem()) { if a.CanConvert(b.Type().Elem()) {
a = a.Convert(b.Type().Elem()) a = a.Convert(b.Type().Elem())
} else { } else {
panic("todo: invalid in operation") return nil, t.posError(op, "%w (%s and %s)", ErrTypeMismatch, a.Type(), b.Type())
} }
case reflect.String: case reflect.String:
if a.Kind() != reflect.String { if a.Kind() != reflect.String {
panic("todo: invalid in operation") return nil, t.posError(op, "%w (%s and %s)", ErrTypeMismatch, a.Type(), b.Type())
} }
default:
return nil, t.posError(op, "%w (got %s and %s)", ErrInOpInvalidTypes, a.Type(), b.Type())
} }
} else if b.CanConvert(a.Type()) { } else if b.CanConvert(a.Type()) {
b = b.Convert(a.Type()) b = b.Convert(a.Type())
} else { } else {
panic("todo: invalid operation") return nil, t.posError(op, "%w (%s and %s)", ErrTypeMismatch, a.Type(), b.Type())
} }
switch op.Op() { switch op.Value {
case "==": case "==":
return a.Equal(b) return a.Equal(b), nil
case "&&": case "&&":
if a.Kind() != reflect.Bool || b.Kind() != reflect.Bool { if a.Kind() != reflect.Bool || b.Kind() != reflect.Bool {
panic("todo: invalid logical") return nil, t.posError(op, "%w", ErrLogicalNonBool)
} }
return a.Bool() && b.Bool() return a.Bool() && b.Bool(), nil
case "||": case "||":
if a.Kind() != reflect.Bool || b.Kind() != reflect.Bool { if a.Kind() != reflect.Bool || b.Kind() != reflect.Bool {
panic("todo: invalid logical") return nil, t.posError(op, "%w", ErrLogicalNonBool)
} }
return a.Bool() || b.Bool() return a.Bool() || b.Bool(), nil
case ">=": case ">=":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() >= b.Int() return a.Int() >= b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() >= b.Uint() return a.Uint() >= b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() >= b.Float() return a.Float() >= b.Float(), nil
} }
case "<=": case "<=":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() <= b.Int() return a.Int() <= b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() <= b.Uint() return a.Uint() <= b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() <= b.Float() return a.Float() <= b.Float(), nil
} }
case ">": case ">":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() > b.Int() return a.Int() > b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() > b.Uint() return a.Uint() > b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() > b.Float() return a.Float() > b.Float(), nil
} }
case "<": case "<":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() < b.Int() return a.Int() < b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() < b.Uint() return a.Uint() < b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() < b.Float() return a.Float() < b.Float(), nil
} }
case "+": case "+":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() + b.Int() return a.Int() + b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() + b.Uint() return a.Uint() + b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() + b.Float() return a.Float() + b.Float(), nil
} }
case "-": case "-":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() - b.Int() return a.Int() - b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() - b.Uint() return a.Uint() - b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() - b.Float() return a.Float() - b.Float(), nil
} }
case "*": case "*":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() * b.Int() return a.Int() * b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() * b.Uint() return a.Uint() * b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() * b.Float() return a.Float() * b.Float(), nil
} }
case "/": case "/":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() / b.Int() return a.Int() / b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() / b.Uint() return a.Uint() / b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() / b.Float() return a.Float() / b.Float(), nil
} }
case "%": case "%":
switch a.Kind() { switch a.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a.Int() % b.Int() return a.Int() % b.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return a.Uint() % b.Uint() return a.Uint() % b.Uint(), nil
case reflect.Float64, reflect.Float32: case reflect.Float64, reflect.Float32:
return a.Float() % b.Float() return nil, t.posError(op, "%w", ErrModulusFloat)
}
case "in": case "in":
if a.Kind() == reflect.String && b.Kind() == reflect.String { if a.Kind() == reflect.String && b.Kind() == reflect.String {
return strings.Contains(b.String(), a.String()) return strings.Contains(b.String(), a.String()), nil
} else { } else {
for i := 0; i < b.Len(); i++ { for i := 0; i < b.Len(); i++ {
if a.Equal(b.Index(i)) { if a.Equal(b.Index(i)) {
return true return true, nil
} }
} }
return false return false, nil
} }
} }
return false return false, nil
} }

View File

@ -159,10 +159,6 @@ func (b Bool) Pos() Position {
return b.Position return b.Position
} }
type Op interface {
Op() string
}
type Operator struct { type Operator struct {
Value string Value string
Position Position Position Position
@ -172,23 +168,6 @@ func (op Operator) Pos() Position {
return op.Position return op.Position
} }
func (op Operator) Op() string {
return op.Value
}
type Logical struct {
Value string
Position Position
}
func (l Logical) Pos() Position {
return l.Position
}
func (l Logical) Op() string {
return l.Value
}
type Ternary struct { type Ternary struct {
Condition Node Condition Node
IfTrue Node IfTrue Node