Allow calling arbitrary values
This commit is contained in:
		@@ -40,7 +40,7 @@ class VariableRef:
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class FunctionCall:
 | 
			
		||||
    pos: Position
 | 
			
		||||
    name: str
 | 
			
		||||
    value: 'Value'
 | 
			
		||||
    args: list['Value']
 | 
			
		||||
 | 
			
		||||
Ref = VariableRef | FunctionCall
 | 
			
		||||
 
 | 
			
		||||
@@ -80,20 +80,19 @@ class Interp:
 | 
			
		||||
        return self._is_numerical(val) or isinstance(val, str)
 | 
			
		||||
    
 | 
			
		||||
    def _exec_func_call(self, call: ast.FunctionCall) -> typing.Any:
 | 
			
		||||
        if call.name not in self.vars:
 | 
			
		||||
            raise KeyError(f'{call.pos}: no such function: {repr(call.name)}')
 | 
			
		||||
        elif not callable(self.vars[call.name]):
 | 
			
		||||
        val = self._convert_value(call.value)
 | 
			
		||||
        if not callable(val):
 | 
			
		||||
            raise ValueError(f'{call.pos}: cannot call non-callable object')
 | 
			
		||||
        args = []
 | 
			
		||||
        for arg in call.args:
 | 
			
		||||
            if isinstance(arg, ast.Expansion):
 | 
			
		||||
                val = self._convert_value(arg.value)
 | 
			
		||||
                if not isinstance(val, typing.Iterable):
 | 
			
		||||
                    raise ValueError(f"{arg.pos}: cannot perform expansion on non-iterable value ({type(val).__name__})")
 | 
			
		||||
                args.extend(val)
 | 
			
		||||
                arg_val = self._convert_value(arg.value)
 | 
			
		||||
                if not isinstance(arg_val, typing.Iterable):
 | 
			
		||||
                    raise ValueError(f"{arg.pos}: cannot perform expansion on non-iterable value ({type(arg_val).__name__})")
 | 
			
		||||
                args.extend(arg_val)
 | 
			
		||||
            else:
 | 
			
		||||
                args.append(self._convert_value(arg))
 | 
			
		||||
        return self.vars[call.name](*args)
 | 
			
		||||
        return val(*args)
 | 
			
		||||
        
 | 
			
		||||
    def _eval_unary_expr(self, expr: ast.UnaryExpression) -> float | int | bool:
 | 
			
		||||
        val = self._convert_value(expr.value)
 | 
			
		||||
 
 | 
			
		||||
@@ -41,6 +41,9 @@ class Lexer:
 | 
			
		||||
        self.unread = ''
 | 
			
		||||
        self.stream = stream
 | 
			
		||||
        self.pos.name = name
 | 
			
		||||
        
 | 
			
		||||
    def _pos(self) -> ast.Position:
 | 
			
		||||
        return dataclasses.replace(self.pos)
 | 
			
		||||
 | 
			
		||||
    def _peek(self, n: int) -> str:
 | 
			
		||||
        if self.unread != '':
 | 
			
		||||
@@ -191,15 +194,15 @@ class Lexer:
 | 
			
		||||
 | 
			
		||||
        match char:
 | 
			
		||||
            case '{' | '}':
 | 
			
		||||
                return Token.CURLY, self.pos, char
 | 
			
		||||
                return Token.CURLY, self._pos(), char
 | 
			
		||||
            case '[' | ']':
 | 
			
		||||
                return Token.SQUARE, self.pos,  char
 | 
			
		||||
                return Token.SQUARE, self._pos(),  char
 | 
			
		||||
            case '(' | ')':
 | 
			
		||||
                return Token.PAREN, self.pos, char
 | 
			
		||||
                return Token.PAREN, self._pos(), char
 | 
			
		||||
            case ',':
 | 
			
		||||
                return Token.COMMA, self.pos, char
 | 
			
		||||
                return Token.COMMA, self._pos(), char
 | 
			
		||||
            case ':':
 | 
			
		||||
                return Token.COLON, self.pos, char
 | 
			
		||||
                return Token.COLON, self._pos(), char
 | 
			
		||||
            case '"':
 | 
			
		||||
                    return self._scan_str()
 | 
			
		||||
            case '<':
 | 
			
		||||
@@ -230,12 +233,12 @@ class Lexer:
 | 
			
		||||
            case '.':
 | 
			
		||||
                if (next := self._read()) != '.':
 | 
			
		||||
                    self._unread(next)
 | 
			
		||||
                    return Token.DOT, self.pos, next
 | 
			
		||||
                    return Token.DOT, self._pos(), next
 | 
			
		||||
                elif (next := self._read()) != '.':
 | 
			
		||||
                    raise ExpectedError(self.pos, '.', next)
 | 
			
		||||
                return Token.ELLIPSIS, self.pos, "..."
 | 
			
		||||
                return Token.ELLIPSIS, self._pos(), "..."
 | 
			
		||||
            case '':
 | 
			
		||||
                return Token.EOF, self.pos, char
 | 
			
		||||
                return Token.EOF, self._pos(), char
 | 
			
		||||
 | 
			
		||||
        if is_numeric(char):
 | 
			
		||||
            return self._scan_number(char)
 | 
			
		||||
