# pyright: strict

"""
implements a constant propagation analysis
"""

from dataclasses import dataclass
from cfg import Cfg
from iteration import Transfer, fixpoint_iteration
from syntax import *

@dataclass
class Interval:
    """ potentially unbounded interval

    None in either bound means that we don't have info.
    Hence `Interval(None,None)` denotes all integers
    """
    lower_bound: int | None
    upper_bound: int | None
    def __str__(self):
        if self.lower_bound is None:
            l = "-∞"
        else:
            l = str(self.lower_bound)
        if self.upper_bound is None:
            u = "+∞"
        else:
            u = str(self.upper_bound)
        return f"[{l}; {u}]"

class Top:
    def __str__(self): return "TOP"

@dataclass
class BTop:
    """class used for the evaluation of boolean expressions

    BTop(False) indicates that we don't know if the result is True or False, but
    that the evaluation cannot lead to an error
    BTop(True) indicates that we neither know the result nor whether an error can occur
    """
    has_error: bool
    def __str__(self):
        if self.has_error: return "TOP (maybe error)"
        else: return "TOP (no error)"

type abstract_env = dict[str, Interval]
"""mapping from variables to abstract values.

As with concrete environment, a variable not in the dictionary will
lead to an error if we try to obtain its value
"""

type state = abstract_env | None
"""abstract state is either an abstract env or bottom
"""

def interval_leq(i1: Interval, i2: Interval) -> bool:
    if i2.lower_bound is None:
        if i2.upper_bound is None:
            return True
        if i1.upper_bound is None:
            return False
        return i1.upper_bound <= i2.upper_bound
    if i2.upper_bound is None:
        if i1.lower_bound is None:
            return False
        return i1.lower_bound >= i2.lower_bound
    if i1.lower_bound is None or i1.upper_bound is None:
        return False
    return i1.lower_bound >= i2.lower_bound and i1.upper_bound <= i2.upper_bound

def interval_join(i1: Interval, i2: Interval) -> Interval:
    if i1.lower_bound is None or i2.lower_bound is None:
        l = None
    else:
        l = min(i1.lower_bound,i2.lower_bound)
    if i1.upper_bound is None or i2.upper_bound is None:
        u = None
    else:
        u = max(i1.upper_bound,i2.upper_bound)
    return Interval(l,u)

def interval_meet(i1: Interval, i2: Interval) -> Interval | None:
    if i1.lower_bound is None:
        l = i2.lower_bound
    elif i2.lower_bound is None:
        l = i1.lower_bound
    else:
        l = max(i1.lower_bound,i2.lower_bound)
    if i1.upper_bound is None:
        u = i2.upper_bound
    elif i2.upper_bound is None:
        u = i1.upper_bound
    else:
        u = min(i1.upper_bound,i2.upper_bound)
    if l is not None and u is not None and l > u: return None
    return Interval(l,u)

def interval_widen(i1: Interval, i2: Interval) -> Interval:
    """widening operator for intervals"""
    if i1.lower_bound is None and i1.upper_bound is None:
        return i2
    
    if i1.lower_bound is None:
        l = None
    elif i2.lower_bound is None:
        l = None
    elif i2.lower_bound < i1.lower_bound:
        l = None  
    else:
        l = i1.lower_bound
    
    if i1.upper_bound is None:
        u = None
    elif i2.upper_bound is None:
        u = None
    elif i2.upper_bound > i1.upper_bound:
        u = None 
    else:
        u = i1.upper_bound
    
    return Interval(l, u)

def has_strict_positive_val(i: Interval) -> bool:
    return i.upper_bound is None or i.upper_bound > 0

def has_strict_negative_val(i: Interval) -> bool:
    return i.lower_bound is None or i.lower_bound < 0

def contains_zero(i: Interval) -> bool:
    if i.lower_bound is None:
        return i.upper_bound is None or i.upper_bound >= 0
    if i.upper_bound is None:
        return i.lower_bound <= 0
    return i.lower_bound <= 0 and i.upper_bound >= 0

def is_zero(i: Interval) -> bool:
    return i.lower_bound == 0 and i.upper_bound == 0

def interval_opp(i: Interval) -> Interval:
    if i.lower_bound is None:
        u = None
    else:
        u = -i.lower_bound
    if i.upper_bound is None:
        l = None
    else:
        l = -i.upper_bound
    return Interval(l,u)

def interval_add(i1: Interval, i2: Interval) -> Interval:
    """addition of two intervals"""
    if i1.lower_bound is None or i2.lower_bound is None:
        l = None
    else:
        l = i1.lower_bound + i2.lower_bound
    if i1.upper_bound is None or i2.upper_bound is None:
        u = None
    else:
        u = i1.upper_bound + i2.upper_bound
    return Interval(l, u)

