package salix import ( "bytes" "errors" "fmt" "html" "io" "reflect" "go.elara.ws/salix/internal/ast" ) var ( ErrNoSuchFunc = errors.New("no such function") ErrNoSuchVar = errors.New("no such variable") ErrNoSuchMethod = errors.New("no such method") ErrNoSuchField = errors.New("no such field") ErrNoSuchTag = errors.New("no such tag") ErrNotOperatorNonBool = errors.New("not operator cannot be used on a non-bool value") ErrParamNumMismatch = errors.New("incorrect parameter amount") ErrIncorrectParamType = errors.New("incorrect parameter type for function") ErrEndTagWithoutStart = errors.New("end tag without a start tag") ErrIncorrectIndexType = errors.New("incorrect index type") ErrIndexOutOfRange = errors.New("index out of range") ErrMapIndexNotFound = errors.New("map index not found") ErrMapInvalidIndexType = errors.New("invalid map index type") ErrFuncTooManyReturns = errors.New("template functions can only have two return values") ErrFuncNoReturns = errors.New("template functions must return at least one value") ErrFuncSecondReturnType = errors.New("the second return value of a template function must be an error") ) type HTML string // Template represents a Salix template type Template struct { file string ast []ast.Node escapeHTML bool tags map[string]Tag funcs map[string]reflect.Value vars map[string]reflect.Value } // New creates a new base template with the default // tags and functions added. func New() *Template { t := &Template{ funcs: map[string]reflect.Value{}, vars: map[string]reflect.Value{}, tags: map[string]Tag{}, } return t.WithTagMap(defaultTags).WithFuncMap(defaultFuncs) } // WithFuncMap adds all the functions in m to the template's // global function map. If a function with the given name already // exists, it will be overwritten. func (t *Template) WithFuncMap(m map[string]any) *Template { if m != nil { for name, fn := range m { t.funcs[name] = reflect.ValueOf(fn) } } return t } // WithVarMap adds all the variables in m to the template's // global variable map. If a variable with the given name already // exists, it will be overwritten. func (t *Template) WithVarMap(m map[string]any) *Template { if m != nil { for name, val := range m { t.vars[name] = reflect.ValueOf(val) } } return t } // WithTagMap adds all the tags in m to the template's // global tag map. If a tag with the given name already // exists, it will be overwritten. func (t *Template) WithTagMap(m map[string]Tag) *Template { if m != nil { for name, tag := range m { t.tags[name] = tag } } return t } // WithEscapeHTML enables or disables HTML escaping. // The HTML escaping functionality is NOT context-aware. // Using the HTML type allows you to get around the escaping. func (t *Template) WithEscapeHTML(b bool) *Template { t.escapeHTML = true return t } // Execute executes a parsed template and writes // the result to w. func (t *Template) Execute(w io.Writer) error { return t.execute(w, t.ast, nil) } func (t *Template) execute(w io.Writer, nodes []ast.Node, local map[string]any) error { for i := 0; i < len(nodes); i++ { switch node := nodes[i].(type) { case ast.Text: _, err := w.Write(node.Data) if err != nil { return t.posError(node, "%w", err) } case ast.Tag: newOffset, err := t.execTag(node, w, nodes, i, local) if err != nil { return err } i = newOffset case ast.EndTag: // We should never see an end tag here because it // should be taken care of by execTag, so if we do, // return an error because execTag was never called, // which means there was no start tag. return ErrEndTagWithoutStart case ast.ExprTag: v, err := t.getValue(node.Value, local) if err != nil { return err } _, err = io.WriteString(w, t.toString(v)) if err != nil { return err } } } return nil } func (t *Template) toString(v any) string { if h, ok := v.(HTML); ok { return string(h) } else if t.escapeHTML { return html.EscapeString(fmt.Sprint(v)) } return fmt.Sprint(v) } // getBlock gets all the nodes in the input, up to the end tag with the given name func (t *Template) getBlock(nodes []ast.Node, offset, startLine int, name string) []ast.Node { var out []ast.Node tagAmount := 1 for i := offset; i < len(nodes); i++ { switch node := nodes[i].(type) { case ast.Tag: // If we encounter another tag with the same name, // increment tagAmount so that we know that the next // end tag isn't the end of this tag. if node.Name.Value == name { tagAmount++ } out = append(out, node) case ast.EndTag: if node.Name.Value == name { tagAmount-- } // Once tagAmount is zero (all the tags of the same name) // have been closed with an end tag, we can handle our newlines // and return the nodes we've accumulated. if tagAmount == 0 { // If the end tag is on the same line as the start tag, // we don't need to remove any newlines. if node.Position.Line != startLine { t.handleNewlines(nodes, i) } return out } else { out = append(out, node) } default: out = append(out, node) } } return out } // getValue gets a Go value from an AST node func (t *Template) getValue(node ast.Node, local map[string]any) (any, error) { switch node := node.(type) { case ast.Value: return t.unwrapASTValue(node, local) case ast.Ident: return t.getVar(node, local) case ast.String: return node.Value, nil case ast.Float: return node.Value, nil case ast.Integer: return node.Value, nil case ast.Bool: return node.Value, nil case ast.Expr: return t.evalExpr(node, local) case ast.ExprSegment: return t.evalExprSegment(node, local) case ast.FuncCall: return t.execFuncCall(node, local) case ast.Index: return t.getIndex(node, local) case ast.FieldAccess: return t.getField(node, local) case ast.MethodCall: return t.execMethodCall(node, local) default: return nil, nil } } // unwrapASTValue unwraps an ast.Value node into its underlying value func (t *Template) unwrapASTValue(node ast.Value, local map[string]any) (any, error) { v, err := t.getValue(node.Node, local) if err != nil { return nil, err } if node.Not { rval := reflect.ValueOf(v) if rval.Kind() != reflect.Bool { return nil, ErrNotOperatorNonBool } return !rval.Bool(), nil } return v, err } // getVar tries to get a variable from the local map. If it's not found, // it'll try the global variable map. If it doesn't exist in either map, // it will return an error. func (t *Template) getVar(id ast.Ident, local map[string]any) (any, error) { if local != nil { v, ok := local[id.Value] if ok { return v, nil } } v, ok := t.vars[id.Value] if !ok { return nil, t.posError(id, "%w: %s", ErrNoSuchVar, id.Value) } return v.Interface(), nil } // handleNewlines removes newlines above and below the given index // in the nodes slice. func (t *Template) handleNewlines(nodes []ast.Node, i int) { if i != 0 { if node, ok := nodes[i-1].(ast.Text); ok { ni := bytes.LastIndexByte(node.Data, '\n') if ni != -1 { node.Data = node.Data[:ni] } nodes[i-1] = node } } lastIndex := len(nodes) - 1 if i != lastIndex { if node, ok := nodes[i+1].(ast.Text); ok { ni := bytes.IndexByte(node.Data, '\n') if ni != -1 { node.Data = node.Data[ni:] } nodes[i+1] = node } } } // execTag executes a tag func (t *Template) execTag(node ast.Tag, w io.Writer, nodes []ast.Node, i int, local map[string]any) (newOffset int, err error) { tag, ok := t.tags[node.Name.Value] if !ok { return 0, t.posError(node, "%w: %s", ErrNoSuchTag, node.Name.Value) } t.handleNewlines(nodes, i) var block []ast.Node if node.HasBody { block = t.getBlock(nodes, i+1, node.Position.Line, node.Name.Value) i += len(block) + 1 } tc := &TagContext{w, t, local} err = tag.Run(tc, block, node.Params) if err != nil { return 0, err } return i, nil } // execFuncCall executes a function call func (t *Template) execFuncCall(fc ast.FuncCall, local map[string]any) (any, error) { fn, ok := t.funcs[fc.Name.Value] if !ok { return nil, t.posError(fc, "%w: %s", ErrNoSuchFunc, fc.Name.Value) } return t.execFunc(fn, fc, fc.Params, local) } // getIndex tries to evaluate an ast.Index node by indexing the underlying value. func (t *Template) getIndex(i ast.Index, local map[string]any) (any, error) { val, err := t.getValue(i.Value, local) if err != nil { return nil, err } index, err := t.getValue(i.Index, local) if err != nil { return nil, err } rval := reflect.ValueOf(val) rindex := reflect.ValueOf(index) switch rval.Kind() { case reflect.Slice, reflect.Array: intType := reflect.TypeOf(0) if rindex.CanConvert(intType) { rindex = rindex.Convert(intType) } else { return nil, ErrIncorrectIndexType } intIndex := rindex.Interface().(int) if intIndex < rval.Len() { return rval.Index(intIndex).Interface(), nil } else { return nil, t.posError(i, "%w: %d", ErrIndexOutOfRange, intIndex) } case reflect.Map: if rindex.CanConvert(rval.Type().Key()) { rindex = rindex.Convert(rval.Type().Key()) } else { return nil, t.posError(i, "%w: %T (expected %s)", ErrMapInvalidIndexType, index, rval.Type().Key()) } if out := rval.MapIndex(rindex); out.IsValid() { return out.Interface(), nil } else { return nil, t.posError(i, "%w: %q", ErrMapIndexNotFound, index) } } return nil, nil } // getField tries to get a struct field from the underlying value func (t *Template) getField(fa ast.FieldAccess, local map[string]any) (any, error) { val, err := t.getValue(fa.Value, local) if err != nil { return nil, err } rval := reflect.ValueOf(val) field := rval.FieldByName(fa.Name.Value) if !field.IsValid() { return nil, t.posError(fa, "%w: %s", ErrNoSuchField, fa.Name.Value) } return field.Interface(), nil } // execMethodCall executes a method call on the underlying value func (t *Template) execMethodCall(mc ast.MethodCall, local map[string]any) (any, error) { val, err := t.getValue(mc.Value, local) if err != nil { return nil, err } rval := reflect.ValueOf(val) mtd := rval.MethodByName(mc.Name.Value) if !mtd.IsValid() { return nil, t.posError(mc, "%w: %s", ErrNoSuchMethod, mc.Name.Value) } return t.execFunc(mtd, mc, mc.Params, local) } // execFunc executes a function call func (t *Template) execFunc(fn reflect.Value, node ast.Node, args []ast.Node, local map[string]any) (any, error) { fnType := fn.Type() if fnType.NumIn() != len(args) { return nil, t.posError(node, "%w: %d (expected %d)", ErrParamNumMismatch, len(args), fnType.NumIn()) } if err := validateFunc(fnType); err != nil { return nil, t.posError(node, "%w", err) } params := make([]reflect.Value, fnType.NumIn()) for i, arg := range args { paramVal, err := t.getValue(arg, local) if err != nil { return nil, err } params[i] = reflect.ValueOf(paramVal) if params[i].CanConvert(fnType.In(i)) { params[i] = params[i].Convert(fnType.In(i)) } else { return nil, t.posError(node, "%w", ErrIncorrectParamType) } } ret := fn.Call(params) if len(ret) == 1 { retv := ret[0].Interface() if err, ok := retv.(error); ok { return nil, err } return ret[0].Interface(), nil } else { return ret[0].Interface(), ret[1].Interface().(error) } } func (t *Template) posError(n ast.Node, format string, v ...any) error { return ast.PosError(n, t.file, format, v...) } func validateFunc(t reflect.Type) error { numOut := t.NumOut() if numOut > 2 { return ErrFuncTooManyReturns } else if numOut == 0 { return ErrFuncNoReturns } if numOut == 2 { if !t.Out(1).Implements(reflect.TypeOf(error(nil))) { return ErrFuncSecondReturnType } } return nil }