diff --git a/Pystan/config/sign_analysis.py b/Pystan/config/sign_analysis.py index 38b4e9e2b1f6c8d4b1fd3f6f0ee7773c61b02f89..1d204b5ad331763d7981674774760bdcd08b2ef3 100644 --- a/Pystan/config/sign_analysis.py +++ b/Pystan/config/sign_analysis.py @@ -374,44 +374,51 @@ def reduce_state(s: state, c: BoolExpr) -> state: class Sign_interp(Transfer[state]): variables: frozenset[str] - def __init__(self,instr: Instr): + + def __init__(self, instr: Instr): self.variables = variables_of_instr(instr) + def bottom(self) -> state: return None + def init_state(self) -> state: return dict.fromkeys(self.variables, sign.INT) - def join(self,s1:state,s2:state) -> state: + + def join(self, s1: state, s2: state) -> state: if s1 is None: return s2 if s2 is None: return s1 - res: abstract_env = {} - for var in self.variables: - v1 = s1[var] - v2 = s2[var] - res[var] = sign_join(v1,v2) - return res - - def included(self,s1: state,s2: state) -> bool: + return {v: sign_join(s1[v], s2[v]) for v in self.variables} + + def included(self, s1: state, s2: state) -> bool: if s1 is None: return True if s2 is None: return False - for var in self.variables: - if not sign_leq(s1[var], s2[var]): return False - return True + return all(sign_leq(s1[v], s2[v]) for v in self.variables) - def tr_skip(self,s: state) -> state: + def tr_skip(self, s: state) -> state: return s - def tr_set(self,s: state,v: str,e: ArithExpr) -> state: - raise NotImplementedError - - def tr_test(self,s: state,c: BoolExpr) -> state: - raise NotImplementedError - - def tr_err(self,s: state,e: Expr) -> state: - if s is None: return s - if isinstance(e,ArithExpr): - raise NotImplementedError - if isinstance(e,BoolExpr): - raise NotImplementedError + def tr_set(self, s: state, v: str, e: ArithExpr) -> state: + if s is None: return None + val = eval_aexp(s, e) + if val is None or isinstance(val, Top): + s1 = s.copy() + s1[v] = sign.INT + return s1 + s1 = s.copy() + s1[v] = val + return s1 + + def tr_test(self, s: state, c: BoolExpr) -> state: + return reduce_state(s, c) + + def tr_err(self, s: state, e: Expr) -> state: + if s is None: + return s + match e: + case ArithExpr(): + return s if eval_aexp(s, e) is not None else None + case BoolExpr(): + return s if eval_bexp(s, e) is not None else None def analyze(i: Instr) -> None: cfg = Cfg(i)