def interval_mul(i1: Interval, i2: Interval) -> Interval:
    """multiplication of two intervals"""
    if i1.lower_bound is not None and i1.upper_bound is not None and i1.lower_bound > i1.upper_bound:
        return Interval(None, None)
    if i2.lower_bound is not None and i2.upper_bound is not None and i2.lower_bound > i2.upper_bound:
        return Interval(None, None)
    
    if contains_zero(i1) or contains_zero(i2):
        return Interval(None, None)
    
    if i1.lower_bound is not None and i1.lower_bound >= 0 and i2.lower_bound is not None and i2.lower_bound >= 0:
        l = i1.lower_bound * i2.lower_bound
        if i1.upper_bound is None or i2.upper_bound is None:
            u = None
        else:
            u = i1.upper_bound * i2.upper_bound
        return Interval(l, u)
    
    if i1.upper_bound is not None and i1.upper_bound <= 0 and i2.upper_bound is not None and i2.upper_bound <= 0:
        l = i1.upper_bound * i2.upper_bound
        if i1.lower_bound is None or i2.lower_bound is None:
            u = None
        else:
            u = i1.lower_bound * i2.lower_bound
        return Interval(l, u)
    
    if (i1.lower_bound is not None and i1.lower_bound >= 0 and i2.upper_bound is not None and i2.upper_bound <= 0) or \
       (i2.lower_bound is not None and i2.lower_bound >= 0 and i1.upper_bound is not None and i1.upper_bound <= 0):
        l = None
        if i1.lower_bound is not None and i2.lower_bound is not None:
            l = min(i1.lower_bound * i2.upper_bound, i1.upper_bound * i2.lower_bound)
        u = None
        if i1.upper_bound is not None and i2.upper_bound is not None:
            u = max(i1.lower_bound * i2.lower_bound, i1.upper_bound * i2.upper_bound)
        return Interval(l, u)
    
    return Interval(None, None)

def interval_div(i1: Interval, i2: Interval) -> Interval | Top | None:
    """division of two intervals"""
    if is_zero(i2):
        return None
    
    if contains_zero(i2):
        return Top()
    
    if i2.lower_bound is not None and i2.lower_bound > 0:
        if i1.lower_bound is None:
            l = None
        else:
            l = i1.lower_bound // i2.lower_bound
        if i1.upper_bound is None or i2.upper_bound is None:
            u = None
        else:
            u = i1.upper_bound // i2.upper_bound
        return Interval(l, u)
    
    if i2.upper_bound is not None and i2.upper_bound < 0:
        if i1.lower_bound is None:
            l = None
        else:
            l = i1.lower_bound // i2.upper_bound
        if i1.upper_bound is None or i2.lower_bound is None:
            u = None
        else:
            u = i1.upper_bound // i2.lower_bound
        return Interval(l, u)
    
    return Top()

def eval_aexp(env: abstract_env, e: ArithExpr) -> Interval | Top | None:
    """evaluate an arithmetic expression in an abstract environment
    returns None in case of error
    """
    match e:
        case AECst(value): return Interval(value,value)
        case AEVar(var):
            if var in env: return env[var]
            else: return None
        case AEUop(uop,expr):
            res = eval_aexp(env,expr)
            if res is None or isinstance(res,Top): return res
            if uop == Uop.OPP: return interval_opp(res)
            return None
        case AEBop(bop,left_expr,right_expr):
            v1 = eval_aexp(env,left_expr)
            v2 = eval_aexp(env,right_expr)
            if v1 is None or v2 is None: return None
            if isinstance(v1,Top) or isinstance(v2,Top):
                return Top()
            match bop:
                case Bop.ADD: return interval_add(v1,v2)
                case Bop.MUL: return interval_mul(v1,v2)
                case Bop.DIV: return interval_div(v1,v2)
        case _: pass

