{-# LANGUAGE DeriveDataTypeable        #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TupleSections             #-}
{-# LANGUAGE TypeSynonymInstances      #-}

module Language.Haskell.Liquid.Transforms.Rec (
     transformRecExpr, transformScope
     , outerScTr , innerScTr
     , isIdTRecBound, setIdTRecBound
     ) where

import           Bag
import           Coercion
import           Control.Arrow                        (second)
import           Control.Monad.State
import           CoreSyn
import           CoreUtils
import qualified Data.HashMap.Strict                  as M
import           Data.Hashable
import           ErrUtils
import           Id
import           IdInfo
import           Language.Haskell.Liquid.GHC.Misc
import           Language.Haskell.Liquid.GHC.Play
import           Language.Haskell.Liquid.Misc         (mapSndM)
import           Language.Fixpoint.Misc               (mapSnd) -- , traceShow)
import           Language.Haskell.Liquid.Types.Errors
import           MkCore                               (mkCoreLams)
import           Name                                 (isSystemName)
import           Outputable                           (SDoc)
import           Prelude                              hiding (error)
import           SrcLoc
import           Type                                 (mkForAllTys, splitForAllTys)
import           TyCoRep
import           Unique                               hiding (deriveUnique)
import           Var

import           Data.List                            (foldl', isInfixOf)


import qualified Data.List                            as L


transformRecExpr :: CoreProgram -> CoreProgram
transformRecExpr cbs
  | isEmptyBag $ filterBag isTypeError e
  =  {-trace "new cbs"-} pg
  | otherwise
  = panic Nothing ("Type-check" ++ showSDoc (pprMessageBag e))
  where pg0    = evalState (transPg (inlineLoopBreaker <$> cbs)) initEnv
        (_, e) = lintCoreBindings [] pg
        pg     = inlineFailCases pg0




inlineLoopBreaker :: Bind Id -> Bind Id
inlineLoopBreaker (NonRec x e) | Just (lbx, lbe) <- hasLoopBreaker be
  = Rec [(x, foldr Lam (sub (M.singleton lbx e') lbe) (αs ++ as))]
  where
    (αs, as, be) = collectTyAndValBinders e

    e' = foldl' App (foldl' App (Var x) ((Type . TyVarTy) <$> αs)) (Var <$> as)

    hasLoopBreaker (Let (Rec [(x1, e1)]) (Var x2)) | isLoopBreaker x1 && x1 == x2 = Just (x1, e1)
    hasLoopBreaker _                               = Nothing

    isLoopBreaker =  isStrongLoopBreaker . occInfo . idInfo

inlineLoopBreaker bs
  = bs

inlineFailCases :: CoreProgram -> CoreProgram
inlineFailCases = (go [] <$>)
  where
    go su (Rec xes)    = Rec (mapSnd (go' su) <$> xes)
    go su (NonRec x e) = NonRec x (go' su e)

    go' su (App (Var x) _)       | isFailId x, Just e <- getFailExpr x su = e
    go' su (Let (NonRec x ex) e) | isFailId x   = go' (addFailExpr x (go' su ex) su) e

    go' su (App e1 e2)      = App (go' su e1) (go' su e2)
    go' su (Lam x e)        = Lam x (go' su e)
    go' su (Let xs e)       = Let (go su xs) (go' su e)
    go' su (Case e x t alt) = Case (go' su e) x t (goalt su <$> alt)
    go' su (Cast e c)       = Cast (go' su e) c
    go' su (Tick t e)       = Tick t (go' su e)
    go' _  e                = e

    goalt su (c, xs, e)     = (c, xs, go' su e)

    isFailId x  = isLocalId x && (isSystemName $ varName x) && L.isPrefixOf "fail" (show x)
    getFailExpr = L.lookup

    addFailExpr x (Lam _ e) su = (x, e):su
    addFailExpr _ _         _  = impossible Nothing "internal error" -- this cannot happen

isTypeError :: SDoc -> Bool
isTypeError s | isInfixOf "Non term variable" (showSDoc s) = False
isTypeError _ = True

-- No need for this transformation after ghc-8!!!
transformScope :: [Bind Id] -> [Bind Id]
transformScope = outerScTr . innerScTr 

outerScTr :: [Bind Id] -> [Bind Id]
outerScTr = mapNonRec (go [])
  where
   go ack x (xe : xes) | isCaseArg x xe = go (xe:ack) x xes
   go ack _ xes        = ack ++ xes

isCaseArg :: Id -> Bind t -> Bool
isCaseArg x (NonRec _ (Case (Var z) _ _ _)) = z == x
isCaseArg _ _                               = False

innerScTr :: Functor f => f (Bind Id) -> f (Bind Id)
innerScTr = (mapBnd scTrans <$>)

scTrans :: Id -> Expr Id -> Expr Id
scTrans x e = mapExpr scTrans $ foldr Let e0 bs
  where (bs, e0)           = go [] x e
        go bs x (Let b e)  | isCaseArg x b = go (b:bs) x e
        go bs x (Tick t e) = second (Tick t) $ go bs x e
        go bs _ e          = (bs, e)

type TE = State TrEnv

data TrEnv = Tr { freshIndex  :: !Int
                , _loc        :: SrcSpan
                }

initEnv :: TrEnv
initEnv = Tr 0 noSrcSpan

transPg :: Traversable t
        => t (Bind CoreBndr)
        -> State TrEnv (t (Bind CoreBndr))
transPg = mapM transBd

transBd :: Bind CoreBndr
        -> State TrEnv (Bind CoreBndr)
transBd (NonRec x e) = liftM (NonRec x) (transExpr =<< mapBdM transBd e)
transBd (Rec xes)    = liftM Rec $ mapM (mapSndM (mapBdM transBd)) xes

transExpr :: CoreExpr -> TE CoreExpr
transExpr e
  | (isNonPolyRec e') && (not (null tvs))
  = trans tvs ids bs e'
  | otherwise
  = return e
  where (tvs, ids, e'')       = collectTyAndValBinders e
        (bs, e')              = collectNonRecLets e''

isNonPolyRec :: Expr CoreBndr -> Bool
isNonPolyRec (Let (Rec xes) _) = any nonPoly (snd <$> xes)
isNonPolyRec _                 = False

nonPoly :: CoreExpr -> Bool
nonPoly = null . fst . splitForAllTys . exprType

collectNonRecLets :: Expr t -> ([Bind t], Expr t)
collectNonRecLets = go []
  where go bs (Let b@(NonRec _ _) e') = go (b:bs) e'
        go bs e'                      = (reverse bs, e')

appTysAndIds :: [Var] -> [Id] -> Id -> Expr b
appTysAndIds tvs ids x = mkApps (mkTyApps (Var x) (map TyVarTy tvs)) (map Var ids)

trans :: Foldable t
      => [TyVar]
      -> [Var]
      -> t (Bind Id)
      -> Expr Var
      -> State TrEnv (Expr Id)
trans vs ids bs (Let (Rec xes) e)
  = liftM (mkLam . mkLet) (makeTrans vs liveIds e')
  where liveIds = mkAlive <$> ids
        mkLet e = foldr Let e bs
        mkLam e = foldr Lam e $ vs ++ liveIds
        e'      = Let (Rec xes') e
        xes'    = (second mkLet) <$> xes

trans _ _ _ _ = panic Nothing "TransformRec.trans called with invalid input"

makeTrans :: [TyVar]
          -> [Var]
          -> Expr Var
          -> State TrEnv (Expr Var)
makeTrans vs ids (Let (Rec xes) e)
 = do fids    <- mapM (mkFreshIds vs ids) xs
      let (ids', ys) = unzip fids
      let yes  = appTysAndIds vs ids <$> ys
      ys'     <- mapM fresh xs
      let su   = M.fromList $ zip xs (Var <$> ys')
      let rs   = zip ys' yes
      let es'  = zipWith (mkE ys) ids' es
      let xes' = zip ys es'
      return   $ mkRecBinds rs (Rec xes') (sub su e)
 where
   (xs, es)       = unzip xes
   mkSu ys ids'   = mkSubs ids vs ids' (zip xs ys)
   mkE ys ids' e' = mkCoreLams (vs ++ ids') (sub (mkSu ys ids') e')

makeTrans _ _ _ = panic Nothing "TransformRec.makeTrans called with invalid input"

mkRecBinds :: [(b, Expr b)] -> Bind b -> Expr b -> Expr b
mkRecBinds xes rs e = Let rs (foldl' f e xes)
  where f e (x, xe) = Let (NonRec x xe) e

mkSubs :: (Eq k, Hashable k)
       => [k] -> [Var] -> [Id] -> [(k, Id)] -> M.HashMap k (Expr b)
mkSubs ids tvs xs ys = M.fromList $ s1 ++ s2
  where s1 = (second (appTysAndIds tvs xs)) <$> ys
        s2 = zip ids (Var <$> xs)

mkFreshIds :: [TyVar]
           -> [Var]
           -> Var
           -> State TrEnv ([Var], Id)
mkFreshIds tvs ids x
  = do ids'  <- mapM fresh ids
       let ids'' = map setIdTRecBound ids'
       let t  = mkForAllTys ((`TvBndr` Required) <$> tvs) $ mkType (reverse ids'') $ varType x
       let x' = setVarType x t
       return (ids'', x')
  where
    mkType ids ty = foldl (\t x -> FunTy (varType x) t) ty ids

-- NOTE [Don't choose transform-rec binders as decreasing params]
-- --------------------------------------------------------------
--
-- We don't want to select a binder created by TransformRec as the
-- decreasing parameter, since the user didn't write it. Furthermore,
-- consider T1065. There we have an inner loop that decreases on the
-- sole list parameter. But TransformRec prepends the parameters to the
-- outer `groupByFB` to the inner `groupByFBCore`, and now the first
-- decreasing parameter is the constant `xs0`. Disaster!
--
-- So we need a way to signal to L.H.L.Constraint.Generate that we
-- should ignore these copied Vars. The easiest way to do that is to set
-- a flag on the Var that we know won't be set, and it just so happens
-- GHC has a bunch of optional flags that can be set by various Core
-- analyses that we don't run...
setIdTRecBound :: Id -> Id
-- This is an ugly hack..
setIdTRecBound = modifyIdInfo (`setCafInfo` NoCafRefs)

isIdTRecBound :: Id -> Bool
isIdTRecBound = not . mayHaveCafRefs . cafInfo . idInfo

class Freshable a where
  fresh :: a -> TE a

instance Freshable Int where
  fresh _ = freshInt

instance Freshable Unique where
  fresh _ = freshUnique

instance Freshable Var where
  fresh v = liftM (setVarUnique v) freshUnique

freshInt :: MonadState TrEnv m => m Int
freshInt
  = do s <- get
       let n = freshIndex s
       put s{freshIndex = n+1}
       return n

freshUnique :: MonadState TrEnv m => m Unique
freshUnique = liftM (mkUnique 'X') freshInt


mapNonRec :: (b -> [Bind b] -> [Bind b]) -> [Bind b] -> [Bind b]
mapNonRec f (NonRec x xe:xes) = NonRec x xe : f x (mapNonRec f xes)
mapNonRec f (xe:xes)          = xe : mapNonRec f xes
mapNonRec _ []                = []

mapBnd :: (b -> Expr b -> Expr b) -> Bind b -> Bind b
mapBnd f (NonRec b e)             = NonRec b (mapExpr f  e)
mapBnd f (Rec bs)                 = Rec (map (second (mapExpr f)) bs)

mapExpr :: (b -> Expr b -> Expr b) -> Expr b -> Expr b
mapExpr f (Let (NonRec x ex) e)   = Let (NonRec x (f x ex) ) (f x e)
mapExpr f (App e1 e2)             = App  (mapExpr f e1) (mapExpr f e2)
mapExpr f (Lam b e)               = Lam b (mapExpr f e)
mapExpr f (Let bs e)              = Let (mapBnd f bs) (mapExpr f e)
mapExpr f (Case e b t alt)        = Case e b t (map (mapAlt f) alt)
mapExpr f (Tick t e)              = Tick t (mapExpr f e)
mapExpr _  e                      = e

mapAlt :: (b -> Expr b -> Expr b) -> (t, t1, Expr b) -> (t, t1, Expr b)
mapAlt f (d, bs, e) = (d, bs, mapExpr f e)

-- Do not apply transformations to inner code

mapBdM :: Monad m => t -> a -> m a
mapBdM _ = return

-- mapBdM f (Let b e)        = liftM2 Let (f b) (mapBdM f e)
-- mapBdM f (App e1 e2)      = liftM2 App (mapBdM f e1) (mapBdM f e2)
-- mapBdM f (Lam b e)        = liftM (Lam b) (mapBdM f e)
-- mapBdM f (Case e b t alt) = liftM (Case e b t) (mapM (mapBdAltM f) alt)
-- mapBdM f (Tick t e)       = liftM (Tick t) (mapBdM f e)
-- mapBdM _  e               = return  e
--
-- mapBdAltM f (d, bs, e) = liftM ((,,) d bs) (mapBdM f e)