{-# LANGUAGE ExistentialQuantification, GeneralizedNewtypeDeriving, MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, UndecidableInstances, FlexibleContexts, DeriveDataTypeable #-}
-- |The "Traced" module provides a simple way of tracing expression evaluation.
-- A value of type @Traced a@ has both a value of type @a@ and an expression tree
-- that describes how the value was computed.
-- 
-- There are instances for the 'Traced' type for all numeric classes to make
-- it simple to trace numeric expressions.
-- 
-- The expression tree associated with a traced value is exactly that: a tree.
-- But evaluation of expressions in Haskell typically has sharing to avoid recomputation.
-- This sharing can be recovered by the (impure) 'reShare' function.
--
-- $examples
module Debug.Traced(
           Traced, traced, named, nameTraced, unknown, unTraced, tracedD,
	   TracedV,
           TracedD(..), unTracedD,
           liftT, liftFun, Liftable, Typeable,
           ifT, (%==), (%/=), (%<), (%<=), (%>), (%>=),
           Fixity(..),
	   showAsExp, showAsExpFull,
           reShare, simplify,
           AsValue, AsExp, AsFullExp,
           asValue, asExp, asFullExp, asSharedExp
           ) where
import System.Mem.StableName
import System.IO.Unsafe(unsafePerformIO)
import Data.Typeable
import Control.Monad.State
import Data.Maybe(fromMaybe, fromJust)
import Data.List(group, sort)
import Data.Char(isAlpha)
import qualified Data.Map as M
import qualified Data.Set as S
import Text.PrettyPrint.HughesPJ
--import Debug.Trace

import qualified Debug.StableMap as SM

-- | Traced values of some type.
data Traced t a = Traced TracedD a
    deriving (Typeable)

type TracedV a = Traced AsValue a

-- | Expression tree for a traced value.
data TracedD
    = NoValue                                                           -- ^unknown value
    | forall a . Name Bool Name TracedD                                 -- ^value with a name
    | forall a . (Show a, Typeable a) => Con a                          -- ^constant
    | forall a . (Show a, Typeable a) => Apply a Name Fixity [TracedD]  -- ^application
    | forall a . Let [(Name, TracedD)] TracedD                          -- ^(recovered) let expression
    deriving (Typeable)
type Name = String

instance Show TracedD where
    showsPrec _ NoValue = showString "__NoValue__"
    showsPrec p (Name _ _ v) = showsPrec p v
    showsPrec p (Con a) = showsPrec p a
    showsPrec p (Apply a _ _ _) = showsPrec p a
    showsPrec p (Let _ v) = showsPrec p v

-- |Fixity for identifier.
data Fixity = InfixL Int | InfixR Int | Infix Int | Nonfix

eLet :: [(Name, TracedD)] -> TracedD -> TracedD
eLet [] e = e
eLet bs e = Let bs e

-- | Create a traced value.
traced :: (Show a, Typeable a) => a -> Traced t a
traced a = Traced (Con a) a

-- | Add a named to a traced value.
nameTraced :: String -> Traced t a -> Traced t a
nameTraced s (Traced v a) = Traced (Name False s v) a

-- | Create a named traced value.
named :: (Show a, Typeable a) => String -> a -> Traced t a
named s a = nameTraced s $ traced a

-- | Create a named thing with no value.  Cannot be used where a real value is needed.
unknown :: (Show a, Typeable a) => String -> Traced t a
unknown s = nameTraced s $ fromJust $ unTracedD NoValue

-- | Extract the real value from a traced value.
unTraced :: Traced t a -> a
unTraced (Traced _ a) = a

-- | Extract the expression tree from a traced value.
tracedD :: Traced t a -> TracedD
tracedD (Traced d _) = d

-- | Convert an expression tree to a traced value, if the types are correct.
unTracedD :: (Typeable a) => TracedD -> Maybe (Traced t a)
unTracedD e =
    case e of
    NoValue -> Just $ Traced e (error "unTraced: no value")
    Name _ n NoValue -> Just $ Traced e (error $ "unTraced: no value: " ++ n)
    Name _ _ v -> liftM (Traced e . unTraced) $ unTracedD v
    Con a -> liftM (Traced e) $ cast a
    Apply a _ _ _ -> liftM (Traced e) $ cast a
    Let _ v -> liftM (Traced e . unTraced) $ unTracedD v

-- |Create a traced value with an 'Apply' expression tree.
apply :: (Typeable a, Show a) => a -> Name -> Fixity -> [TracedD] -> Traced t a
apply r op fx as = Traced (Apply r op fx as) r

class Liftable a b | a -> b where
    liftT' :: Name -> Fixity -> [TracedD] -> a -> b

instance (Typeable a, Show a, Liftable b tb) => Liftable (a -> b) (Traced t a -> tb) where
    liftT' n fx as f = \ x -> liftT' n fx (tracedD x:as) (f (unTraced x))

baseLiftT :: (Typeable a, Show a) => Name -> Fixity -> [TracedD] -> a -> Traced t a
baseLiftT n fx as r = Traced (Apply r n fx (reverse as)) r

instance Liftable Integer  (Traced t Integer)  where liftT' = baseLiftT
instance Liftable Int      (Traced t Int)      where liftT' = baseLiftT
instance Liftable Double   (Traced t Double)   where liftT' = baseLiftT
instance Liftable Float    (Traced t Float)    where liftT' = baseLiftT
instance Liftable Bool     (Traced t Bool)     where liftT' = baseLiftT
instance Liftable Ordering (Traced t Ordering) where liftT' = baseLiftT
instance Liftable ()       (Traced t ())       where liftT' = baseLiftT

liftT :: (Liftable a b) => Name -> Fixity -> a -> b
liftT n fx = liftT' n fx []

liftFun :: (Liftable a b) => Name -> a -> b
liftFun n = liftT n Nonfix

-----------------------------------

-- Numeric instances

binOp :: (Show c, Typeable c) =>
         (a->b->c) -> (String, Fixity) -> Traced t a -> Traced t b -> Traced t c
binOp f (n, fx) x y = apply (unTraced x `f` unTraced y) n fx [tracedD x, tracedD y]

unOp :: (Show b, Typeable b) => (a->b) -> String -> Traced t a -> Traced t b
unOp f op x = apply (f $ unTraced x) op Nonfix [tracedD x]

instance (Eq a) => Eq (Traced t a) where
    x == x'  =  unTraced x == unTraced x'

instance (Ord a) => Ord (Traced t a) where
    x `compare` x' =  unTraced x `compare` unTraced x'

instance (Num t, Show a) => Show (Traced t a) where
    showsPrec _ (Traced (Name _ s NoValue) _) = showString s
    showsPrec p v = if doExp then showString (render $ ppTracedD full p $ tracedD v) else showsPrec p $ unTraced v
       where f :: (Num t) => Traced t a -> t
             f _ = 0
             -- This is a rather gross hack. :)
             (doExp, full) = case show (f v) of
                             'A':'s':'E':_ -> (True, False)
                             'A':'s':'F':_ -> (True, True)
                             _             -> (False, False)

