Allow calling arbitrary values

This commit is contained in:
Elara 2024-11-13 15:30:50 -08:00
parent fb81eb84a7
commit 7a55ee0777
5 changed files with 41 additions and 34 deletions

View File

@ -40,7 +40,7 @@ class VariableRef:
@dataclasses.dataclass
class FunctionCall:
pos: Position
name: str
value: 'Value'
args: list['Value']
Ref = VariableRef | FunctionCall

View File

@ -80,20 +80,19 @@ class Interp:
return self._is_numerical(val) or isinstance(val, str)
def _exec_func_call(self, call: ast.FunctionCall) -> typing.Any:
if call.name not in self.vars:
raise KeyError(f'{call.pos}: no such function: {repr(call.name)}')
elif not callable(self.vars[call.name]):
val = self._convert_value(call.value)
if not callable(val):
raise ValueError(f'{call.pos}: cannot call non-callable object')
args = []
for arg in call.args:
if isinstance(arg, ast.Expansion):
val = self._convert_value(arg.value)
if not isinstance(val, typing.Iterable):
raise ValueError(f"{arg.pos}: cannot perform expansion on non-iterable value ({type(val).__name__})")
args.extend(val)
arg_val = self._convert_value(arg.value)
if not isinstance(arg_val, typing.Iterable):
raise ValueError(f"{arg.pos}: cannot perform expansion on non-iterable value ({type(arg_val).__name__})")
args.extend(arg_val)
else:
args.append(self._convert_value(arg))
return self.vars[call.name](*args)
return val(*args)
def _eval_unary_expr(self, expr: ast.UnaryExpression) -> float | int | bool:
val = self._convert_value(expr.value)

View File

@ -41,6 +41,9 @@ class Lexer:
self.unread = ''
self.stream = stream
self.pos.name = name
def _pos(self) -> ast.Position:
return dataclasses.replace(self.pos)
def _peek(self, n: int) -> str:
if self.unread != '':
@ -191,15 +194,15 @@ class Lexer:
match char:
case '{' | '}':
return Token.CURLY, self.pos, char
return Token.CURLY, self._pos(), char
case '[' | ']':
return Token.SQUARE, self.pos, char
return Token.SQUARE, self._pos(), char
case '(' | ')':
return Token.PAREN, self.pos, char
return Token.PAREN, self._pos(), char
case ',':
return Token.COMMA, self.pos, char
return Token.COMMA, self._pos(), char
case ':':
return Token.COLON, self.pos, char
return Token.COLON, self._pos(), char
case '"':
return self._scan_str()
case '<':
@ -230,12 +233,12 @@ class Lexer:
case '.':
if (next := self._read()) != '.':
self._unread(next)
return Token.DOT, self.pos, next
return Token.DOT, self._pos(), next
elif (next := self._read()) != '.':
raise ExpectedError(self.pos, '.', next)
return Token.ELLIPSIS, self.pos, "..."
return Token.ELLIPSIS, self._pos(), "..."
case '':
return Token.EOF, self.pos, char
return Token.EOF, self._pos(), char
if is_numeric(char):
return self._scan_number(char)
@ -244,7 +247,7 @@ class Lexer:
elif is_operator(char):
return self._scan_operator(char)
return Token.ILLEGAL, self.pos, char
return Token.ILLEGAL, self._pos(), char
def is_whitespace(char: str) -> bool:
return char in (' ', '\t', '\r', '\n')

View File

@ -98,15 +98,14 @@ class Parser:
return ast.Object(start_pos, items)
def _parse_func_call(self) -> ast.FunctionCall:
id_tok, id_pos, id_lit = self._scan()
tok, pos, lit = self._scan()
if tok != lexer.Token.PAREN or lit != '(':
raise ExpectedError(pos, 'opening parentheses', lit)
tok, pos, lit = self._scan()
def _parse_func_call(self, val: ast.Value, start_pos: ast.Position) -> ast.FunctionCall:
tok, pos, lit = self._scan()
if tok == lexer.Token.PAREN and lit == ')':
return ast.FunctionCall(pos=id_pos, name=id_lit, args=[])
out = ast.FunctionCall(pos=start_pos, value=val, args=[])
while self.lexer._peek(1) == '(':
_, start_pos, _ = self._scan()
out = self._parse_func_call(out, start_pos)
return out
self._unscan(tok, pos, lit)
args: list[ast.Value] = []
@ -125,7 +124,12 @@ class Parser:
break
else:
raise ExpectedError(pos, 'comma or closing parentheses', lit)
return ast.FunctionCall(pos=id_pos, name=id_lit, args=args)
out = ast.FunctionCall(pos=start_pos, value=val, args=args)
while self.lexer._peek(1) == '(':
_, start_pos, _ = self._scan()
out = self._parse_func_call(out, start_pos)
return out
def _parse_value(self) -> ast.Value:
out = None
@ -140,11 +144,7 @@ class Parser:
case lexer.Token.STRING:
out = ast.String(pos=pos, value=pyast.literal_eval(lit))
case lexer.Token.IDENT:
if self.lexer._peek(1) == '(':
self._unscan(tok, pos, lit)
out = self._parse_func_call()
else:
out = ast.VariableRef(pos=pos, name=lit)
out = ast.VariableRef(pos=pos, name=lit)
case lexer.Token.HEREDOC:
out = ast.String(pos=pos, value=lit)
case lexer.Token.OPERATOR:
@ -171,6 +171,8 @@ class Parser:
tok, pos, lit = self._scan()
if tok == lexer.Token.SQUARE and lit == '[':
out = self._parse_index(out, pos)
elif tok == lexer.Token.PAREN and lit == '(':
out = self._parse_func_call(out, pos)
elif tok == lexer.Token.DOT:
out = self._parse_getattr(out, pos)
else:

View File

@ -180,8 +180,11 @@ class TestExpressions(unittest.TestCase):
def test_expansion(self):
val = parser.Parser(io.StringIO('x(y...)'), 'TestExpressions.test_expansion')._parse_expr()
self.assertEqual(val, ast.FunctionCall(
pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=1),
name = 'x',
pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=2),
value = ast.VariableRef(
pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=1),
name = 'x',
),
args = [
ast.Expansion(
pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=3),