Implement getattr operation

This commit is contained in:
Elara 2024-11-13 13:05:25 -08:00
parent 2278a16ca0
commit fb81eb84a7
6 changed files with 80 additions and 12 deletions

View File

@ -86,7 +86,13 @@ class Index:
value: 'Value' value: 'Value'
index: 'Value' index: 'Value'
Expression = BinaryExpression | UnaryExpression | Expansion | Index @dataclasses.dataclass
class GetAttr:
pos: Position
value: 'Value'
attr: str
Expression = BinaryExpression | UnaryExpression | Expansion | Index | GetAttr
Value = Literal | Collection | Expression | Ref Value = Literal | Collection | Expression | Ref
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -65,6 +65,11 @@ class Interp:
return self._eval_unary_expr(val) return self._eval_unary_expr(val)
elif isinstance(val, ast.Index): elif isinstance(val, ast.Index):
return self._eval_index(val) return self._eval_index(val)
elif isinstance(val, ast.GetAttr):
obj = self._convert_value(val.value)
if not hasattr(obj, val.attr):
raise AttributeError(f'{val.pos}: no such attribute {repr(val.attr)} in object of type {type(obj).__name__}')
return getattr(obj, val.attr)
elif isinstance(val, ast.Expansion): elif isinstance(val, ast.Expansion):
raise ValueError(f'{val.pos}: cannot use expansion operator outside of a function call') raise ValueError(f'{val.pos}: cannot use expansion operator outside of a function call')

View File

@ -25,6 +25,7 @@ class Token(enum.Enum):
COLON = 12 COLON = 12
OPERATOR = 13 OPERATOR = 13
ELLIPSIS = 14 ELLIPSIS = 14
DOT = 15
class ExpectedError(Exception): class ExpectedError(Exception):
def __init__(self, pos: ast.Position, expected: str, got: str): def __init__(self, pos: ast.Position, expected: str, got: str):
@ -228,8 +229,9 @@ class Lexer:
return self.scan() return self.scan()
case '.': case '.':
if (next := self._read()) != '.': if (next := self._read()) != '.':
raise ExpectedError(self.pos, '.', next) self._unread(next)
if (next := self._read()) != '.': return Token.DOT, self.pos, next
elif (next := self._read()) != '.':
raise ExpectedError(self.pos, '.', next) raise ExpectedError(self.pos, '.', next)
return Token.ELLIPSIS, self.pos, "..." return Token.ELLIPSIS, self.pos, "..."
case '': case '':

View File

@ -29,15 +29,27 @@ class Parser:
def _unscan(self, tok: lexer.Token, pos: ast.Position, lit: str): def _unscan(self, tok: lexer.Token, pos: ast.Position, lit: str):
self._prev = tok, pos, lit self._prev = tok, pos, lit
def _parse_index(self, val: ast.Value) -> ast.Index: def _parse_index(self, val: ast.Value, start_pos: ast.Position) -> ast.Index:
index = ast.Index(pos=val.pos, value=val, index=self._parse_expr()) index = ast.Index(pos=start_pos, value=val, index=self._parse_expr())
tok, pos, lit = self._scan() tok, start_pos, lit = self._scan()
if tok != lexer.Token.SQUARE or lit != ']': if tok != lexer.Token.SQUARE or lit != ']':
raise ExpectedError(pos, 'closing square bracket', lit) raise ExpectedError(start_pos, 'closing square bracket', lit)
while self.lexer._peek(1) == '[': while self.lexer._peek(1) == '[':
self._scan() _, start_pos, _ = self._scan()
index = self._parse_index(index) index = self._parse_index(index, start_pos)
return index return index
def _parse_getattr(self, val: ast.Value, start_pos: ast.Position) -> ast.Index | ast.GetAttr:
tok, pos, lit = self._scan()
if tok == lexer.Token.INTEGER:
return ast.Index(pos=start_pos, value=val, index=ast.Integer(pos=pos, value=int(lit)))
elif tok == lexer.Token.IDENT:
return ast.GetAttr(pos=start_pos, value=val, attr=lit)
else:
raise ExpectedError(pos, 'integer or identifier', lit)
while self.lexer._peek(1) == '.':
_, start_pos, _ = self._scan()
index = self._parse_getattr(index, start_pos)
def _parse_expr(self) -> ast.Value: def _parse_expr(self) -> ast.Value:
left = self._parse_value() left = self._parse_value()
@ -156,9 +168,13 @@ class Parser:
case _: case _:
raise ExpectedError(pos, 'value', lit) raise ExpectedError(pos, 'value', lit)
if self.lexer._peek(1) == '[': tok, pos, lit = self._scan()
self._scan() if tok == lexer.Token.SQUARE and lit == '[':
out = self._parse_index(out) out = self._parse_index(out, pos)
elif tok == lexer.Token.DOT:
out = self._parse_getattr(out, pos)
else:
self._unscan(tok, pos, lit)
return out return out

View File

@ -114,6 +114,22 @@ class TestRefs(unittest.TestCase):
cfg = interp.Interp(io.StringIO('x = ["123", "456", "789"][1][2]'), "TestRefs.test_multi_index").run() cfg = interp.Interp(io.StringIO('x = ["123", "456", "789"][1][2]'), "TestRefs.test_multi_index").run()
self.assertIn('x', cfg) self.assertIn('x', cfg)
self.assertEqual(cfg['x'], '6') self.assertEqual(cfg['x'], '6')
def test_index_legacy(self):
i = interp.Interp(io.StringIO('x = y.1'), "TestRefs.test_index_legacy")
i['y'] = [123, 456, 789]
cfg = i.run()
self.assertIn('x', cfg)
self.assertEqual(cfg['x'], 456)
def test_getattr(self):
class Y:
z = 123
i = interp.Interp(io.StringIO('x = Y.z'), "TestRefs.test_getattr")
i['Y'] = Y()
cfg = i.run()
self.assertIn('x', cfg)
self.assertEqual(cfg['x'], 123)
def test_func(self): def test_func(self):
def y(a, b): def y(a, b):

View File

@ -205,6 +205,29 @@ class TestExpressions(unittest.TestCase):
pos = ast.Position(name='TestExpressions.test_index', line=1, col=3), pos = ast.Position(name='TestExpressions.test_index', line=1, col=3),
value = 0, value = 0,
)) ))
def test_index_legacy(self):
val = parser.Parser(io.StringIO('x.0'), 'TestExpressions.test_index_legacy')._parse_expr()
self.assertIsInstance(val, ast.Index)
assert type(val) is ast.Index
self.assertEqual(val.value, ast.VariableRef(
pos = ast.Position(name='TestExpressions.test_index_legacy', line=1, col=1),
name = 'x',
))
self.assertEqual(val.index, ast.Integer(
pos = ast.Position(name='TestExpressions.test_index_legacy', line=1, col=3),
value = 0,
))
def test_getattr(self):
val = parser.Parser(io.StringIO('x.y'), 'TestExpressions.test_getattr')._parse_expr()
self.assertIsInstance(val, ast.GetAttr)
assert type(val) is ast.GetAttr
self.assertEqual(val.value, ast.VariableRef(
pos = ast.Position(name='TestExpressions.test_getattr', line=1, col=1),
name = 'x',
))
self.assertEqual(val.attr, 'y')
def test_unary(self): def test_unary(self):
val = parser.Parser(io.StringIO('!true'), 'TestExpressions.test_unary')._parse_value() val = parser.Parser(io.StringIO('!true'), 'TestExpressions.test_unary')._parse_value()