{-# LANGUAGE CPP, MultiWayIf #-}
module TmOracle (
        
        PmExpr(..), PmLit(..), SimpleEq, ComplexEq, PmVarEnv, falsePmExpr,
        eqPmLit, filterComplex, isNotPmExprOther, runPmPprM, lhsExprToPmExpr,
        hsExprToPmExpr, pprPmExprWithParens,
        
        tmOracle, TmState, initialTmState, solveOneEq, extendSubst, canDiverge,
        
        toComplex, exprDeepLookup, pmLitType, flattenPmVarEnv
    ) where
#include "HsVersions.h"
import GhcPrelude
import PmExpr
import Id
import Name
import Type
import HsLit
import TcHsSyn
import MonadUtils
import Util
import Outputable
import NameEnv
type PmVarEnv = NameEnv PmExpr
type TmOracleEnv = (Bool, PmVarEnv)
canDiverge :: Name -> TmState -> Bool
canDiverge x (standby, (_unhandled, env))
  
  
  | PmExprVar y <- varDeepLookup env x 
  
  
  
  
  = not $ any (isForcedByEq x) standby || any (isForcedByEq y) standby
  
  | otherwise = False
  where
    isForcedByEq :: Name -> ComplexEq -> Bool
    isForcedByEq y (e1, e2) = varIn y e1 || varIn y e2
varIn :: Name -> PmExpr -> Bool
varIn x e = case e of
  PmExprVar y    -> x == y
  PmExprCon _ es -> any (x `varIn`) es
  PmExprLit _    -> False
  PmExprEq e1 e2 -> (x `varIn` e1) || (x `varIn` e2)
  PmExprOther _  -> False
flattenPmVarEnv :: PmVarEnv -> PmVarEnv
flattenPmVarEnv env = mapNameEnv (exprDeepLookup env) env
type TmState = ([ComplexEq], TmOracleEnv)
initialTmState :: TmState
initialTmState = ([], (False, emptyNameEnv))
solveOneEq :: TmState -> ComplexEq -> Maybe TmState
solveOneEq solver_env@(_,(_,env)) complex
  = solveComplexEq solver_env 
  $ simplifyComplexEq               
  $ applySubstComplexEq env complex 
solveComplexEq :: TmState -> ComplexEq -> Maybe TmState
solveComplexEq solver_state@(standby, (unhandled, env)) eq@(e1, e2) = case eq of
  
  (PmExprOther _,_)            -> Just (standby, (True, env))
  (_,PmExprOther _)            -> Just (standby, (True, env))
  (PmExprLit l1, PmExprLit l2) -> case eqPmLit l1 l2 of
    
    True  -> Just solver_state
    False -> Nothing
  (PmExprCon c1 ts1, PmExprCon c2 ts2)
    | c1 == c2  -> foldlM solveComplexEq solver_state (zip ts1 ts2)
    | otherwise -> Nothing
  (PmExprCon _ [], PmExprEq t1 t2)
    | isTruePmExpr e1  -> solveComplexEq solver_state (t1, t2)
    | isFalsePmExpr e1 -> Just (eq:standby, (unhandled, env))
  (PmExprEq t1 t2, PmExprCon _ [])
    | isTruePmExpr e2   -> solveComplexEq solver_state (t1, t2)
    | isFalsePmExpr e2  -> Just (eq:standby, (unhandled, env))
  (PmExprVar x, PmExprVar y)
    | x == y    -> Just solver_state
    | otherwise -> extendSubstAndSolve x e2 solver_state
  (PmExprVar x, _) -> extendSubstAndSolve x e2 solver_state
  (_, PmExprVar x) -> extendSubstAndSolve x e1 solver_state
  (PmExprEq _ _, PmExprEq _ _) -> Just (eq:standby, (unhandled, env))
  _ -> WARN( True, text "solveComplexEq: Catch all" <+> ppr eq )
       Just (standby, (True, env)) 
extendSubstAndSolve :: Name -> PmExpr -> TmState -> Maybe TmState
extendSubstAndSolve x e (standby, (unhandled, env))
  = foldlM solveComplexEq new_incr_state (map simplifyComplexEq changed)
  where
    
    
    
    
    (changed, unchanged) = partitionWith (substComplexEq x e) standby
    new_incr_state       = (unchanged, (unhandled, extendNameEnv env x e))
extendSubst :: Id -> PmExpr -> TmState -> TmState
extendSubst y e (standby, (unhandled, env))
  | isNotPmExprOther simpl_e
  = (standby, (unhandled, extendNameEnv env x simpl_e))
  | otherwise = (standby, (True, env))
  where
    x = idName y
    simpl_e = fst $ simplifyPmExpr $ exprDeepLookup env e
simplifyComplexEq :: ComplexEq -> ComplexEq
simplifyComplexEq (e1, e2) = (fst $ simplifyPmExpr e1, fst $ simplifyPmExpr e2)
simplifyPmExpr :: PmExpr -> (PmExpr, Bool)
simplifyPmExpr e = case e of
  PmExprCon c ts -> case mapAndUnzip simplifyPmExpr ts of
                      (ts', bs) -> (PmExprCon c ts', or bs)
  PmExprEq t1 t2 -> simplifyEqExpr t1 t2
  _other_expr    -> (e, False) 
simplifyEqExpr :: PmExpr -> PmExpr -> (PmExpr, Bool)
simplifyEqExpr e1 e2 = case (e1, e2) of
  
  (PmExprVar x, PmExprVar y)
    | x == y -> (truePmExpr, True)
  
  (PmExprLit l1, PmExprLit l2) -> case eqPmLit l1 l2 of
    
    True  -> (truePmExpr,  True)
    False -> (falsePmExpr, True)
  
  (PmExprEq {}, _) -> case (simplifyPmExpr e1, simplifyPmExpr e2) of
    ((e1', True ), (e2', _    )) -> simplifyEqExpr e1' e2'
    ((e1', _    ), (e2', True )) -> simplifyEqExpr e1' e2'
    ((e1', False), (e2', False)) -> (PmExprEq e1' e2', False) 
  (_, PmExprEq {}) -> case (simplifyPmExpr e1, simplifyPmExpr e2) of
    ((e1', True ), (e2', _    )) -> simplifyEqExpr e1' e2'
    ((e1', _    ), (e2', True )) -> simplifyEqExpr e1' e2'
    ((e1', False), (e2', False)) -> (PmExprEq e1' e2', False) 
  
  (PmExprCon c1 ts1, PmExprCon c2 ts2)
    | c1 == c2 ->
        let (ts1', bs1) = mapAndUnzip simplifyPmExpr ts1
            (ts2', bs2) = mapAndUnzip simplifyPmExpr ts2
            (tss, _bss) = zipWithAndUnzip simplifyEqExpr ts1' ts2'
            worst_case  = PmExprEq (PmExprCon c1 ts1') (PmExprCon c2 ts2')
        in  if | not (or bs1 || or bs2) -> (worst_case, False) 
               | all isTruePmExpr  tss  -> (truePmExpr, True)
               | any isFalsePmExpr tss  -> (falsePmExpr, True)
               | otherwise              -> (worst_case, False)
    | otherwise -> (falsePmExpr, True)
  
  _other_equality -> (original, False)
  where
    original = PmExprEq e1 e2 
applySubstComplexEq :: PmVarEnv -> ComplexEq -> ComplexEq
applySubstComplexEq env (e1,e2) = (exprDeepLookup env e1, exprDeepLookup env e2)
varDeepLookup :: PmVarEnv -> Name -> PmExpr
varDeepLookup env x
  | Just e <- lookupNameEnv env x = exprDeepLookup env e 
  | otherwise                  = PmExprVar x          
{-# INLINE varDeepLookup #-}
exprDeepLookup :: PmVarEnv -> PmExpr -> PmExpr
exprDeepLookup env (PmExprVar x)    = varDeepLookup env x
exprDeepLookup env (PmExprCon c es) = PmExprCon c (map (exprDeepLookup env) es)
exprDeepLookup env (PmExprEq e1 e2) = PmExprEq (exprDeepLookup env e1)
                                               (exprDeepLookup env e2)
exprDeepLookup _   other_expr       = other_expr 
tmOracle :: TmState -> [ComplexEq] -> Maybe TmState
tmOracle tm_state eqs = foldlM solveOneEq tm_state eqs
pmLitType :: PmLit -> Type 
pmLitType (PmSLit   lit) = hsLitType   lit
pmLitType (PmOLit _ lit) = overLitType lit