def eval_bexp(env: abstract_env, e: BoolExpr) -> bool | BTop | None:
    """abstract evaluation of a boolean expression"""
    match e:
        case BEPlain(aexpr):
            res = eval_aexp(env, aexpr)
            if res is None: return None
            if isinstance(res,Top): return BTop(True)
            if res.lower_bound == 0 and res.upper_bound == 0: return True
            if not interval_leq(Interval(0,0), res): return False
            return BTop(False)
        case BEEq(left_expr,right_expr):
            v1 = eval_aexp(env, left_expr)
            v2 = eval_aexp(env, right_expr)
            if v1 is None or v2 is None: return None
            if isinstance(v1,Top) or isinstance(v2,Top): return BTop(True)
            if v1.lower_bound is None or v2.lower_bound is None or v1.upper_bound is None or v2.upper_bound is None:
                return BTop(False)
            if v1.lower_bound > v2.upper_bound or v1.upper_bound < v2.lower_bound:
                return False
            if v1.lower_bound == v1.upper_bound and v2.lower_bound == v2.upper_bound and v1.lower_bound == v2.lower_bound:
                return True
            return BTop(False)
        case BELeq(left_expr,right_expr):
            v1 = eval_aexp(env, left_expr)
            v2 = eval_aexp(env, right_expr)
            if v1 is None or v2 is None: return None
            if isinstance(v1,Top) or isinstance(v2,Top): return BTop(True)
            if v1.upper_bound is None:
                if v1.lower_bound is None:
                    return BTop(False)
                if v2.upper_bound is None:
                    return BTop(False)
                if v2.upper_bound < v1.lower_bound:
                    return False
                return BTop(False)
            if v2.lower_bound is None:
                return BTop(False)
            if v1.upper_bound <= v2.lower_bound:
                return True
            if v1.lower_bound is None:
                return BTop(False)
            if v2.upper_bound is None:
                return BTop(False)
            if v2.upper_bound < v1.lower_bound:
                return False
            return BTop(False)
        case BENeg(expr):
            v = eval_bexp(env,expr)
            if v is None: return None
            if isinstance(v,BTop): return v
            return not v
        case _: pass

def reduce_eq(s: state, x: str, i: Interval) -> state:
    """Reduce the value of x under the hypothesis that it equals i"""
    if s is None:
        return None
    if x not in s:
        return None
    res = interval_meet(s[x], i)
    if res is None:
        return None
    return s | {x: res}

def reduce_neq(s: state, x: str, i: Interval) -> state:
    """Reduce the value of x under the hypothesis that it differs from i"""
    if s is None:
        return None
    if x not in s:
        return None
    if i.lower_bound == i.upper_bound and s[x].lower_bound == i.lower_bound:
        if s[x].upper_bound is None or s[x].upper_bound > i.upper_bound:
            return s | {x: Interval(i.upper_bound + 1, s[x].upper_bound)}
        return None
    if i.lower_bound == i.upper_bound and s[x].upper_bound == i.upper_bound:
        if s[x].lower_bound is None or s[x].lower_bound < i.lower_bound:
            return s | {x: Interval(s[x].lower_bound, i.lower_bound - 1)}
        return None
    return s

def reduce_leq(s: state, x: str, upper_bound: int) -> state:
    """Reduce the value of x under the hypothesis that it is less than or equal to upper_bound"""
    if s is None:
        return None
    if x not in s:
        return None
    res = interval_meet(s[x], Interval(None, upper_bound))
    if res is None:
        return None
    return s | {x: res}

def reduce_geq(s: state, x: str, lower_bound: int) -> state:
    """Reduce the value of x under the hypothesis that it is greater than or equal to lower_bound"""
    if s is None:
        return None
    if x not in s:
        return None
    res = interval_meet(s[x], Interval(lower_bound, None))
    if res is None:
        return None
    return s | {x: res}

