From f08f613a17b299c8e7e8a6d3038d4398d16171da Mon Sep 17 00:00:00 2001 From: Lephe Date: Tue, 1 Oct 2019 07:04:02 +0200 Subject: [PATCH] improvements on AST shape and printing --- ast.py | 27 +++++++++++++++++++++++++++ fx92.py | 2 +- printer.py | 47 +++++++++++++++++++++++++++++++++++++++-------- 3 files changed, 67 insertions(+), 9 deletions(-) diff --git a/ast.py b/ast.py index 3e9f41c..86e3982 100644 --- a/ast.py +++ b/ast.py @@ -51,13 +51,29 @@ class Node: self.args = args def __str__(self): + """Basic text representation without children.""" + if self.type == N.CONST: + return str(self.value) try: name = N(self.type).name return f"" except ValueError: return f"" + @property + def value(self): + """Retrive the value of a CONST node.""" + if self.type != N.CONST: + raise Exception("Taking value of non-const node") + return self.args[0] + + def constchildren(self): + """Checks whether all arguments are constants.""" + return all(c.type == N.CONST for c in self.args) + def simplify(self): + """Simplify arithmetic expressions.""" + simpl = lambda n: n.simplify() if isinstance(n, Node) else n self.args = [ simpl(arg) for arg in self.args ] arity = len(self.args) @@ -66,9 +82,20 @@ class Node: return Node(N.CONST, 1) if self.type == N.MUL and arity == 1: return self.args[0] + if self.type == N.MUL and self.constchildren(): + prod = 1 + for c in self.args: + prod *= c.value + return Node(N.CONST, prod) + if self.type == N.ADD and arity == 0: return Node(N.CONST, 0) if self.type == N.ADD and arity == 1: return self.args[0] + if self.type == N.ADD and self.constchildren(): + return Node(N.CONST, sum(c.value for c in self.args)) + + if self.type == N.MINUS and self.constchildren(): + return Node(N.CONST, -self.args[0].value) return self diff --git a/fx92.py b/fx92.py index 15c6fca..4f9e4bf 100755 --- a/fx92.py +++ b/fx92.py @@ -37,7 +37,7 @@ def main(argv): ast = parser.parse_program() ast = ast.simplify() - print_ast(ast, lang="fr") + print_ast(ast, lang="ast") if __name__ == "__main__": main(sys.argv) diff --git a/printer.py b/printer.py index eb99a57..04b8e38 100644 --- a/printer.py +++ b/printer.py @@ -7,9 +7,21 @@ __all__ = ["print_ast"] # Message definitions #--- +class MessageAST: + forward = "FORWARD {}" + rotate = "ROTATE {}" + orient = "ORIENT {}" + goto = "GOTO {}, {}" + pendown = "PENDOWN" + penup = "PENUP" + class MessageFrench: - multiply = "mul({})" - goto = "goto {}, {}" + forward = "Avancer de {} pixels" + rotate = "Tourner de {} degrés" + orient = "S'orienter à {} degrés" + goto = "Aller à x={}; y={}" + pendown = "Stylo écrit" + penup = "Stylo relevé" class MessageEnglish: pass @@ -19,8 +31,14 @@ class MessageEnglish: #--- def print_ast(n, lang="en", indent=0): - if lang == "fr": lang = MessageFrench - if lang == "en": lang = MessageEnglish + if lang == "fr": lang = MessageFrench + if lang == "en": lang = MessageEnglish + if lang == "ast": lang = MessageAST + + if n.type == N.PROGRAM: + for arg in n.args: + print_ast(arg, lang=lang, indent=indent) + return print(" " * indent, end="") @@ -30,9 +48,22 @@ def print_ast(n, lang="en", indent=0): if n.type == N.CONST: print(n.args[0]) - elif n.type == N.VAR: - print(f"VAR({n.args[0]})") + return + + if n.type == N.VAR: + print(f"{n.args[0]}") + return + + id = n.type.name.lower() + + if hasattr(lang, id): + print(getattr(lang, id).format(*n.args)) else: print(f"{n.type.name}") - for arg in n.args: - print_ast(arg, lang=lang, indent=indent+2) + + if n.type in [N.FORWARD, N.ROTATE, N.ORIENT, N.GOTO] and \ + n.constchildren(): + return + + for arg in n.args: + print_ast(arg, lang=lang, indent=indent+2)