@@ -244,7 +247,7 @@ class Lexer:
 | 
			
		||||
        elif is_operator(char):
 | 
			
		||||
            return self._scan_operator(char)
 | 
			
		||||
 | 
			
		||||
        return Token.ILLEGAL, self.pos, char
 | 
			
		||||
        return Token.ILLEGAL, self._pos(), char
 | 
			
		||||
 | 
			
		||||
def is_whitespace(char: str) -> bool:
 | 
			
		||||
    return char in (' ', '\t', '\r', '\n')
 | 
			
		||||
 
 | 
			
		||||
@@ -98,15 +98,14 @@ class Parser:
 | 
			
		||||
            
 | 
			
		||||
        return ast.Object(start_pos, items)
 | 
			
		||||
        
 | 
			
		||||
    def _parse_func_call(self) -> ast.FunctionCall:
 | 
			
		||||
        id_tok, id_pos, id_lit = self._scan()
 | 
			
		||||
        tok, pos, lit = self._scan()
 | 
			
		||||
        if tok != lexer.Token.PAREN or lit != '(':
 | 
			
		||||
            raise ExpectedError(pos, 'opening parentheses', lit)
 | 
			
		||||
        
 | 
			
		||||
        tok, pos, lit = self._scan()
 | 
			
		||||
    def _parse_func_call(self, val: ast.Value, start_pos: ast.Position) -> ast.FunctionCall:        
 | 
			
		||||
        tok, pos, lit = self._scan()        
 | 
			
		||||
        if tok == lexer.Token.PAREN and lit == ')':
 | 
			
		||||
            return ast.FunctionCall(pos=id_pos, name=id_lit, args=[])
 | 
			
		||||
            out = ast.FunctionCall(pos=start_pos, value=val, args=[])
 | 
			
		||||
            while self.lexer._peek(1) == '(':
 | 
			
		||||
                _, start_pos, _ = self._scan()
 | 
			
		||||
                out = self._parse_func_call(out, start_pos)
 | 
			
		||||
            return out
 | 
			
		||||
        self._unscan(tok, pos, lit)
 | 
			
		||||
        
 | 
			
		||||
        args: list[ast.Value] = []
 | 
			
		||||
@@ -125,7 +124,12 @@ class Parser:
 | 
			
		||||
                break
 | 
			
		||||
            else:
 | 
			
		||||
                raise ExpectedError(pos, 'comma or closing parentheses', lit)
 | 
			
		||||
        return ast.FunctionCall(pos=id_pos, name=id_lit, args=args)
 | 
			
		||||
        
 | 
			
		||||
        out = ast.FunctionCall(pos=start_pos, value=val, args=args)
 | 
			
		||||
        while self.lexer._peek(1) == '(':
 | 
			
		||||
            _, start_pos, _ = self._scan()
 | 
			
		||||
            out = self._parse_func_call(out, start_pos)
 | 
			
		||||
        return out
 | 
			
		||||
    
 | 
			
		||||
    def _parse_value(self) -> ast.Value:
 | 
			
		||||
        out = None
 | 
			
		||||
@@ -140,11 +144,7 @@ class Parser:
 | 
			
		||||
            case lexer.Token.STRING:
 | 
			
		||||
                out = ast.String(pos=pos, value=pyast.literal_eval(lit))
 | 
			
		||||
            case lexer.Token.IDENT:
 | 
			
		||||
                if self.lexer._peek(1) == '(':
 | 
			
		||||
                    self._unscan(tok, pos, lit)
 | 
			
		||||
                    out = self._parse_func_call()
 | 
			
		||||
                else:
 | 
			
		||||
                    out = ast.VariableRef(pos=pos, name=lit)
 | 
			
		||||
                out = ast.VariableRef(pos=pos, name=lit)
 | 
			
		||||
            case lexer.Token.HEREDOC:
 | 
			
		||||
                out = ast.String(pos=pos, value=lit)
 | 
			
		||||
            case lexer.Token.OPERATOR:
 | 
			
		||||
@@ -171,6 +171,8 @@ class Parser:
 | 
			
		||||
        tok, pos, lit = self._scan()
 | 
			
		||||
        if tok == lexer.Token.SQUARE and lit == '[':
 | 
			
		||||
            out = self._parse_index(out, pos)
 | 
			
		||||
        elif tok == lexer.Token.PAREN and lit == '(':
 | 
			
		||||
            out = self._parse_func_call(out, pos)
 | 
			
		||||
        elif tok == lexer.Token.DOT:
 | 
			
		||||
            out = self._parse_getattr(out, pos)
 | 
			
		||||
        else:
 | 
			
		||||
 
 | 
			
		||||
@@ -180,8 +180,11 @@ class TestExpressions(unittest.TestCase):
 | 
			
		||||
    def test_expansion(self):
 | 
			
		||||
        val = parser.Parser(io.StringIO('x(y...)'), 'TestExpressions.test_expansion')._parse_expr()
 | 
			
		||||
        self.assertEqual(val, ast.FunctionCall(
 | 
			
		||||
            pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=1),
 | 
			
		||||
            name = 'x',
 | 
			
		||||
            pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=2),
 | 
			
		||||
            value = ast.VariableRef(
 | 
			
		||||
                pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=1),
 | 
			
		||||
                name = 'x',
 | 
			
		||||
            ),
 | 
			
		||||
            args = [
 | 
			
		||||
                ast.Expansion(
 | 
			
		||||
                    pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=3),
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user