instance (Num t, Typeable a, Num a) => Num (Traced t a) where
    (+)           = binOp (+) ("+", InfixL 6)
    (-)           = binOp (-) ("-", InfixL 6)
    (*)           = binOp (*) ("*", InfixL 7)
    negate        = unOp negate "negate"
    abs           = unOp abs    "abs"
    signum        = unOp signum "signum"
    fromInteger   = traced . fromInteger

instance (Num t, Typeable a, Fractional a) => Fractional (Traced t a) where
    (/)           = binOp (/) ("/", InfixL 7)
    fromRational  = traced . fromRational

instance (Num t, Typeable a, Integral a) => Integral (Traced t a) where
    quot          = binOp quot ("quot", InfixL 7)
    rem           = binOp rem ("rem", InfixL 7)
    div           = binOp div ("div", InfixL 7)
    mod           = binOp mod ("mod", InfixL 7)
    toInteger     = toInteger . unTraced
    quotRem x y   = (quot x y, rem x y)

instance (Show a, Typeable a, Enum a) => Enum (Traced t a) where
    toEnum        = traced . toEnum
    fromEnum      = fromEnum . unTraced
 
instance (Num t, Typeable a, Real a) => Real (Traced t a) where
    toRational    = toRational . unTraced

instance (Num t, Typeable a, RealFrac a) => RealFrac (Traced t a) where
    properFraction c = (i, traced c') where (i, c') = properFraction (unTraced c)

instance (Num t, Typeable a, Floating a) => Floating (Traced t a) where
    pi = named "pi" pi
    exp = unOp exp "exp"
    sqrt = unOp sqrt "sqrt"
    log = unOp log "log"
    (**) = binOp (**) ("**", InfixR 8)
    logBase = binOp logBase ("logBase", Nonfix)
    sin = unOp sin "sin"
    tan = unOp tan "tan"
    cos = unOp cos "cos"
    asin = unOp asin "asin"
    atan = unOp atan "atan"
    acos = unOp acos "acos"
    sinh = unOp sinh "sinh"
    tanh = unOp tanh "tanh"
    cosh = unOp cosh "cosh"
    asinh = unOp asinh "asinh"
    atanh = unOp atanh "atanh"
    acosh = unOp acosh "acosh"

