From 9830432d0cd0d84eda5e86dbfb906e151f9b0133 Mon Sep 17 00:00:00 2001
From: Elara6331 <elara@elara.ws>
Date: Wed, 13 Nov 2024 15:26:08 -0800
Subject: [PATCH] Allow calling arbitrary values

---
 hisscl/ast.py       |  2 +-
 hisscl/interp.py    | 15 +++++++--------
 hisscl/lexer.py     | 21 ++++++++++++---------
 hisscl/parser.py    | 30 ++++++++++++++++--------------
 test/test_parser.py |  7 +++++--
 5 files changed, 41 insertions(+), 34 deletions(-)

diff --git a/hisscl/ast.py b/hisscl/ast.py
index 1d27de6..d1bd9c9 100644
--- a/hisscl/ast.py
+++ b/hisscl/ast.py
@@ -40,7 +40,7 @@ class VariableRef:
 @dataclasses.dataclass
 class FunctionCall:
     pos: Position
-    name: str
+    value: 'Val'
     args: list['Value']
 
 Ref = VariableRef | FunctionCall
diff --git a/hisscl/interp.py b/hisscl/interp.py
index d867282..a8fdff3 100644
--- a/hisscl/interp.py
+++ b/hisscl/interp.py
@@ -80,20 +80,19 @@ class Interp:
         return self._is_numerical(val) or isinstance(val, str)
     
     def _exec_func_call(self, call: ast.FunctionCall) -> typing.Any:
-        if call.name not in self.vars:
-            raise KeyError(f'{call.pos}: no such function: {repr(call.name)}')
-        elif not callable(self.vars[call.name]):
+        val = self._convert_value(call.value)
+        if not callable(val):
             raise ValueError(f'{call.pos}: cannot call non-callable object')
         args = []
         for arg in call.args:
             if isinstance(arg, ast.Expansion):
-                val = self._convert_value(arg.value)
-                if not isinstance(val, typing.Iterable):
-                    raise ValueError(f"{arg.pos}: cannot perform expansion on non-iterable value ({type(val).__name__})")
-                args.extend(val)
+                arg_val = self._convert_value(arg.value)
+                if not isinstance(arg_val, typing.Iterable):
+                    raise ValueError(f"{arg.pos}: cannot perform expansion on non-iterable value ({type(arg_val).__name__})")
+                args.extend(arg_val)
             else:
                 args.append(self._convert_value(arg))
-        return self.vars[call.name](*args)
+        return val(*args)
         
     def _eval_unary_expr(self, expr: ast.UnaryExpression) -> float | int | bool:
         val = self._convert_value(expr.value)
diff --git a/hisscl/lexer.py b/hisscl/lexer.py
index a8677cc..719e634 100644
--- a/hisscl/lexer.py
+++ b/hisscl/lexer.py
@@ -41,6 +41,9 @@ class Lexer:
         self.unread = ''
         self.stream = stream
         self.pos.name = name
+        
+    def _pos(self) -> ast.Position:
+        return dataclasses.replace(self.pos)
 
     def _peek(self, n: int) -> str:
         if self.unread != '':
@@ -191,15 +194,15 @@ class Lexer:
 
         match char:
             case '{' | '}':
-                return Token.CURLY, self.pos, char
+                return Token.CURLY, self._pos(), char
             case '[' | ']':
-                return Token.SQUARE, self.pos,  char
+                return Token.SQUARE, self._pos(),  char
             case '(' | ')':
-                return Token.PAREN, self.pos, char
+                return Token.PAREN, self._pos(), char
             case ',':
-                return Token.COMMA, self.pos, char
+                return Token.COMMA, self._pos(), char
             case ':':
-                return Token.COLON, self.pos, char
+                return Token.COLON, self._pos(), char
             case '"':
                     return self._scan_str()
             case '<':
@@ -230,12 +233,12 @@ class Lexer:
             case '.':
                 if (next := self._read()) != '.':
                     self._unread(next)
-                    return Token.DOT, self.pos, next
+                    return Token.DOT, self._pos(), next
                 elif (next := self._read()) != '.':
                     raise ExpectedError(self.pos, '.', next)
