# SPDX-License-Identifier: Apache-2.0
#
# This file is part of the M2-ISA-R project: https://github.com/tum-ei-eda/M2-ISA-R
#
# Copyright (C) 2022
# Chair of Electrical Design Automation
# Technical University of Munich
"""A transformation module for simplifying M2-ISA-R behavior expressions. The following
simplifications are done:
* Resolvable :class:`m2isar.metamodel.arch.Constant` s are replaced by
`m2isar.metamodel.arch.IntLiteral` s representing their value
* Fully resolvable arithmetic operations are carried out and their results
represented as a matching :class:`m2isar.metamodel.arch.IntLiteral`
* Conditions and loops with fully resolvable conditions are either discarded entirely
or transformed into code blocks without any conditions
* Ternaries with fully resolvable conditions are transformed into only the matching part
* Type conversions of :class:`m2isar.metamodel.arch.IntLiteral` s apply the desired
type directly to the :class:`IntLiteral` and discard the type conversion
"""
from ...metamodel import arch, behav
from .ExprVisitor import ExprVisitor
from functools import singledispatchmethod
# pylint: disable=unused-argument
[docs]
class ExprSimplifierVisitor(ExprVisitor):
"""Visitor that simplifies behavior expression trees."""
@singledispatchmethod
[docs]
def generate(self, expr : behav.BaseNode, context=None):
raise NotImplementedError(f"No visit method implemented for type {type(expr).__name__} in {type(expr).__name__}")
@generate.register
[docs]
def _(self, expr: behav.Operation, context):
statements = []
for stmt in expr.statements:
try:
temp = self.generate(stmt, context)
if isinstance(temp, list):
statements.extend(temp)
else:
statements.append(temp)
except (NotImplementedError, ValueError):
print(f"cant simplify {stmt}")
expr.statements = statements
return expr
@generate.register
def _(self, expr: behav.Block, context):
expr.statements = [self.generate(x, context) for x in expr.statements]
return expr
@generate.register
def _(self, expr: behav.BinaryOperation, context):
expr.left = self.generate(expr.left, context)
expr.right = self.generate(expr.right, context)
if isinstance(expr.left, behav.IntLiteral) and isinstance(expr.right, (behav.NamedReference, behav.IndexedReference)):
if expr.left.bit_size < expr.right.reference.size:
expr.left.bit_size = expr.right.reference.size
if isinstance(expr.right, behav.IntLiteral) and isinstance(expr.left, (behav.NamedReference, behav.IndexedReference)):
if expr.right.bit_size < expr.left.reference.size:
expr.right.bit_size = expr.left.reference.size
if isinstance(expr.left, behav.IntLiteral) and isinstance(expr.right, behav.IntLiteral):
# pylint: disable=eval-used
res: int = int(eval(f"{expr.left.value}{expr.op.value}{expr.right.value}"))
return behav.IntLiteral(res, max(expr.left.bit_size, expr.right.bit_size, res.bit_length()))
if expr.op.value == "&&":
if isinstance(expr.left, behav.IntLiteral):
if expr.left.value:
return expr.right
return expr.left
if isinstance(expr.right, behav.IntLiteral):
if expr.right.value:
return expr.left
return expr.right
if expr.op.value == "||":
if isinstance(expr.left, behav.IntLiteral):
if expr.left.value:
return expr.left
return expr.right
if isinstance(expr.right, behav.IntLiteral):
if expr.right.value:
return expr.right
return expr.left
return expr
@generate.register
def _(self, expr: behav.SliceOperation, context):
expr.expr = self.generate(expr.expr, context)
expr.left = self.generate(expr.left, context)
expr.right = self.generate(expr.right, context)
return expr
@generate.register
def _(self, expr: behav.ConcatOperation, context):
expr.left = self.generate(expr.left, context)
expr.right = self.generate(expr.right, context)
return expr
@generate.register
def _(self, expr: behav.NumberLiteral, context):
return expr
@generate.register
def _(self, expr: behav.IntLiteral, context):
return expr
@generate.register
def _(self, expr: behav.StringLiteral, context):
return expr
@generate.register
def _(self, expr: behav.ScalarDefinition, context):
return expr
@generate.register
def _(self, expr: behav.Break, context):
return expr
@generate.register
def _(self, expr: behav.Assignment, context):
expr.target = self.generate(expr.target, context)
expr.expr = self.generate(expr.expr, context)
if isinstance(expr.expr, behav.IntLiteral) and isinstance(expr.target, (behav.NamedReference, behav.IndexedReference)):
if expr.expr.bit_size < expr.target.reference.size:
expr.expr.bit_size = expr.target.reference.size
return expr
@generate.register
def _(self, expr: behav.Conditional, context):
expr.conds = [self.generate(x, context) for x in expr.conds]
expr.stmts = [self.generate(x, context) for x in expr.stmts]
eval_false = True
conds = []
stmts = []
for cond, stmt in zip(expr.conds, expr.stmts):
if isinstance(cond, behav.IntLiteral):
if cond.value:
return stmt
else:
conds.append(cond)
stmts.append(stmt)
eval_false = False
if len(expr.conds) < len(expr.stmts):
if eval_false and isinstance(expr.conds[-1], behav.IntLiteral):
if not cond.value: # pylint: disable=undefined-loop-variable
return expr.stmts[-1]
stmts.append(expr.stmts[-1])
expr.conds = conds
expr.stmts = stmts
return expr
@generate.register
def _(self, expr: behav.Loop, context):
expr.cond = self.generate(expr.cond, context)
expr.stmts = [self.generate(x, context) for x in expr.stmts]
return expr
@generate.register
def _(self, expr: behav.Ternary, context):
expr.cond = self.generate(expr.cond, context)
expr.then_expr = self.generate(expr.then_expr, context)
expr.else_expr = self.generate(expr.else_expr, context)
if isinstance(expr.cond, behav.IntLiteral):
if expr.cond.value:
return expr.then_expr
return expr.else_expr
return expr
@generate.register
def _(self, expr: behav.Return, context):
if expr.expr is not None:
expr.expr = self.generate(expr.expr, context)
return expr
@generate.register
def _(self, expr: behav.UnaryOperation, context):
expr.right = self.generate(expr.right, context)
if isinstance(expr.right, behav.IntLiteral):
# pylint: disable=eval-used
res: int = eval(f"{expr.op.value}{expr.right.value}")
return behav.IntLiteral(res, max(expr.right.bit_size, res.bit_length()))
return expr
@generate.register
def _(self, expr: behav.NamedReference, context):
if isinstance(expr.reference, arch.Constant):
return behav.IntLiteral(expr.reference.value, expr.reference.size, expr.reference.signed)
return expr
@generate.register
def _(self, expr: behav.IndexedReference, context):
expr.index = self.generate(expr.index, context)
return expr
@generate.register
def _(self, expr: behav.TypeConv, context):
expr.expr = self.generate(expr.expr, context)
if isinstance(expr.expr, behav.IntLiteral):
size = expr.size
if size is None:
assert expr.inferred_type is not None
size = expr.inferred_type.width
assert size is not None
expr.expr.bit_size = size
assert expr.expr.bit_size is not None
expr.expr.signed = expr.data_type == arch.DataType.S
return expr.expr
return expr
@generate.register
def _(self, expr: behav.Callable, context):
expr.args = [self.generate(stmt, context) for stmt in expr.args]
return expr
@generate.register
def _(self, expr: behav.ProcedureCall, context):
expr.args = [self.generate(stmt, context) for stmt in expr.args]
return expr
@generate.register
def _(self, expr: behav.Group, context):
expr.expr = self.generate(expr.expr, context)
if isinstance(expr.expr, behav.IntLiteral):
return expr.expr
return expr