From 9bf56b50a46166bf2f4c9f764db0c99c0e6ceea4 Mon Sep 17 00:00:00 2001 From: Elara6331 Date: Thu, 18 Jan 2024 23:42:18 -0800 Subject: [PATCH] Add support for variadic functions --- salix.go | 23 +++++++++++++++++------ vars.go | 1 + 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/salix.go b/salix.go index 809037e..725d015 100644 --- a/salix.go +++ b/salix.go @@ -475,7 +475,10 @@ func (t *Template) execMethodCall(mc ast.MethodCall, local map[string]any) (any, // execFunc executes a function call func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, local map[string]any) (any, error) { fnType := fn.Type() - if fnType.NumIn() != len(args) { + lastIndex := fnType.NumIn() - 1 + isVariadic := fnType.IsVariadic() + + if !isVariadic && fnType.NumIn() != len(args) { return nil, ast.PosError(node, "%s: invalid parameter amount: %d (expected %d)", valueToString(node), len(args), fnType.NumIn()) } @@ -483,7 +486,7 @@ func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, lo return nil, err } - params := make([]reflect.Value, fnType.NumIn()) + params := make([]reflect.Value, 0, fnType.NumIn()) for i, arg := range args { if _, ok := arg.(ast.Assignment); ok { return nil, ast.PosError(arg, "%s: an assignment cannot be used as a function argument", valueToString(node)) @@ -492,11 +495,19 @@ func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, lo if err != nil { return nil, err } - params[i] = reflect.ValueOf(paramVal) - if params[i].CanConvert(fnType.In(i)) { - params[i] = params[i].Convert(fnType.In(i)) + params = append(params, reflect.ValueOf(paramVal)) + if isVariadic && i >= lastIndex { + if params[i].CanConvert(fnType.In(lastIndex).Elem()) { + params[i] = params[i].Convert(fnType.In(lastIndex).Elem()) + } else { + return nil, ast.PosError(node, "%s: invalid parameter type: %T (expected %s)", valueToString(node), paramVal, fnType.In(i)) + } } else { - return nil, ast.PosError(node, "%s: invalid parameter type: %T (expected %s)", valueToString(node), paramVal, fnType.In(i)) + if params[i].CanConvert(fnType.In(i)) { + params[i] = params[i].Convert(fnType.In(i)) + } else { + return nil, ast.PosError(node, "%s: invalid parameter type: %T (expected %s)", valueToString(node), paramVal, fnType.In(i)) + } } } diff --git a/vars.go b/vars.go index 12c61d5..2a65d03 100644 --- a/vars.go +++ b/vars.go @@ -19,6 +19,7 @@ var globalVars = map[string]any{ "count": strings.Count, "split": strings.Split, "join": strings.Join, + "sprintf": fmt.Sprintf, } func tmplLen(v any) (int, error) {