Implement indices

This commit is contained in:
Elara 2024-11-10 21:11:14 -08:00
parent 0dcacdc04a
commit 58b45e6cac
3 changed files with 61 additions and 28 deletions

View File

@ -80,7 +80,13 @@ class Expansion:
pos: Position pos: Position
value: 'Value' value: 'Value'
Expression = BinaryExpression | UnaryExpression | Expansion @dataclasses.dataclass
class Index:
pos: Position
value: 'Value'
index: 'Value'
Expression = BinaryExpression | UnaryExpression | Expansion | Index
Value = Literal | Collection | Expression | Ref Value = Literal | Collection | Expression | Ref
@dataclasses.dataclass @dataclasses.dataclass

View File

@ -4,9 +4,9 @@ from . import parser
import typing import typing
import io import io
__all__ = ['TypeError', 'Block', 'Interp'] __all__ = ['OperandError', 'Block', 'Interp']
class TypeError(Exception): class OperandError(Exception):
def __init__(self, pos: ast.Position, action: str, issue: str, val: typing.Any): def __init__(self, pos: ast.Position, action: str, issue: str, val: typing.Any):
super().__init__(f'{pos}: cannot perform {action} on {issue} operand ({type(val).__name__})') super().__init__(f'{pos}: cannot perform {action} on {issue} operand ({type(val).__name__})')
@ -33,6 +33,20 @@ class Interp:
def update(self, vars: dict[str, typing.Any]): def update(self, vars: dict[str, typing.Any]):
self.vars.update(vars) self.vars.update(vars)
def _eval_index(self, index: ast.Index) -> typing.Any:
val = self._convert_value(index.value)
if not hasattr(val, '__getitem__'):
raise ValueError(f'{index.value.pos}: value is not indexable ({type(val).__getitem__})')
index_val = self._convert_value(index.index)
if type(index_val) is int and hasattr(val, '__len__') and index_val >= len(val):
raise IndexError(f'{index.index.pos}: index out of range ({index_val} with length {len(val)})')
elif type(index_val) is not int:
if isinstance(val, list) or isinstance(val, tuple) or isinstance(val, str):
raise TypeError(f'{index.index.pos}: {type(val).__name__} indices must be integers, not {type(index_val).__name__}')
elif hasattr(val, '__contains__') and index_val not in val:
raise IndexError(f'{index.index.pos}: index {repr(index_val)} does not exist in value')
return val[index_val]
def _convert_value(self, val: ast.Value) -> typing.Any: def _convert_value(self, val: ast.Value) -> typing.Any:
if isinstance(val, ast.VariableRef): if isinstance(val, ast.VariableRef):
if val.name not in self.vars: if val.name not in self.vars:
@ -50,6 +64,8 @@ class Interp:
return self._eval_binary_expr(val) return self._eval_binary_expr(val)
elif isinstance(val, ast.UnaryExpression): elif isinstance(val, ast.UnaryExpression):
return self._eval_unary_expr(val) return self._eval_unary_expr(val)
elif isinstance(val, ast.Index):
return self._eval_index(val)
elif isinstance(val, ast.Expansion): elif isinstance(val, ast.Expansion):
raise ValueError(f'{val.pos}: cannot use expansion operator outside of a function call') raise ValueError(f'{val.pos}: cannot use expansion operator outside of a function call')
@ -80,11 +96,11 @@ class Interp:
match expr.op.value: match expr.op.value:
case '!': case '!':
if type(val) is not bool: if type(val) is not bool:
raise TypeError(expr.value.pos, 'NOT operation', 'non-boolean', val) raise OperandError(expr.value.pos, 'NOT operation', 'non-boolean', val)
return not val return not val
case '-': case '-':
if not self._is_numerical(val): if not self._is_numerical(val):
raise TypeError(expr.value.pos, 'negation', 'non-numerical', val) raise OperandError(expr.value.pos, 'negation', 'non-numerical', val)
return -val return -val
case _: case _:
raise ValueError(f'{expr.op.pos}: unknown unary operation: {repr(expr.op.value)}') raise ValueError(f'{expr.op.pos}: unknown unary operation: {repr(expr.op.value)}')
@ -100,69 +116,69 @@ class Interp:
return left != right return left != right
case '+': case '+':
if not self._is_numerical(left): if not self._is_numerical(left):
raise TypeError(expr.left.pos, 'addition operation', 'non-numerical', left) raise OperandError(expr.left.pos, 'addition operation', 'non-numerical', left)
elif not self._is_numerical(right): elif not self._is_numerical(right):
raise TypeError(expr.right.pos, 'addition operation', 'non-numerical', right) raise OperandError(expr.right.pos, 'addition operation', 'non-numerical', right)
return left + right return left + right
case '-': case '-':
if not self._is_numerical(left): if not self._is_numerical(left):
raise TypeError(expr.left.pos, 'subtraction operation', 'non-numerical', left) raise OperandError(expr.left.pos, 'subtraction operation', 'non-numerical', left)
elif not self._is_numerical(right): elif not self._is_numerical(right):
raise TypeError(expr.right.pos, 'subtraction operation', 'non-numerical', right) raise OperandError(expr.right.pos, 'subtraction operation', 'non-numerical', right)
return left - right return left - right
case '*': case '*':
if not self._is_numerical(left): if not self._is_numerical(left):
raise TypeError(expr.left.pos, 'multiplication operation', 'non-numerical', left) raise OperandError(expr.left.pos, 'multiplication operation', 'non-numerical', left)
elif not self._is_numerical(right): elif not self._is_numerical(right):
raise TypeError(expr.right.pos, 'multiplication operation', 'non-numerical', right) raise OperandError(expr.right.pos, 'multiplication operation', 'non-numerical', right)
return left * right return left * right
case '/': case '/':
if not self._is_numerical(left): if not self._is_numerical(left):
raise TypeError(expr.left.pos, 'division operation', 'non-numerical', left) raise OperandError(expr.left.pos, 'division operation', 'non-numerical', left)
elif not self._is_numerical(right): elif not self._is_numerical(right):
raise TypeError(expr.right.pos, 'division operation', 'non-numerical', right) raise OperandError(expr.right.pos, 'division operation', 'non-numerical', right)
return left / right return left / right
case '%': case '%':
if not self._is_numerical(left): if not self._is_numerical(left):
raise TypeError(expr.left.pos, 'modulo operation', 'non-numerical', left) raise OperandError(expr.left.pos, 'modulo operation', 'non-numerical', left)
elif not self._is_numerical(right): elif not self._is_numerical(right):
raise TypeError(expr.right.pos, 'modulo operation', 'non-numerical', right) raise OperandError(expr.right.pos, 'modulo operation', 'non-numerical', right)
return left % right return left % right
case '>': case '>':
if not self._is_comparable(left): if not self._is_comparable(left):
raise TypeError(expr.left.pos, 'comparison', 'non-comparable', left) raise OperandError(expr.left.pos, 'comparison', 'non-comparable', left)
elif not self._is_comparable(right): elif not self._is_comparable(right):
raise TypeError(expr.right.pos, 'comparison', 'non-comparable', right) raise OperandError(expr.right.pos, 'comparison', 'non-comparable', right)
return left > right return left > right
case '<': case '<':
if not self._is_comparable(left): if not self._is_comparable(left):
raise TypeError(expr.left.pos, 'comparison', 'non-comparable', left) raise OperandError(expr.left.pos, 'comparison', 'non-comparable', left)
elif not self._is_comparable(right): elif not self._is_comparable(right):
raise TypeError(expr.right.pos, 'comparison', 'non-comparable', right) raise OperandError(expr.right.pos, 'comparison', 'non-comparable', right)
return left < right return left < right
case '<=': case '<=':
if not self._is_comparable(left): if not self._is_comparable(left):
raise TypeError(expr.left.pos, 'comparison', 'non-comparable', left) raise OperandError(expr.left.pos, 'comparison', 'non-comparable', left)
elif not self._is_comparable(right): elif not self._is_comparable(right):
raise TypeError(expr.right.pos, 'comparison', 'non-comparable', right) raise OperandError(expr.right.pos, 'comparison', 'non-comparable', right)
return left <= right return left <= right
case '>=': case '>=':
if not self._is_comparable(left): if not self._is_comparable(left):
raise TypeError(expr.left.pos, 'comparison', 'non-comparable', left) raise OperandError(expr.left.pos, 'comparison', 'non-comparable', left)
elif not self._is_comparable(right): elif not self._is_comparable(right):
raise TypeError(expr.right.pos, 'comparison', 'non-comparable', right) raise OperandError(expr.right.pos, 'comparison', 'non-comparable', right)
return left >= right return left >= right
case '||': case '||':
if type(left) is not bool: if type(left) is not bool:
raise TypeError(expr.left.pos, 'OR operation', 'non-boolean', left) raise OperandError(expr.left.pos, 'OR operation', 'non-boolean', left)
elif type(right) is not bool: elif type(right) is not bool:
raise TypeError(expr.right.pos, 'OR operation', 'non-boolean', right) raise OperandError(expr.right.pos, 'OR operation', 'non-boolean', right)
return left or right return left or right
case '&&': case '&&':
if type(left) is not bool: if type(left) is not bool:
raise TypeError(expr.left.pos, 'AND operation', 'non-boolean', left) raise OperandError(expr.left.pos, 'AND operation', 'non-boolean', left)
elif type(right) is not bool: elif type(right) is not bool:
raise TypeError(expr.right.pos, 'AND operation', 'non-boolean', right) raise OperandError(expr.right.pos, 'AND operation', 'non-boolean', right)
return left and right return left and right
case _: case _:
raise ValueError(f'{expr.op.pos}: unknown binary operation: {repr(expr.op.value)}') raise ValueError(f'{expr.op.pos}: unknown binary operation: {repr(expr.op.value)}')

View File

@ -29,10 +29,21 @@ class Parser:
def _unscan(self, tok: lexer.Token, pos: ast.Position, lit: str): def _unscan(self, tok: lexer.Token, pos: ast.Position, lit: str):
self._prev = tok, pos, lit self._prev = tok, pos, lit
def _parse_index(self, val: ast.Value) -> ast.Index:
index = ast.Index(pos=val.pos, value=val, index=self._parse_expr())
tok, pos, lit = self._scan()
if tok != lexer.Token.SQUARE or lit != ']':
raise ExpectedError(pos, 'closing square bracket', lit)
return index
def _parse_expr(self) -> ast.Value: def _parse_expr(self) -> ast.Value:
left = self._parse_value() left = self._parse_value()
tok, pos, lit = self._scan() tok, pos, lit = self._scan()
while tok == lexer.Token.SQUARE and lit == '[':
left = self._parse_index(left)
# Scan the next token for the next if statement
tok, pos, lit = self._scan()
if tok != lexer.Token.OPERATOR: if tok != lexer.Token.OPERATOR:
self._unscan(tok, pos, lit) self._unscan(tok, pos, lit)
return left return left