diff --git a/for_tag.go b/for_tag.go index 4741bf2..f4ecb3b 100644 --- a/for_tag.go +++ b/for_tag.go @@ -1,6 +1,7 @@ package salix import ( + "errors" "reflect" "go.elara.ws/salix/ast" @@ -10,34 +11,26 @@ import ( type forTag struct{} func (ft forTag) Run(tc *TagContext, block, args []ast.Node) error { - if len(args) == 0 || len(args) > 2 { + if len(args) == 0 || len(args) > 3 { return tc.PosError(tc.Tag, "invalid argument amount") } - var expr ast.Expr - if len(args) == 1 { - expr2, ok := args[0].(ast.Expr) - if !ok { - return tc.PosError(args[0], "invalid argument type: %T (expected ast.Expr)", args[0]) - } - expr = expr2 - } else if len(args) == 2 { - expr2, ok := args[1].(ast.Expr) - if !ok { - return tc.PosError(args[1], "invalid argument type: %T (expected ast.Expr)", args[1]) - } - expr = expr2 + expr, ok := args[len(args)-1].(ast.Expr) + if !ok { + return tc.PosError(args[0], "invalid argument type: %T (expected ast.Expr)", args[0]) } var vars []string var in reflect.Value - if len(args) == 2 { - varName, ok := unwrap(args[0]).(ast.Ident) - if !ok { - return tc.PosError(args[0], "invalid argument type: %T (expected ast.Ident)", expr.First) + if len(args) > 1 { + for _, arg := range args[:len(args)-1] { + varName, ok := unwrap(arg).(ast.Ident) + if !ok { + return tc.PosError(arg, "invalid argument type: %T (expected ast.Ident)", expr.First) + } + vars = append(vars, varName.Value) } - vars = append(vars, varName.Value) } varName, ok := unwrap(expr.First).(ast.Ident) @@ -70,6 +63,8 @@ func (ft forTag) Run(tc *TagContext, block, args []ast.Node) error { } else if len(vars) == 2 { local[vars[0]] = i local[vars[1]] = in.Index(i).Interface() + } else { + return errors.New("slices and arrays can only use two for loop variables") } err = tc.Execute(block, local) @@ -80,18 +75,25 @@ func (ft forTag) Run(tc *TagContext, block, args []ast.Node) error { case reflect.Map: local := map[string]any{} iter := in.MapRange() + i := 0 for iter.Next() { if len(vars) == 1 { local[vars[0]] = iter.Value().Interface() } else if len(vars) == 2 { local[vars[0]] = iter.Key().Interface() local[vars[1]] = iter.Value().Interface() + } else if len(vars) == 3 { + local[vars[0]] = i + local[vars[1]] = iter.Key().Interface() + local[vars[2]] = iter.Value().Interface() } err = tc.Execute(block, local) if err != nil { return err } + + i++ } }