Add support for variadic functions

This commit is contained in:
Elara 2024-01-18 23:42:18 -08:00
parent 25037db86a
commit 9bf56b50a4
2 changed files with 18 additions and 6 deletions

View File

@ -475,7 +475,10 @@ func (t *Template) execMethodCall(mc ast.MethodCall, local map[string]any) (any,
// execFunc executes a function call
func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, local map[string]any) (any, error) {
fnType := fn.Type()
if fnType.NumIn() != len(args) {
lastIndex := fnType.NumIn() - 1
isVariadic := fnType.IsVariadic()
if !isVariadic && fnType.NumIn() != len(args) {
return nil, ast.PosError(node, "%s: invalid parameter amount: %d (expected %d)", valueToString(node), len(args), fnType.NumIn())
}
@ -483,7 +486,7 @@ func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, lo
return nil, err
}
params := make([]reflect.Value, fnType.NumIn())
params := make([]reflect.Value, 0, fnType.NumIn())
for i, arg := range args {
if _, ok := arg.(ast.Assignment); ok {
return nil, ast.PosError(arg, "%s: an assignment cannot be used as a function argument", valueToString(node))
@ -492,13 +495,21 @@ func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, lo
if err != nil {
return nil, err
}
params[i] = reflect.ValueOf(paramVal)
params = append(params, reflect.ValueOf(paramVal))
if isVariadic && i >= lastIndex {
if params[i].CanConvert(fnType.In(lastIndex).Elem()) {
params[i] = params[i].Convert(fnType.In(lastIndex).Elem())
} else {
return nil, ast.PosError(node, "%s: invalid parameter type: %T (expected %s)", valueToString(node), paramVal, fnType.In(i))
}
} else {
if params[i].CanConvert(fnType.In(i)) {
params[i] = params[i].Convert(fnType.In(i))
} else {
return nil, ast.PosError(node, "%s: invalid parameter type: %T (expected %s)", valueToString(node), paramVal, fnType.In(i))
}
}
}
ret := fn.Call(params)
if len(ret) == 1 {

View File

@ -19,6 +19,7 @@ var globalVars = map[string]any{
"count": strings.Count,
"split": strings.Split,
"join": strings.Join,
"sprintf": fmt.Sprintf,
}
func tmplLen(v any) (int, error) {