diff --git a/hisscl/ast.py b/hisscl/ast.py index 1d27de6..d1bd9c9 100644 --- a/hisscl/ast.py +++ b/hisscl/ast.py @@ -40,7 +40,7 @@ class VariableRef: @dataclasses.dataclass class FunctionCall: pos: Position - name: str + value: 'Val' args: list['Value'] Ref = VariableRef | FunctionCall diff --git a/hisscl/interp.py b/hisscl/interp.py index d867282..a8fdff3 100644 --- a/hisscl/interp.py +++ b/hisscl/interp.py @@ -80,20 +80,19 @@ class Interp: return self._is_numerical(val) or isinstance(val, str) def _exec_func_call(self, call: ast.FunctionCall) -> typing.Any: - if call.name not in self.vars: - raise KeyError(f'{call.pos}: no such function: {repr(call.name)}') - elif not callable(self.vars[call.name]): + val = self._convert_value(call.value) + if not callable(val): raise ValueError(f'{call.pos}: cannot call non-callable object') args = [] for arg in call.args: if isinstance(arg, ast.Expansion): - val = self._convert_value(arg.value) - if not isinstance(val, typing.Iterable): - raise ValueError(f"{arg.pos}: cannot perform expansion on non-iterable value ({type(val).__name__})") - args.extend(val) + arg_val = self._convert_value(arg.value) + if not isinstance(arg_val, typing.Iterable): + raise ValueError(f"{arg.pos}: cannot perform expansion on non-iterable value ({type(arg_val).__name__})") + args.extend(arg_val) else: args.append(self._convert_value(arg)) - return self.vars[call.name](*args) + return val(*args) def _eval_unary_expr(self, expr: ast.UnaryExpression) -> float | int | bool: val = self._convert_value(expr.value) diff --git a/hisscl/lexer.py b/hisscl/lexer.py index a8677cc..719e634 100644 --- a/hisscl/lexer.py +++ b/hisscl/lexer.py @@ -41,6 +41,9 @@ class Lexer: self.unread = '' self.stream = stream self.pos.name = name + + def _pos(self) -> ast.Position: + return dataclasses.replace(self.pos) def _peek(self, n: int) -> str: if self.unread != '': @@ -191,15 +194,15 @@ class Lexer: match char: case '{' | '}': - return Token.CURLY, self.pos, char + return Token.CURLY, self._pos(), char case '[' | ']': - return Token.SQUARE, self.pos, char + return Token.SQUARE, self._pos(), char case '(' | ')': - return Token.PAREN, self.pos, char + return Token.PAREN, self._pos(), char case ',': - return Token.COMMA, self.pos, char + return Token.COMMA, self._pos(), char case ':': - return Token.COLON, self.pos, char + return Token.COLON, self._pos(), char case '"': return self._scan_str() case '<': @@ -230,12 +233,12 @@ class Lexer: case '.': if (next := self._read()) != '.': self._unread(next) - return Token.DOT, self.pos, next + return Token.DOT, self._pos(), next elif (next := self._read()) != '.': raise ExpectedError(self.pos, '.', next) - return Token.ELLIPSIS, self.pos, "..." + return Token.ELLIPSIS, self._pos(), "..." case '': - return Token.EOF, self.pos, char + return Token.EOF, self._pos(), char if is_numeric(char): return self._scan_number(char) @@ -244,7 +247,7 @@ class Lexer: elif is_operator(char): return self._scan_operator(char) - return Token.ILLEGAL, self.pos, char + return Token.ILLEGAL, self._pos(), char def is_whitespace(char: str) -> bool: return char in (' ', '\t', '\r', '\n') diff --git a/hisscl/parser.py b/hisscl/parser.py index c9eef52..c0af86c 100644 --- a/hisscl/parser.py +++ b/hisscl/parser.py @@ -98,15 +98,14 @@ class Parser: return ast.Object(start_pos, items) - def _parse_func_call(self) -> ast.FunctionCall: - id_tok, id_pos, id_lit = self._scan() - tok, pos, lit = self._scan() - if tok != lexer.Token.PAREN or lit != '(': - raise ExpectedError(pos, 'opening parentheses', lit) - - tok, pos, lit = self._scan() + def _parse_func_call(self, val: ast.Value, start_pos: ast.Position) -> ast.FunctionCall: + tok, pos, lit = self._scan() if tok == lexer.Token.PAREN and lit == ')': - return ast.FunctionCall(pos=id_pos, name=id_lit, args=[]) + out = ast.FunctionCall(pos=start_pos, value=val, args=[]) + while self.lexer._peek(1) == '(': + _, start_pos, _ = self._scan() + out = self._parse_func_call(out, start_pos) + return out self._unscan(tok, pos, lit) args: list[ast.Value] = [] @@ -125,7 +124,12 @@ class Parser: break else: raise ExpectedError(pos, 'comma or closing parentheses', lit) - return ast.FunctionCall(pos=id_pos, name=id_lit, args=args) + + out = ast.FunctionCall(pos=start_pos, value=val, args=args) + while self.lexer._peek(1) == '(': + _, start_pos, _ = self._scan() + out = self._parse_func_call(out, start_pos) + return out def _parse_value(self) -> ast.Value: out = None @@ -140,11 +144,7 @@ class Parser: case lexer.Token.STRING: out = ast.String(pos=pos, value=pyast.literal_eval(lit)) case lexer.Token.IDENT: - if self.lexer._peek(1) == '(': - self._unscan(tok, pos, lit) - out = self._parse_func_call() - else: - out = ast.VariableRef(pos=pos, name=lit) + out = ast.VariableRef(pos=pos, name=lit) case lexer.Token.HEREDOC: out = ast.String(pos=pos, value=lit) case lexer.Token.OPERATOR: @@ -171,6 +171,8 @@ class Parser: tok, pos, lit = self._scan() if tok == lexer.Token.SQUARE and lit == '[': out = self._parse_index(out, pos) + elif tok == lexer.Token.PAREN and lit == '(': + out = self._parse_func_call(out, pos) elif tok == lexer.Token.DOT: out = self._parse_getattr(out, pos) else: diff --git a/test/test_parser.py b/test/test_parser.py index ea80f9e..81ae951 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -180,8 +180,11 @@ class TestExpressions(unittest.TestCase): def test_expansion(self): val = parser.Parser(io.StringIO('x(y...)'), 'TestExpressions.test_expansion')._parse_expr() self.assertEqual(val, ast.FunctionCall( - pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=1), - name = 'x', + pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=2), + value = ast.VariableRef( + pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=1), + name = 'x', + ), args = [ ast.Expansion( pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=3),