instance (Num t, Typeable a, RealFloat a) => RealFloat (Traced t a) where
    floatRadix = floatRadix . unTraced
    floatDigits = floatDigits . unTraced
    floatRange  = floatRange . unTraced
    decodeFloat = decodeFloat . unTraced
    encodeFloat m e = traced (encodeFloat m e)
    exponent = exponent . unTraced
    significand = traced . significand . unTraced
    scaleFloat k = traced . scaleFloat k . unTraced
    isNaN = isNaN . unTraced
    isInfinite = isInfinite . unTraced
    isDenormalized = isDenormalized . unTraced
    isNegativeZero = isNegativeZero . unTraced
    isIEEE = isIEEE . unTraced
    atan2 = binOp atan2 ("atan2", Nonfix)


-----------------------------------

-- Boolean operations

-- |Traced version of /if/.
ifT :: (Show a, Typeable a) => Traced t Bool -> Traced t a -> Traced t a -> Traced t a
ifT c t e = apply (unTraced $ if b then t else e) "ifT" Nonfix $ tracedD c : if b then [tracedD t, none] else [none, tracedD e]
  where none = tracedD u
        u = unknown "..." `asTypeOf` t
        b = unTraced c

infix 4 %==, %/=, %<, %<=, %>, %>=
-- |Comparisons generating traced booleans.
(%==), (%/=) :: (Eq a) => Traced t a -> Traced t a -> Traced t Bool
(%==) = binOp (==) ("==", Infix 4)
(%/=) = binOp (/=) ("/=", Infix 4)

(%<), (%<=), (%>), (%>=) :: (Ord a) => Traced t a -> Traced t a -> Traced t Bool
(%<)  = binOp (<)  ("<",  Infix 4)
(%<=) = binOp (<=) ("<=", Infix 4)
(%>)  = binOp (>)  (">",  Infix 4)
(%>=) = binOp (>=) (">=", Infix 4)


-----------------------------------

-- Pretty printing of a traced value

ppTracedD :: Bool -> Int -> TracedD -> Doc
ppTracedD _     _ NoValue = text "undefined"
ppTracedD _     _ (Name _ n NoValue) = text n
ppTracedD False p (Name False _ v) = ppTracedD False p v
ppTracedD False _ (Name True s _) = text s
ppTracedD True  _ (Name _ n v) = text n <> text "{-" <> text (show v) <> text "-}"
ppTracedD _     p (Con v) = text (showsPrec p v "")
ppTracedD _     _ (Apply _ f _ []) = text f
ppTracedD b     p (Apply _ "negate" _ [x]) = -- A hack for negate
    ppParens (p >= 6) (text "-" <> ppTracedD b 7 x)
ppTracedD b     p (Apply _ op Nonfix as) =
    ppParens (p > 10) $
    text op <+> fsep (map (ppTracedD b 11) as)
ppTracedD b     p (Apply _ op f [x,y]) =
    let (ql,q,qr) = case f of
                    InfixL d -> (d,d,d+1)
		    InfixR d -> (d+1,d,d)
		    Infix  d -> (d+1,d,d+1)
		    Nonfix   -> error "ppTracedD: impossible"
        op' = if isAlpha (head op) then "`" ++ op ++ "`" else op
    in  ppParens (p > q) $
        ppTracedD b ql x <+> text op' <+> ppTracedD b qr y
ppTracedD _     _ (Apply _ _ _ _) = error "ppTracedD: bad binop"
ppTracedD b     p (Let bs v) =
    ppParens (p > 0) $
    sep (text "let" : map (nest 4 . ppBind) bs ++ [text "in  " <> ppTracedD b 0 v])
  where ppBind (n, e) = text n <+> equals <+> ppTracedD b 0 e <>
                        if b then text " {- " <> equals <+> text (show e) <> text " -}" <> semi else semi
                    
ppParens :: Bool -> Doc -> Doc
ppParens False d = d
ppParens True d = parens d

-- |Show the expression tree of a traced value.
showAsExp :: (Show a) => Traced t a -> String
showAsExp = render . ppTracedD False 0 . tracedD

-- |Show the expression tree of a traced value, also show the value of each variable.
showAsExpFull :: (Show a) => Traced t a -> String
showAsExpFull = render . ppTracedD True 0 . tracedD

-----------------------------------