-                return Token.ELLIPSIS, self.pos, "..."
+                return Token.ELLIPSIS, self._pos(), "..."
             case '':
-                return Token.EOF, self.pos, char
+                return Token.EOF, self._pos(), char
 
         if is_numeric(char):
             return self._scan_number(char)
@@ -244,7 +247,7 @@ class Lexer:
         elif is_operator(char):
             return self._scan_operator(char)
 
-        return Token.ILLEGAL, self.pos, char
+        return Token.ILLEGAL, self._pos(), char
 
 def is_whitespace(char: str) -> bool:
     return char in (' ', '\t', '\r', '\n')
diff --git a/hisscl/parser.py b/hisscl/parser.py
index c9eef52..c0af86c 100644
--- a/hisscl/parser.py
+++ b/hisscl/parser.py
@@ -98,15 +98,14 @@ class Parser:
             
         return ast.Object(start_pos, items)
         
-    def _parse_func_call(self) -> ast.FunctionCall:
-        id_tok, id_pos, id_lit = self._scan()
-        tok, pos, lit = self._scan()
-        if tok != lexer.Token.PAREN or lit != '(':
-            raise ExpectedError(pos, 'opening parentheses', lit)
-        
-        tok, pos, lit = self._scan()
+    def _parse_func_call(self, val: ast.Value, start_pos: ast.Position) -> ast.FunctionCall:        
+        tok, pos, lit = self._scan()        
         if tok == lexer.Token.PAREN and lit == ')':
-            return ast.FunctionCall(pos=id_pos, name=id_lit, args=[])
+            out = ast.FunctionCall(pos=start_pos, value=val, args=[])
+            while self.lexer._peek(1) == '(':
+                _, start_pos, _ = self._scan()
+                out = self._parse_func_call(out, start_pos)
+            return out
         self._unscan(tok, pos, lit)
         
         args: list[ast.Value] = []
@@ -125,7 +124,12 @@ class Parser:
                 break
             else:
                 raise ExpectedError(pos, 'comma or closing parentheses', lit)
-        return ast.FunctionCall(pos=id_pos, name=id_lit, args=args)
+        
+        out = ast.FunctionCall(pos=start_pos, value=val, args=args)
+        while self.lexer._peek(1) == '(':
+            _, start_pos, _ = self._scan()
+            out = self._parse_func_call(out, start_pos)
+        return out
     
     def _parse_value(self) -> ast.Value:
         out = None
@@ -140,11 +144,7 @@ class Parser:
             case lexer.Token.STRING:
                 out = ast.String(pos=pos, value=pyast.literal_eval(lit))
             case lexer.Token.IDENT:
-                if self.lexer._peek(1) == '(':
-                    self._unscan(tok, pos, lit)
-                    out = self._parse_func_call()
-                else:
-                    out = ast.VariableRef(pos=pos, name=lit)
+                out = ast.VariableRef(pos=pos, name=lit)
             case lexer.Token.HEREDOC:
                 out = ast.String(pos=pos, value=lit)
             case lexer.Token.OPERATOR:
@@ -171,6 +171,8 @@ class Parser:
         tok, pos, lit = self._scan()
         if tok == lexer.Token.SQUARE and lit == '[':
             out = self._parse_index(out, pos)
+        elif tok == lexer.Token.PAREN and lit == '(':
+            out = self._parse_func_call(out, pos)
         elif tok == lexer.Token.DOT:
             out = self._parse_getattr(out, pos)
         else:
diff --git a/test/test_parser.py b/test/test_parser.py
index ea80f9e..81ae951 100644
--- a/test/test_parser.py
+++ b/test/test_parser.py
@@ -180,8 +180,11 @@ class TestExpressions(unittest.TestCase):
     def test_expansion(self):
         val = parser.Parser(io.StringIO('x(y...)'), 'TestExpressions.test_expansion')._parse_expr()
         self.assertEqual(val, ast.FunctionCall(
-            pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=1),
-            name = 'x',
+            pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=2),
+            value = ast.VariableRef(
+                pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=1),
+                name = 'x',
+            ),
             args = [
                 ast.Expansion(
                     pos = ast.Position(name='TestExpressions.test_expansion', line=1, col=3),