Simplify variadic function support and add a test for it

This commit is contained in:
Elara 2024-01-24 22:33:47 -08:00
parent 9bf56b50a4
commit 958a25d559
2 changed files with 21 additions and 11 deletions

View File

@ -172,6 +172,16 @@ func TestMethodCall(t *testing.T) {
} }
} }
func TestVariadic(t *testing.T) {
res := execStr(t, `#(sprintf("%s %d", x, y))`, map[string]any{
"x": "test",
"y": 4,
})
if res != "test 4" {
t.Errorf("Expected %q, got %q", "test 4", res)
}
}
func execStr(t *testing.T, tmplStr string, vars map[string]any) string { func execStr(t *testing.T, tmplStr string, vars map[string]any) string {
t.Helper() t.Helper()
tmpl, err := New().ParseString("test", tmplStr) tmpl, err := New().ParseString("test", tmplStr)

View File

@ -477,7 +477,7 @@ func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, lo
fnType := fn.Type() fnType := fn.Type()
lastIndex := fnType.NumIn() - 1 lastIndex := fnType.NumIn() - 1
isVariadic := fnType.IsVariadic() isVariadic := fnType.IsVariadic()
if !isVariadic && fnType.NumIn() != len(args) { if !isVariadic && fnType.NumIn() != len(args) {
return nil, ast.PosError(node, "%s: invalid parameter amount: %d (expected %d)", valueToString(node), len(args), fnType.NumIn()) return nil, ast.PosError(node, "%s: invalid parameter amount: %d (expected %d)", valueToString(node), len(args), fnType.NumIn())
} }
@ -496,18 +496,18 @@ func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, lo
return nil, err return nil, err
} }
params = append(params, reflect.ValueOf(paramVal)) params = append(params, reflect.ValueOf(paramVal))
var paramType reflect.Type
if isVariadic && i >= lastIndex { if isVariadic && i >= lastIndex {
if params[i].CanConvert(fnType.In(lastIndex).Elem()) { paramType = fnType.In(lastIndex).Elem()
params[i] = params[i].Convert(fnType.In(lastIndex).Elem())
} else {
return nil, ast.PosError(node, "%s: invalid parameter type: %T (expected %s)", valueToString(node), paramVal, fnType.In(i))
}
} else { } else {
if params[i].CanConvert(fnType.In(i)) { paramType = fnType.In(i)
params[i] = params[i].Convert(fnType.In(i)) }
} else {
return nil, ast.PosError(node, "%s: invalid parameter type: %T (expected %s)", valueToString(node), paramVal, fnType.In(i)) if params[i].CanConvert(paramType) {
} params[i] = params[i].Convert(paramType)
} else {
return nil, ast.PosError(node, "%s: invalid parameter type: %T (expected %s)", valueToString(node), paramVal, paramType)
} }
} }