{-
-- |Recover sharing.
class ReShare a b | a -> b where
    reShare :: a -> a

instance (Show a, Typeable a) => ReShare (Traced a) (Traced a) where
-}

reShare :: (Typeable a) => Traced t a -> Traced t a
reShare = fromMaybe (error "impossible reShare") . unTracedD . share . tracedD

-- This unsafePerformIO is safe in the sense that it doesn't cause any runtime errors,
-- but it does allow observation of how expressions are evaluated.  That's the whole
-- purpose of it.
share :: TracedD -> TracedD
share e = unsafePerformIO $ do
    (v, (_, _, bs)) <- runStateT (share' e) (0, SM.empty, [])
    let unknownBind (n, Name False n' NoValue) = n == n'
        unknownBind _ = False
    return $ Let (filter (not . unknownBind) $ reverse bs) v

share' :: TracedD -> StateT (Integer, SM.StableMap TracedD TracedD, [(Name, TracedD)]) IO TracedD
share' e@NoValue = return e  -- Don't share constants
share' e@(Con _) = return e  -- Don't share constants
share' e = do
    (i, sm, bs) <- get
    sn <- liftIO $ e `seq` makeStableName e
--    liftIO $ putStrLn $ "Stable=" ++ show (hashStableName sn)
    --liftIO performGC
--    liftIO $ print (e `seq` True)
    -- ++ " " ++ (render $ ppTracedD False 0 e)
--    sn' <- liftIO $ makeStableName e
--    liftIO $ putStrLn $ "   new=" ++ show (hashStableName sn')
    case SM.lookup sn sm of
        Just ie -> do --liftIO $ putStrLn "Found";
                      return ie
        Nothing -> do
--            liftIO $ putStrLn "New"
            let n = case e of
                    Name _ s _ -> s   -- reuse the user name
                    _ -> prefix ++ show i
	        ie = Name True n e
            put (i+1, SM.insert sn ie sm, bs)
	    e' <- case e of
	          NoValue -> return e
		  Name b m a -> liftM (Name b m) $ share' a
		  Con _ -> return e
		  Apply a m fx as -> liftM (Apply a m fx) $ mapM share' as
		  Let _ _ -> error "share': Let"
	    (i', sm', bs') <- get
            put (i', sm', (n, e') : bs')
	    return ie

prefix :: String
prefix = "_"

-- |Simplify an expression tree.
simplify :: Traced t a -> Traced t a
simplify (Traced d a) = Traced (simplifyD d) a

-- Simplify bindings
-- Inline definitions used once and trivial (constants and variables) expressions.
simplifyD :: TracedD -> TracedD
simplifyD elet@(Let bs b) = 
    let onceVars = S.fromList $ map head $ filter ((== 1) . length) $ group $ sort $ getVars elet
        getVars NoValue = []
        getVars (Con _) = []
        getVars (Let bs e) = concatMap (getVars . snd) bs ++ getVars e
        getVars (Apply _ _ _ es) = concatMap getVars es
        getVars (Name True v _) = [v]
        getVars (Name False _ e) = getVars e
        isTriv NoValue = True
        isTriv (Con _) = True
        isTriv (Name True _ _) = True
        isTriv _ = False
        subst _ NoValue = NoValue
        subst _ e@(Con _) = e
        subst _ (Let _ _) = error "Traced.simplify: Let"
        subst m (Apply a op fx es) = Apply a op fx (map (subst m) es)
        subst m (Name b v e) = 
            case M.lookup v m of
            Nothing -> Name b v (subst m e)
            Just e' -> e'
        shareVar v = take (length prefix) v == prefix
        (bs', bm) = foldr step ([], M.empty) (reverse bs)
        step (v, e) (ds, m) =
            let e' = subst m e
            in  if shareVar v && (v `S.member` onceVars || isTriv e') then
                    (ds, M.insert v e' m)
                else
                    ((v, e') : ds, m)
        b' = subst bm b
    in  eLet (reverse bs') b'
simplifyD e = e

-----------------------------------

type AsValue = Integer

asValue :: Traced t a -> Traced AsValue a
asValue = idTraced

newtype AsExp = AsExp Int deriving (Eq, Show, Num)

asExp :: Traced t a -> Traced AsExp a
asExp = idTraced

newtype AsFullExp = AsFullExp Int deriving (Eq, Show, Num)

asFullExp :: Traced t a -> Traced AsFullExp a
asFullExp = idTraced

idTraced :: Traced t a -> Traced t' a
idTraced (Traced d t) = Traced d t

asSharedExp :: (Typeable a) => Traced t a -> Traced AsExp a
asSharedExp = asExp . simplify . reShare