{-# LANGUAGE DeriveDataTypeable        #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE TupleSections             #-}
{-# LANGUAGE TypeSynonymInstances      #-}

module Language.Haskell.Liquid.TransformRec (
     transformRecExpr, transformScope
     ) where

import           Bag
import           Coercion
import           Control.Arrow       (second, (***))
import           Control.Monad.State
import           CoreLint
import           CoreSyn
import qualified Data.HashMap.Strict as M
import           ErrUtils
import           Id                  (idOccInfo, setIdInfo)
import           IdInfo
import           MkCore              (mkCoreLams)
import           SrcLoc
import           Type                (mkForAllTys)
import           TypeRep
import           Unique              hiding (deriveUnique)
import           Var
import           Name (isSystemName)
import           Language.Haskell.Liquid.GhcMisc
import           Language.Haskell.Liquid.Misc (mapSndM)
import           Language.Fixpoint.Misc       (mapSnd)

import           Data.List                (foldl', isInfixOf)
import           Control.Applicative      ((<$>))

import qualified Data.List as L

transformRecExpr :: CoreProgram -> CoreProgram
transformRecExpr cbs
  | isEmptyBag $ filterBag isTypeError e
  =  {-trace "new cbs"-} pg 
  | otherwise 
  = error ("INITIAL\n" ++ showPpr pg0 ++ "\nTRANSFORMED\n" ++ showPpr pg ++ "Type-check" ++ showSDoc (pprMessageBag e))
  where pg0    = evalState (transPg cbs) initEnv
        (_, e) = lintCoreBindings [] pg
        pg     = inlineFailCases pg0


inlineFailCases :: CoreProgram -> CoreProgram
inlineFailCases = (go [] <$>)
  where 
    go su (Rec xes)    = Rec (mapSnd (go' su) <$> xes)
    go su (NonRec x e) = NonRec x (go' su e)

    go' su (App (Var x) _)       | isFailId x, Just e <- getFailExpr x su = e  
    go' su (Let (NonRec x ex) e) | isFailId x   = go' (addFailExpr x (go' su ex) su) e

    go' su (App e1 e2)      = App (go' su e1) (go' su e2)
    go' su (Lam x e)        = Lam x (go' su e)
    go' su (Let xs e)       = Let (go su xs) (go' su e)
    go' su (Case e x t alt) = Case (go' su e) x t (goalt su <$> alt) 
    go' su (Cast e c)       = Cast (go' su e) c
    go' su (Tick t e)       = Tick t (go' su e)
    go' su e                = e

    goalt su (c, xs, e)     = (c, xs, go' su e)

    isFailId x  = isLocalId x && (isSystemName $ varName x) && L.isPrefixOf "fail" (show x)
    getFailExpr = L.lookup

    addFailExpr x (Lam _ e) su = (x, e):su 
    addFailExpr x e         _  = error "internal error" -- this cannot happen

isTypeError s | isInfixOf "Non term variable" (showSDoc s) = False
isTypeError _ = True

scopeTr = outerScTr . innerScTr
transformScope = outerScTr . innerScTr

outerScTr = mapNonRec (go [])
  where
   go ack x (xe : xes) | isCaseArg x xe = go (xe:ack) x xes
   go ack _ xes        = ack ++ xes

isCaseArg x (NonRec _ (Case (Var z) _ _ _)) = z == x
isCaseArg _ _                               = False

innerScTr = (mapBnd scTrans <$>)

scTrans x e = mapExpr scTrans $ foldr Let e0 bs
  where (bs, e0)           = go [] x e
        go bs x (Let b e)  | isCaseArg x b = go (b:bs) x e
        go bs x (Tick t e) = second (Tick t) $ go bs x e
        go bs x e          = (bs, e)

type TE = State TrEnv

data TrEnv = Tr { freshIndex  :: !Int
                , loc         :: SrcSpan
                }

initEnv = Tr 0 noSrcSpan

transPg = mapM transBd

transBd (NonRec x e) = liftM (NonRec x) (transExpr =<< mapBdM transBd e)
transBd (Rec xes)    = liftM Rec $ mapM (mapSndM (mapBdM transBd)) xes

transExpr :: CoreExpr -> TE CoreExpr
transExpr e
  | (isNonPolyRec e') && (not (null tvs)) 
  = trans tvs ids bs e'
  | otherwise
  = return e
  where (tvs, ids, e'')       = collectTyAndValBinders e
        (bs, e')              = collectNonRecLets e''

isNonPolyRec (Let (Rec xes) _) = any nonPoly (snd <$> xes)
isNonPolyRec _                 = False

nonPoly = null . fst . collectTyBinders

collectNonRecLets = go []
  where go bs (Let b@(NonRec _ _) e') = go (b:bs) e'
        go bs e'                      = (reverse bs, e')

appTysAndIds tvs ids x = mkApps (mkTyApps (Var x) (map TyVarTy tvs)) (map Var ids)

trans vs ids bs (Let (Rec xes) e)
  = liftM (mkLam . mkLet) (makeTrans vs liveIds e')
  where liveIds = mkAlive <$> ids
        mkLet e = foldr Let e bs
        mkLam e = foldr Lam e $ vs ++ liveIds
        e'      = Let (Rec xes') e
        xes'    = (second mkLet) <$> xes

makeTrans vs ids (Let (Rec xes) e)
 = do fids    <- mapM (mkFreshIds vs ids) xs
      let (ids', ys) = unzip fids
      let yes  = appTysAndIds vs ids <$> ys
      ys'     <- mapM fresh xs
      let su   = M.fromList $ zip xs (Var <$> ys')
      let rs   = zip ys' yes
      let es'  = zipWith (mkE ys) ids' es
      let xes' = zip ys es'
      return   $ mkRecBinds rs (Rec xes') (sub su e)
 where 
   (xs, es)       = unzip xes
   mkSu ys ids'   = mkSubs ids vs ids' (zip xs ys)
   mkE ys ids' e' = mkCoreLams (vs ++ ids') (sub (mkSu ys ids') e')

mkRecBinds :: [(b, Expr b)] -> Bind b -> Expr b -> Expr b
mkRecBinds xes rs e = Let rs (foldl' f e xes)
  where f e (x, xe) = Let (NonRec x xe) e  

mkSubs ids tvs xs ys = M.fromList $ s1 ++ s2
  where s1 = (second (appTysAndIds tvs xs)) <$> ys
        s2 = zip ids (Var <$> xs)

mkFreshIds tvs ids x
  = do ids'  <- mapM fresh ids
       let t  = mkForAllTys tvs $ mkType (reverse ids') $ varType x
       let x' = setVarType x t
       return (ids', x')
  where 
    mkType ids ty = foldl (\t x -> FunTy (varType x) t) ty ids

class Freshable a where
  fresh :: a -> TE a

instance Freshable Int where
  fresh _ = freshInt

instance Freshable Unique where
  fresh _ = freshUnique

instance Freshable Var where
  fresh v = liftM (setVarUnique v) freshUnique

freshInt
  = do s <- get
       let n = freshIndex s
       put s{freshIndex = n+1}
       return n

freshUnique = liftM (mkUnique 'X') freshInt

mkAlive x
  | isId x && isDeadOcc (idOccInfo x)
  = setIdInfo x (setOccInfo (idInfo x) NoOccInfo)
  | otherwise
  = x

class Subable a where
  sub   :: M.HashMap CoreBndr CoreExpr -> a -> a
  subTy :: M.HashMap TyVar Type -> a -> a

instance Subable CoreExpr where
  sub s (Var v)        = M.lookupDefault (Var v) v s
  sub _ (Lit l)        = Lit l
  sub s (App e1 e2)    = App (sub s e1) (sub s e2)
  sub s (Lam b e)      = Lam b (sub s e)
  sub s (Let b e)      = Let (sub s b) (sub s e)
  sub s (Case e b t a) = Case (sub s e) (sub s b) t (map (sub s) a)
  sub s (Cast e c)     = Cast (sub s e) c
  sub s (Tick t e)     = Tick t (sub s e)
  sub _ (Type t)       = Type t
  sub _ (Coercion c)   = Coercion c

  subTy s (Var v)      = Var (subTy s v)
  subTy _ (Lit l)      = Lit l
  subTy s (App e1 e2)  = App (subTy s e1) (subTy s e2)
  subTy s (Lam b e)    | isTyVar b = Lam v' (subTy s e)
   where v' = case M.lookup b s of
               Nothing          -> b
               Just (TyVarTy v) -> v

  subTy s (Lam b e)      = Lam (subTy s b) (subTy s e)
  subTy s (Let b e)      = Let (subTy s b) (subTy s e)
  subTy s (Case e b t a) = Case (subTy s e) (subTy s b) (subTy s t) (map (subTy s) a)
  subTy s (Cast e c)     = Cast (subTy s e) (subTy s c)
  subTy s (Tick t e)     = Tick t (subTy s e)
  subTy s (Type t)       = Type (subTy s t)
  subTy s (Coercion c)   = Coercion (subTy s c)

instance Subable Coercion where
  sub _ c                = c
  subTy _ _              = error "subTy Coercion"

instance Subable (Alt Var) where
 sub s (a, b, e)   = (a, map (sub s) b,   sub s e)
 subTy s (a, b, e) = (a, map (subTy s) b, subTy s e)

instance Subable Var where
 sub s v   | M.member v s = subVar $ s M.! v 
           | otherwise    = v
 subTy s v = setVarType v (subTy s (varType v))

subVar (Var x) = x
subVar  _      = error "sub Var"

instance Subable (Bind Var) where
 sub s (NonRec x e)   = NonRec (sub s x) (sub s e)
 sub s (Rec xes)      = Rec ((sub s *** sub s) <$> xes)

 subTy s (NonRec x e) = NonRec (subTy s x) (subTy s e)
 subTy s (Rec xes)    = Rec ((subTy s  *** subTy s) <$> xes)

instance Subable Type where
 sub _ e   = e
 subTy     = substTysWith

substTysWith s tv@(TyVarTy v)  = M.lookupDefault tv v s
substTysWith s (FunTy t1 t2)   = FunTy (substTysWith s t1) (substTysWith s t2)
substTysWith s (ForAllTy v t)  = ForAllTy v (substTysWith (M.delete v s) t)
substTysWith s (TyConApp c ts) = TyConApp c (map (substTysWith s) ts)
substTysWith s (AppTy t1 t2)   = AppTy (substTysWith s t1) (substTysWith s t2)

mapNonRec f (NonRec x xe:xes) = NonRec x xe : f x (mapNonRec f xes)
mapNonRec f (xe:xes)          = xe : mapNonRec f xes
mapNonRec _ []                = []

mapBnd f (NonRec b e)             = NonRec b (mapExpr f  e)
mapBnd f (Rec bs)                 = Rec (map (second (mapExpr f)) bs)

mapExpr f (Let (NonRec x ex) e)   = Let (NonRec x (f x ex) ) (f x e)
mapExpr f (App e1 e2)             = App  (mapExpr f e1) (mapExpr f e2)
mapExpr f (Lam b e)               = Lam b (mapExpr f e)
mapExpr f (Let bs e)              = Let (mapBnd f bs) (mapExpr f e)
mapExpr f (Case e b t alt)        = Case e b t (map (mapAlt f) alt)
mapExpr f (Tick t e)              = Tick t (mapExpr f e)
mapExpr _  e                      = e

mapAlt f (d, bs, e) = (d, bs, mapExpr f e)

-- Do not apply transformations to inner code

mapBdM _ = return

-- mapBdM f (Let b e)        = liftM2 Let (f b) (mapBdM f e)
-- mapBdM f (App e1 e2)      = liftM2 App (mapBdM f e1) (mapBdM f e2)
-- mapBdM f (Lam b e)        = liftM (Lam b) (mapBdM f e)
-- mapBdM f (Case e b t alt) = liftM (Case e b t) (mapM (mapBdAltM f) alt)
-- mapBdM f (Tick t e)       = liftM (Tick t) (mapBdM f e)
-- mapBdM _  e               = return  e
-- 
-- mapBdAltM f (d, bs, e) = liftM ((,,) d bs) (mapBdM f e)