def reduce_state(s: state,c: BoolExpr) -> state:
    if s is None: return s
    match c:
        case BEEq(AEVar(x), AEVar(y)):
            vx = eval_aexp(s,AEVar(x))
            vy = eval_aexp(s,AEVar(y))
            if vx is None or vy is None: return None
            if not isinstance(vx,Top):
                s = reduce_eq(s,y,vx)
            if not isinstance(vy,Top):
                s = reduce_eq(s,x,vy)
            return s
        case BEEq(AEVar(x), right_expr):
            v = eval_aexp(s,right_expr)
            if v is None: return None
            if isinstance(v,Top): return s
            return reduce_eq(s,x,v)
        case BEEq(left_expr, AEVar(y)):
            v = eval_aexp(s,left_expr)
            if v is None: return None
            if isinstance(v,Top): return s
            return reduce_eq(s,y,v)
        case BELeq(AEVar(x), AEVar(y)):
            vx = eval_aexp(s,AEVar(x))
            vy = eval_aexp(s,AEVar(y))
            if vx is None or vy is None: return None
            if not isinstance(vy,Top) and vy.upper_bound is not None:
                s = reduce_leq(s,x,vy.upper_bound)
            if not isinstance(vx,Top) and vx.lower_bound is not None:
                s = reduce_geq(s,y,vx.lower_bound)
            return s
        case BELeq(AEVar(x),right_expr):
            v = eval_aexp(s,right_expr)
            if v is None: return None
            if isinstance(v,Top) or v.upper_bound is None: return s
            return reduce_leq(s,x,v.upper_bound)
        case BELeq(left_expr,AEVar(y)):
            v = eval_aexp(s,left_expr)
            if v is None: return None
            if isinstance(v,Top) or v.lower_bound is None: return s
            return reduce_geq(s,y,v.lower_bound)
        case BENeg(BEEq(AEVar(x), AEVar(y))):
            vx = eval_aexp(s,AEVar(x))
            vy = eval_aexp(s,AEVar(y))
            if vx is None or vy is None: return None
            if not isinstance(vx,Top):
                s = reduce_neq(s,y,vx)
            if not isinstance(vy,Top):
                s = reduce_neq(s,x,vy)
            return s
        case BENeg(BEEq(AEVar(x), right_expr)):
            v = eval_aexp(s,right_expr)
            if v is None: return None
            if isinstance(v,Top): return s
            return reduce_neq(s,x,v)
        case BENeg(BEEq(left_expr, AEVar(y))):
            v = eval_aexp(s,left_expr)
            if v is None: return None
            if isinstance(v,Top): return s
            return reduce_neq(s,y,v)
        case BENeg(BELeq(AEVar(x), AEVar(y))):
            vx = eval_aexp(s,AEVar(x))
            vy = eval_aexp(s,AEVar(y))
            if vx is None or vy is None: return None
            if not isinstance(vx,Top) and vx.upper_bound is not None:
                s = reduce_leq(s,y,vx.upper_bound - 1)
            if not isinstance(vy,Top) and vy.lower_bound is not None:
                s = reduce_geq(s,x,vy.lower_bound + 1)
            return s
        case BENeg(BELeq(AEVar(x),right_expr)):
            v = eval_aexp(s,right_expr)
            if v is None: return None
            if isinstance(v,Top) or v.lower_bound is None: return s
            return reduce_geq(s,x,v.lower_bound + 1)
        case BENeg(BELeq(left_expr,AEVar(y))):
            v = eval_aexp(s,left_expr)
            if v is None: return None
            if isinstance(v,Top) or v.upper_bound is None: return s
            return reduce_leq(s,y,v.upper_bound - 1)
        case _: return s

class Interval_interp(Transfer[state]):
    variables: frozenset[str]

    def __init__(self, instr: Instr):
        self.variables = variables_of_instr(instr)

    def bottom(self) -> state:
        return None

    def init_state(self) -> state:
        return dict.fromkeys(self.variables, Interval(None, None))

    def join(self, s1: state, s2: state) -> state:
        if s1 is None:
            return s2
        if s2 is None:
            return s1
        res: abstract_env = {}
        for var in self.variables:
            res[var] = interval_join(s1[var], s2[var])
        return res

    def widen(self, s1: state, s2: state) -> state:
        if s1 is None:
            return s2
        if s2 is None:
            return s1
        res: abstract_env = {}
        for var in self.variables:
            res[var] = interval_widen(s1[var], s2[var])
        return res

    def included(self, s1: state, s2: state) -> bool:
        if s1 is None:
            return True
        if s2 is None:
            return False
        for var in self.variables:
            if not interval_leq(s1[var], s2[var]):
                return False
        return True

    def tr_skip(self, s: state) -> state:
        return s

    def tr_set(self, s: state, v: str, e: ArithExpr) -> state:
        """transfer function for assignment"""
        if s is None:
            return None
        res = eval_aexp(s, e)
        if res is None:
            return None
        if isinstance(res, Top):
            return s | {v: Interval(None, None)}
        return s | {v: res}

    def tr_test(self, s: state, c: BoolExpr) -> state:
        """transfer function for test"""
        if s is None:
            return None
        res = eval_bexp(s, c)
        if res is None:
            return None
        if isinstance(res, BTop):
            return s
        if res:
            return reduce_state(s, c)
        return None

    def tr_err(self, s: state, e: Expr) -> state:
        """transfer function for error"""
        if s is None:
            return s
        if isinstance(e, ArithExpr):
            res = eval_aexp(s, e)
            if res is None:
                return None
            if isinstance(res, Top):
                return s
            return s
        if isinstance(e, BoolExpr):
            res = eval_bexp(s, e)
            if res is None:
                return None
            if isinstance(res, BTop):
                return s
            return s
        return s

def analyze(i: Instr) -> None:
    cfg = Cfg(i)
    res = fixpoint_iteration(Interval_interp(i), cfg)
    for node in cfg.g.nodes:
        print(f"State at {node}:")
        s = res[node]
        if s is not None:
            for (v, s) in s.items():
                print(f"  {v}: {s}")