{-# LANGUAGE CPP, ExistentialQuantification, GeneralizedNewtypeDeriving, MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, UndecidableInstances, FlexibleContexts, DeriveDataTypeable, OverlappingInstances #-}
module Debug.TracedInternal(
           Traced, traced, named, nameTraced, unknown, unTraced, tracedD,
           TraceLevel(..), Traceable(..),
           TracedD(..), unTracedD,
           binOp, unOp, apply,
           liftT, liftFun, Liftable(..), baseLiftT,
           Fixity(..),
	   showAsExp, showAsExpFull,
           pPrintTraced,
           reShare, simplify
           ) where
import System.Mem.StableName
import System.IO.Unsafe(unsafePerformIO)
import Data.Typeable
#if __GLASGOW_HASKELL__ >= 706
import Data.Data(Data(..), mkNoRepType)
#elif __GLASGOW_HASKELL__ >= 610
import Data.Data(Data(..), mkNorepType)
#else
import Data.Generics(Data(..), mkNorepType)
#endif
import Control.Monad.State
import Data.Maybe(fromMaybe)
import Data.List(group, sort)
import Data.Char(isAlpha)
import qualified Data.Map as M
import qualified Data.Set as S
import Text.PrettyPrint.HughesPJ
--import Debug.Trace

import qualified Debug.StableMap as SM

-- | Traced values of some type.
data Traced a = Traced (Maybe TracedD) a
    deriving (Typeable, Data)

instance (Read a, Traceable a) => Read (Traced a) where
    readsPrec p s = [ (traced x, r) | (x, r) <- readsPrec p s]

data TraceLevel = TLValue | TLExp | TLFullExp
    deriving (Eq, Ord, Show, Typeable, Data)

class (Typeable a, Show a) => Traceable a where
    trPPrint :: TraceLevel -> Int -> a -> Doc

instance (Typeable a, Show a) => Traceable a where
    trPPrint _ p x = text (showsPrec p x "")

-- | Expression tree for a traced value.
data TracedD
    = NoValue                                                    -- ^unknown value
    | forall a . Name Bool Name TracedD                          -- ^value with a name
    | forall a . (Traceable a) => Con a                          -- ^constant
    | forall a . (Traceable a) => Apply a Name Fixity [TracedD]  -- ^application
    | forall a . Let [(Name, TracedD)] TracedD                   -- ^(recovered) let expression
    deriving (Typeable)
type Name = String

instance Data TracedD where
    gfoldl _ z c = z c
    gunfold _ _ _ = error "gunfold: TracedD"
    toConstr _ = error "toConstr: TracedD"
#if __GLASGOW_HASKELL__ >= 706
    dataTypeOf _ = mkNoRepType "Debug.TracedInternal.TracedD"
#else
    dataTypeOf _ = mkNorepType "Debug.TracedInternal.TracedD"
#endif

instance Show TracedD where
    showsPrec _ NoValue = showString "__NoValue__"
    showsPrec p (Name _ _ v) = showsPrec p v
    showsPrec p (Con a) = showsPrec p a
    showsPrec p (Apply a _ _ _) = showsPrec p a
    showsPrec p (Let _ v) = showsPrec p v

instance Eq TracedD where
    x == y  =  compare x y == EQ

-- Pure structural comparison, good for putting values into a map
instance Ord TracedD where
    compare NoValue NoValue = EQ
    compare NoValue _       = LT
    compare (Name _ _ _) NoValue = GT
    compare (Name b1 n1 v1) (Name b2 n2 v2) = compare (b1, n1, v1) (b2, n2, v2)
    compare (Name _ _ _) _ = LT
    compare (Con _) NoValue = GT
    compare (Con _) (Name _ _ _) = GT
    compare (Con a1) (Con a2) = compareValues a1 a2
    compare (Con _) _ = LT
    compare (Apply _ n1 f1 as1) (Apply _ n2 f2 as2) = compare (n1, f1, as1) (n2, f2, as2)
    compare (Apply _ _ _ _) (Let _ _) = LT
    compare (Apply _ _ _ _) _ = GT
    compare (Let bs1 t1) (Let bs2 t2) = compare (bs1, t1) (bs2, t2)
    compare (Let _ _) _ = GT

-- We need an ordering on values (possibly of different types).
-- It doesn't matter what the ordering is, as long as it's total.
-- We fudge it by using the hash of the stable name as the unique
-- identifier for the value and the compare those.
compareValues :: (Traceable a, Traceable b) => a -> b -> Ordering
compareValues x y = compare (uniq x) (uniq y)
  where uniq a = seq a $
                 hashStableName (unsafePerformIO $ makeStableName a)

-- |Fixity for identifier.
data Fixity = InfixL Int | InfixR Int | Infix Int | Nonfix
    deriving (Eq, Ord, Show)

eLet :: [(Name, TracedD)] -> TracedD -> TracedD
eLet [] e = e
eLet bs e = Let bs e

-- | Create a traced value.
traced :: (Traceable a) => a -> Traced a
traced a = Traced Nothing a

-- | Add a named to a traced value.
nameTraced :: (Traceable a) => String -> Traced a -> Traced a
nameTraced s t@(Traced (Just (Name False s' _)) _) | s == s' = t
nameTraced s t@(Traced _ a) = Traced (Just $ Name False s $ tracedD t) a

-- | Create a named traced value.
named :: (Traceable a) => String -> a -> Traced a
named s a = nameTraced s $ traced a

-- | Create a named thing with no value.  Cannot be used where a real value is needed.
unknown :: (Traceable a) => String -> Traced a
unknown s = Traced (Just (Name False s NoValue)) (error $ "Data.Traced: unknown `" ++ s ++ "' used")

-- | Extract the real value from a traced value.
unTraced :: Traced a -> a
unTraced (Traced _ a) = a

-- | Extract the expression tree from a traced value.
tracedD :: (Traceable a) => Traced a -> TracedD
tracedD (Traced (Just d) _) = d
tracedD (Traced Nothing a) = Con a

-- | Convert an expression tree to a traced value, if the types are correct.
unTracedD :: (Traceable a) => TracedD -> Maybe (Traced a)
unTracedD e =
    let trcd = Traced (Just e) in
    case e of
    NoValue -> Just $ trcd (error "unTraced: no value")
    Name _ n NoValue -> Just $ trcd (error $ "unTraced: no value: " ++ n)
    Name _ _ v -> liftM (trcd . unTraced) $ unTracedD v
    Con a -> liftM trcd $ cast a
    Apply a _ _ _ -> liftM trcd $ cast a
    Let _ v -> liftM (trcd . unTraced) $ unTracedD v

-- |Create a traced value with an 'Apply' expression tree.
apply :: (Traceable a) => a -> Name -> Fixity -> [TracedD] -> Traced a
apply r op fx as = Traced (Just $ Apply r op fx as) r

class Liftable a b | a -> b, b -> a where
    liftT' :: Name -> Fixity -> [TracedD] -> a -> b

instance (Traceable a, Liftable b tb) => Liftable (a -> b) (Traced a -> tb) where
    liftT' n fx as f = \ x -> liftT' n fx (tracedD x:as) (f (unTraced x))

baseLiftT :: (Traceable a) => Name -> Fixity -> [TracedD] -> a -> Traced a
baseLiftT n fx as r = Traced (Just $ Apply r n fx (reverse as)) r

instance Liftable Integer  (Traced Integer)  where liftT' = baseLiftT
instance Liftable Int      (Traced Int)      where liftT' = baseLiftT
instance Liftable Double   (Traced Double)   where liftT' = baseLiftT
instance Liftable Float    (Traced Float)    where liftT' = baseLiftT
instance Liftable Bool     (Traced Bool)     where liftT' = baseLiftT
instance Liftable Ordering (Traced Ordering) where liftT' = baseLiftT
instance Liftable ()       (Traced ())       where liftT' = baseLiftT

liftT :: (Liftable a b) => Name -> Fixity -> a -> b
liftT n fx = liftT' n fx []

liftFun :: (Liftable a b) => Name -> a -> b
liftFun n = liftT n Nonfix

-----------------------------------

-- Numeric instances

binOp :: (Traceable a, Traceable b, Traceable c) =>
         (a->b->c) -> (String, Fixity) -> Traced a -> Traced b -> Traced c
binOp f (n, fx) x y = apply (unTraced x `f` unTraced y) n fx [tracedD x, tracedD y]

unOp :: (Traceable a, Traceable b) => (a->b) -> String -> Traced a -> Traced b
unOp f op x = apply (f $ unTraced x) op Nonfix [tracedD x]

instance (Eq a) => Eq (Traced a) where
    x == x'  =  unTraced x == unTraced x'

instance (Ord a) => Ord (Traced a) where
    x `compare` x' =  unTraced x `compare` unTraced x'

instance (Show a) => Show (Traced a) where
    showsPrec _ (Traced (Just (Name _ s NoValue)) _) = showString s
    showsPrec p v = showsPrec p $ unTraced v

instance (Traceable a, Num a) => Num (Traced a) where
    (+)           = binOp (+) ("+", InfixL 6)
    (-)           = binOp (-) ("-", InfixL 6)
    (*)           = binOp (*) ("*", InfixL 7)
    negate        = unOp negate "negate"
    abs           = unOp abs    "abs"
    signum        = unOp signum "signum"
    fromInteger   = traced . fromInteger

instance (Traceable a, Fractional a) => Fractional (Traced a) where
    (/)           = binOp (/) ("/", InfixL 7)
    fromRational  = traced . fromRational

instance (Traceable a, Integral a) => Integral (Traced a) where
    quot          = binOp quot ("quot", InfixL 7)
    rem           = binOp rem ("rem", InfixL 7)
    div           = binOp div ("div", InfixL 7)
    mod           = binOp mod ("mod", InfixL 7)
    toInteger     = toInteger . unTraced
    quotRem x y   = (quot x y, rem x y)

instance (Traceable a, Enum a) => Enum (Traced a) where
    toEnum        = traced . toEnum
    fromEnum      = fromEnum . unTraced

instance (Traceable a, Real a) => Real (Traced a) where
    toRational    = toRational . unTraced

instance (Traceable a, RealFrac a) => RealFrac (Traced a) where
    properFraction c = (i, traced c') where (i, c') = properFraction (unTraced c)

instance (Traceable a, Floating a) => Floating (Traced a) where
    pi = named "pi" pi
    exp = unOp exp "exp"
    sqrt = unOp sqrt "sqrt"
    log = unOp log "log"
    (**) = binOp (**) ("**", InfixR 8)
    logBase = binOp logBase ("logBase", Nonfix)
    sin = unOp sin "sin"
    tan = unOp tan "tan"
    cos = unOp cos "cos"
    asin = unOp asin "asin"
    atan = unOp atan "atan"
    acos = unOp acos "acos"
    sinh = unOp sinh "sinh"
    tanh = unOp tanh "tanh"
    cosh = unOp cosh "cosh"
    asinh = unOp asinh "asinh"
    atanh = unOp atanh "atanh"
    acosh = unOp acosh "acosh"

instance (Traceable a, RealFloat a) => RealFloat (Traced a) where
    floatRadix = floatRadix . unTraced
    floatDigits = floatDigits . unTraced
    floatRange  = floatRange . unTraced
    decodeFloat = decodeFloat . unTraced
    encodeFloat m e = traced (encodeFloat m e)
    exponent = exponent . unTraced
    significand = traced . significand . unTraced
    scaleFloat k = traced . scaleFloat k . unTraced
    isNaN = isNaN . unTraced
    isInfinite = isInfinite . unTraced
    isDenormalized = isDenormalized . unTraced
    isNegativeZero = isNegativeZero . unTraced
    isIEEE = isIEEE . unTraced
    atan2 = binOp atan2 ("atan2", Nonfix)


-----------------------------------

-- Pretty printing of a traced value

ppTracedD :: TraceLevel -> Int -> TracedD -> Doc
ppTracedD _     _ NoValue = text "undefined"
ppTracedD _     _ (Name _ n NoValue) = text n

ppTracedD TLValue p (Name _ _ x) = ppTracedD TLValue p x
ppTracedD TLValue p (Con x) = trPPrint TLValue p x
ppTracedD TLValue p (Apply x _ _ _) = trPPrint TLValue p x
ppTracedD TLValue p (Let _ x) = ppTracedD TLValue p x

ppTracedD TLExp p (Name False _ v) = ppTracedD TLExp p v
ppTracedD TLExp _ (Name True s _) = text s
ppTracedD TLFullExp  _ (Name _ n v) = text n <> text "{-" <> trPPrint TLValue 0 v <> text "-}"

ppTracedD b     p (Con v) = trPPrint b p v
ppTracedD _     _ (Apply _ f _ []) = text f
ppTracedD b     p (Apply _ "negate" _ [x]) = -- A hack for negate
    ppParens (p >= 6) (text "-" <> ppTracedD b 7 x)
ppTracedD b     p (Apply _ op Nonfix as) =
    ppParens (p > 10) $
    text op <+> fsep (map (ppTracedD b 11) as)
ppTracedD b     p (Apply _ op f [x,y]) =
    let (ql,q,qr) = case f of
                    InfixL d -> (d,d,d+1)
		    InfixR d -> (d+1,d,d)
		    Infix  d -> (d+1,d,d+1)
		    Nonfix   -> error "ppTracedD: impossible"
        op' = if isAlpha (head op) then "`" ++ op ++ "`" else op
    in  ppParens (p > q) $
        sep [ppTracedD b ql x <+> text op', ppTracedD b qr y]
ppTracedD _     _ (Apply _ _ _ _) = error "ppTracedD: bad binop"
ppTracedD b     p (Let bs v) =
    ppParens (p > 0) $
    sep (text "let" : map (nest 4 . ppBind) bs ++ [text "in  " <> ppTracedD b 0 v])
  where ppBind (n, e) = text n <+> equals <+> ppTracedD b 0 e <>
                        if b == TLFullExp then text " {- " <> equals <+> text (show e) <> text " -}" <> semi else semi

ppParens :: Bool -> Doc -> Doc
ppParens False d = d
ppParens True d = parens d

-- |Show the expression tree of a traced value.
showAsExp :: (Traceable a) => Traced a -> String
showAsExp = render . pPrintTraced TLExp 0

-- |Show the expression tree of a traced value, also show the value of each variable.
showAsExpFull :: (Traceable a) => Traced a -> String
showAsExpFull = render . pPrintTraced TLFullExp 0

pPrintTraced :: (Traceable a) => TraceLevel -> Int -> Traced a -> Doc
pPrintTraced b p = ppTracedD b p . tracedD

-----------------------------------

reShare :: (Traceable a) => Traced a -> Traced a
reShare = fromMaybe (error "impossible reShare") . unTracedD . share . tracedD

-- This unsafePerformIO is safe in the sense that it doesn't cause any runtime errors,
-- but it does allow observation of how expressions are evaluated.  That's the whole
-- purpose of it.
share :: TracedD -> TracedD
share e = unsafePerformIO $ do
    (v, (_, _, bs)) <- runStateT (share' e) (0, SM.empty, [])
    let unknownBind (n, Name False n' NoValue) = n == n'
        unknownBind _ = False
    return $ Let (filter (not . unknownBind) $ reverse bs) v

share' :: TracedD -> StateT (Integer, SM.StableMap TracedD TracedD, [(Name, TracedD)]) IO TracedD
share' e@NoValue = return e  -- Don't share constants
share' e@(Con _) = return e  -- Don't share constants
share' e = do
    (i, sm, bs) <- get
    sn <- liftIO $ e `seq` makeStableName e
    case SM.lookup sn sm of
        Just ie -> do --liftIO $ putStrLn "Found";
                      return ie
        Nothing -> do
            let n = case e of
                    Name _ s _ -> s   -- reuse the user name
                    _ -> prefix ++ show i
	        ie = Name True n e
            put (i+1, SM.insert sn ie sm, bs)
	    e' <- case e of
	          NoValue -> return e
		  Name b m a -> liftM (Name b m) $ share' a
		  Con _ -> return e
		  Apply a m fx as -> liftM (Apply a m fx) $ mapM share' as
		  Let _ _ -> error "share': Let"
	    (i', sm', bs') <- get
            put (i', sm', (n, e') : bs')
	    return ie

prefix :: String
prefix = "_"

-- |Simplify an expression tree.
simplify :: Traced a -> Traced a
simplify (Traced d a) = Traced (fmap simplifyD d) a

-- Simplify bindings
-- Inline definitions used once and trivial (constants and variables) expressions.
simplifyD :: TracedD -> TracedD
simplifyD elet@(Let binds body) =
    let onceVars = S.fromList $ map head $ filter ((== 1) . length) $ group $ sort $ getVars elet
        getVars NoValue = []
        getVars (Con _) = []
        getVars (Let bs e) = concatMap (getVars . snd) bs ++ getVars e
        getVars (Apply _ _ _ es) = concatMap getVars es
        getVars (Name True v _) = [v]
        getVars (Name False _ e) = getVars e
        isTriv NoValue = True
        isTriv (Con _) = True
        isTriv (Name True _ _) = True
        isTriv _ = False
        subst _ NoValue = NoValue
        subst _ e@(Con _) = e
        subst _ (Let _ _) = error "Traced.simplify: Let"
        subst m (Apply a op fx es) = Apply a op fx (map (subst m) es)
        subst m (Name b v e) =
            case M.lookup v m of
            Nothing -> Name b v (subst m e)
            Just e' -> e'
        shareVar v = take (length prefix) v == prefix
        (bs', bm) = foldr step ([], M.empty) (reverse binds)
        step (v, e) (ds, m) =
            let e' = subst m e
            in  if shareVar v && (v `S.member` onceVars || isTriv e') then
                    (ds, M.insert v e' m)
                else
                    ((v, e') : ds, m)
        body' = subst bm body
    in  eLet (reverse bs') body'
simplifyD e = e