fx92-interpreter/fx92/ast.py

125 lines
3.2 KiB
Python

# fx-92 Scientifique Collège+ language interpreter: AST definition
import enum
from decimal import Decimal
def auto_builder():
number = 0
def auto():
nonlocal number
number += 1
return number
return auto
try:
enum.auto()
except AttributeError:
enum.auto = auto_builder()
del auto_builder
#---
# Internal AST node representation
#---
@enum.unique
class N(enum.IntEnum):
# Core nodes
PROGRAM = enum.auto()
# Basic statements
FORWARD = enum.auto()
ROTATE = enum.auto()
ORIENT = enum.auto()
GOTO = enum.auto()
PENDOWN = enum.auto()
PENUP = enum.auto()
ASSIGN = enum.auto()
INPUT = enum.auto()
MESSAGE = enum.auto()
PRINT = enum.auto()
STYLE = enum.auto()
WAIT = enum.auto()
# Flow control
REPEAT = enum.auto()
UNTIL = enum.auto()
IF = enum.auto()
# Expressions
ADD = enum.auto()
SUB = enum.auto()
MUL = enum.auto()
DIV = enum.auto()
MINUS = enum.auto()
EXP = enum.auto()
VAR = enum.auto()
CONST = enum.auto()
FUN = enum.auto()
REL = enum.auto()
#---
# AST nodes
#---
class Node:
def __init__(self, type, *args):
"""Instantiate a new AST node."""
self.type = type
self.args = args
def __str__(self):
"""Basic text representation without children."""
if self.type == N.CONST:
return str(self.value)
try:
name = N(self.type).name
return "<Node:{}>".format(name)
except ValueError:
return "<Node:{}>".format(hex(self.type))
@property
def value(self):
"""Retrieve the value of a CONST node."""
if self.type != N.CONST:
raise Exception("Taking value of non-const node")
return self.args[0]
def constchildren(self):
"""Checks whether all arguments are constants."""
return all(c.type == N.CONST for c in self.args)
def simplify(self):
"""Simplify arithmetic expressions."""
simpl = lambda n: n.simplify() if isinstance(n, Node) else n
self.args = [ simpl(arg) for arg in self.args ]
arity = len(self.args)
if self.type == N.MUL and arity == 0:
return Node(N.CONST, Decimal(1))
if self.type == N.MUL and arity == 1:
return self.args[0]
if self.type == N.MUL and self.constchildren():
prod = Decimal(1)
for c in self.args:
prod *= c.value
return Node(N.CONST, prod)
if self.type == N.ADD and arity == 0:
return Node(N.CONST, Decimal(0))
if self.type == N.ADD and arity == 1:
return self.args[0]
if self.type == N.ADD and self.constchildren():
return Node(N.CONST, sum(c.value for c in self.args))
if self.type == N.ADD and arity == 2 and self.args[1].type == N.MINUS:
return Node(N.SUB, self.args[0], self.args[1].args[0])
if self.type == N.MINUS and self.constchildren():
return Node(N.CONST, -self.args[0].value)
if self.type == N.FUN:
newargs = [simpl(arg) for arg in self.args[1]]
return Node(N.FUN, self.args[0], newargs)
return self