{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TupleSections #-}

module Language.Haskell.Liquid.Bare.Measure (
    makeHaskellMeasures
  , makeHaskellInlines
  , makeHaskellBounds
  , makeMeasureSpec
  , makeMeasureSpec'
  , makeClassMeasureSpec
  , makeMeasureSelectors
  , strengthenHaskellMeasures
  , varMeasures
  ) where

import CoreSyn
import DataCon
import TyCon
import Id
import Name
import Type hiding (isFunTy)
import Var

import Prelude hiding (mapM, error)
import Control.Arrow ((&&&))
import Control.Monad hiding (forM, mapM)
import Control.Monad.Except hiding (forM, mapM)
import Control.Monad.State hiding (forM, mapM)
import Data.Bifunctor
import Data.Maybe
import Data.Char (toUpper)

import Data.Traversable (forM, mapM)
import Text.PrettyPrint.HughesPJ (text)
import Text.Parsec.Pos (SourcePos)

import qualified Data.List as L

import qualified Data.HashMap.Strict as M
import qualified Data.HashSet        as S

import Language.Fixpoint.Misc (mlookup, sortNub)
import Language.Fixpoint.Types (Symbol, dummySymbol, symbolString, symbol, Expr(..))
import Language.Fixpoint.SortCheck (isFirstOrder)

import qualified Language.Fixpoint.Types as F

import Language.Haskell.Liquid.Transforms.CoreToLogic
import Language.Haskell.Liquid.Misc
import Language.Haskell.Liquid.GHC.Misc (dropModuleNames, getSourcePos, getSourcePosE, sourcePosSrcSpan, isDataConId)
import Language.Haskell.Liquid.Types.RefType (generalize, ofType, uRType, typeSort)
import Language.Haskell.Liquid.Types
import Language.Haskell.Liquid.Types.Bounds

import qualified Language.Haskell.Liquid.Measure as Ms

import Language.Haskell.Liquid.Bare.Env
import Language.Haskell.Liquid.Bare.Misc       (simpleSymbolVar, hasBoolResult)
import Language.Haskell.Liquid.Bare.Expand
import Language.Haskell.Liquid.Bare.Lookup
import Language.Haskell.Liquid.Bare.OfType
import Language.Haskell.Liquid.Bare.Resolve
import Language.Haskell.Liquid.Bare.RefToLogic

makeHaskellMeasures :: F.TCEmb TyCon -> [CoreBind] -> ModName -> (ModName, Ms.BareSpec) -> BareM (Ms.MSpec SpecType DataCon)
makeHaskellMeasures _   _   name' (name, _   ) | name /= name'
  = return mempty
makeHaskellMeasures tce cbs _     (_   , spec)
  = do lmap <- gets logicEnv
       Ms.mkMSpec' <$> mapM (makeMeasureDefinition tce lmap cbs') (S.toList $ Ms.hmeas spec)
  where
    cbs'                  = concatMap unrec cbs
    unrec cb@(NonRec _ _) = [cb]
    unrec (Rec xes)       = [NonRec x e | (x, e) <- xes]

makeHaskellInlines :: F.TCEmb TyCon -> [CoreBind] -> ModName -> (ModName, Ms.BareSpec) -> BareM ()
makeHaskellInlines _   _   name' (name, _   ) | name /= name'
  = return mempty
makeHaskellInlines tce cbs _     (_   , spec)
  = do lmap <- gets logicEnv
       mapM_ (makeMeasureInline tce lmap cbs') (S.toList $ Ms.inlines spec)
  where
    cbs'                  = concatMap unrec cbs
    unrec cb@(NonRec _ _) = [cb]
    unrec (Rec xes)       = [NonRec x e | (x, e) <- xes]

makeMeasureInline :: F.TCEmb TyCon -> LogicMap -> [CoreBind] ->  LocSymbol -> BareM ()
makeMeasureInline tce lmap cbs  x
  = case filter ((val x `elem`) . map (dropModuleNames . simplesymbol) . binders) cbs of
    (NonRec v def:_)   -> do {e <- coreToFun' tce x v def; updateInlines x e}
    (Rec [(v, def)]:_) -> do {e <- coreToFun' tce x v def; updateInlines x e}
    _                  -> throwError $ mkError "Cannot inline haskell function"
  where
    binders (NonRec z _) = [z]
    binders (Rec xes)    = fst <$> xes

    coreToFun' tce x v def = case runToLogic tce lmap mkError $ coreToFun x v def of
                           Left (xs, e)  -> return (TI (symbol <$> xs) (fromLR e))
                           Right e -> throwError e

    fromLR (Left l)  = l 
    fromLR (Right r) = r

    mkError :: String -> Error
    mkError str = ErrHMeas (sourcePosSrcSpan $ loc x) (pprint $ val x) (text str)


updateInlines x v = modify $ \s -> let iold  = M.insert (val x) v (inlines s) in
                                   s{inlines = M.map (f iold) iold }
  where
    f             = txRefToLogic mempty

makeMeasureDefinition :: F.TCEmb TyCon -> LogicMap -> [CoreBind] -> LocSymbol -> BareM (Measure SpecType DataCon)
makeMeasureDefinition tce lmap cbs x
  = case filter ((val x `elem`) . map (dropModuleNames . simplesymbol) . binders) cbs of
    (NonRec v def:_)   -> Ms.mkM x (logicType $ varType v) <$> coreToDef' x v def
    (Rec [(v, def)]:_) -> Ms.mkM x (logicType $ varType v) <$> coreToDef' x v def
    _                  -> throwError $ mkError "Cannot extract measure from haskell function"
  where
    binders (NonRec x _) = [x]
    binders (Rec xes)    = fst <$> xes

    coreToDef' x v def = case runToLogic tce lmap mkError $ coreToDef x v def of
                           Left l  -> return     l
                           Right e -> throwError e

    mkError :: String -> Error
    mkError str = ErrHMeas (sourcePosSrcSpan $ loc x) (pprint $ val x) (text str)

simplesymbol :: CoreBndr -> Symbol
simplesymbol = symbol . getName

strengthenHaskellMeasures :: S.HashSet (Located Var) -> [(Var, Located SpecType)]
strengthenHaskellMeasures hmeas
  = (val &&& fmap strengthenResult) <$> S.toList hmeas

makeMeasureSelectors :: (DataCon, Located DataConP) -> [Measure SpecType DataCon]
makeMeasureSelectors (dc, Loc l l' (DataConP _ vs _ _ _ xts r _))
  = catMaybes (go <$> zip (reverse xts) [1..])
  where
    go ((x,t), i)
      | isFunTy t = Nothing
      | otherwise = Just $ makeMeasureSelector (Loc l l' x) (dty t) dc n i

    dty t         = foldr RAllT  (RFun dummySymbol r (fmap mempty t) mempty) vs
    n             = length xts

makeMeasureSelector x s dc n i = M {name = x, sort = s, eqns = [eqn]}
  where eqn   = Def x [] dc Nothing (((, Nothing) . mkx) <$> [1 .. n]) (E (EVar $ mkx i))
        mkx j = symbol ("xx" ++ show j)


makeMeasureSpec :: (ModName, Ms.Spec BareType LocSymbol) -> BareM (Ms.MSpec SpecType DataCon)
makeMeasureSpec (mod,spec) = inModule mod mkSpec
  where
    mkSpec = mkMeasureDCon =<< mkMeasureSort =<< m
    m      = Ms.mkMSpec <$> mapM expandMeasure (Ms.measures spec)
                        <*> return (Ms.cmeasures spec)
                        <*> mapM expandMeasure (Ms.imeasures spec)

makeMeasureSpec' = mapFst (mapSnd uRType <$>) . Ms.dataConTypes . first (mapReft ur_reft)

makeClassMeasureSpec (Ms.MSpec {..}) = tx <$> M.elems cmeasMap
  where
    tx (M n s _) = (n, CM n (mapReft ur_reft s) -- [(t,m) | (IM n' t m) <- imeas, n == n']
                   )


mkMeasureDCon :: Ms.MSpec t LocSymbol -> BareM (Ms.MSpec t DataCon)
mkMeasureDCon m
  = mkMeasureDCon_ m <$> forM (measureCtors m)
                           (\n -> (val n,) <$> lookupGhcDataCon n)

mkMeasureDCon_ :: Ms.MSpec t LocSymbol -> [(Symbol, DataCon)] -> Ms.MSpec t DataCon
mkMeasureDCon_ m ndcs = m' {Ms.ctorMap = cm'}
  where
    m'  = fmap (tx.val) m
    cm' = hashMapMapKeys (symbol . tx) $ Ms.ctorMap m'
    tx  = mlookup (M.fromList ndcs)

measureCtors ::  Ms.MSpec t LocSymbol -> [LocSymbol]
measureCtors = sortNub . fmap ctor . concat . M.elems . Ms.ctorMap

mkMeasureSort ::  Ms.MSpec BareType LocSymbol -> BareM (Ms.MSpec SpecType LocSymbol)
mkMeasureSort (Ms.MSpec c mm cm im)
  = Ms.MSpec <$> forM c (mapM txDef) <*> forM mm tx <*> forM cm tx <*> forM im tx
    where
      tx :: Measure BareType ctor -> BareM (Measure SpecType ctor)
      tx (M n s eqs) = M n <$> ofMeaSort s <*> mapM txDef eqs

      txDef :: Def BareType ctor -> BareM (Def SpecType ctor)
      txDef def = liftM3 (\xs t bds-> def{ dparams = xs, dsort = t, binds = bds})
                  (mapM (mapSndM ofMeaSort) (dparams def))
                  (mapM ofMeaSort $ dsort def)
                  (mapM (mapSndM $ mapM ofMeaSort) (binds def))


varMeasures :: (Monoid r) => [Var] -> [(Symbol, Located (RRType r))]
varMeasures vars = [ (symbol v, varSpecType v)  | v <- vars
                                                , isDataConId v
                                                , isSimpleType $ varType v ]

isSimpleType :: Type -> Bool
isSimpleType = isFirstOrder . typeSort M.empty

-- OLD isSimpleType t   = null tvs && isNothing (splitFunTy_maybe tb)
-- OLD  where
-- OLD    (tvs, tb)    = splitForAllTys t

varSpecType :: (Monoid r) => Var -> Located (RRType r)
varSpecType v    = Loc l l' (ofType $ varType v)
  where
    l            = getSourcePos  v
    l'           = getSourcePosE v


makeHaskellBounds :: F.TCEmb TyCon -> CoreProgram -> S.HashSet (Var, LocSymbol) -> BareM RBEnv
makeHaskellBounds tce cbs xs
  = do lmap <- gets logicEnv
       M.fromList <$> mapM (makeHaskellBound tce lmap cbs) (S.toList xs)


makeHaskellBound tce lmap  cbs (v, x) = case filter ((v  `elem`) . binders) cbs of
    (NonRec v def:_)   -> do {e <- coreToFun' tce x v def; return $ toBound v x e}
    (Rec [(v, def)]:_) -> do {e <- coreToFun' tce x v def; return $ toBound v x e}
    _                  -> throwError $ mkError "Cannot make bound of haskell function"

  where
    binders (NonRec x _) = [x]
    binders (Rec xes)    = fst <$> xes

    coreToFun' tce x v def = case runToLogic tce lmap mkError $ coreToFun x v def of
                           Left (xs, e) -> return (xs, e)
                           Right e      -> throwError e

    mkError :: String -> Error
    mkError str = ErrHMeas (sourcePosSrcSpan $ loc x) (pprint $ val x) (text str)


toBound :: Var -> LocSymbol -> ([Var], Either F.Expr F.Expr) -> (LocSymbol, RBound)
toBound v x (vs, Left p) = (x', Bound x' fvs ps xs p)
  where
    x'         = capitalizeBound x
    (ps', xs') = L.partition (hasBoolResult . varType) vs
    (ps , xs)  = (txp <$> ps', txx <$> xs')
    txp v      = (dummyLoc $ simpleSymbolVar v, ofType $ varType v)
    txx v      = (dummyLoc $ symbol v,          ofType $ varType v)
    fvs        = (((`RVar` mempty) . RTV) <$> fst (splitForAllTys $ varType v)) :: [RSort]

toBound v x (vs, Right e) = toBound v x (vs, Left e)

capitalizeBound = fmap (symbol . toUpperHead . symbolString)
  where
    toUpperHead []     = []
    toUpperHead (x:xs) = toUpper x:xs

--------------------------------------------------------------------------------
-- | Expand Measures -----------------------------------------------------------
--------------------------------------------------------------------------------

expandMeasure m
  = do eqns <- sequence $ expandMeasureDef <$> eqns m
       return $ m { sort = generalize (sort m)
                  , eqns = eqns }

expandMeasureDef :: Def t LocSymbol -> BareM (Def t LocSymbol)
expandMeasureDef d
  = do body <- expandMeasureBody (loc $ measure d) $ body d
       return $ d { body = body }

expandMeasureBody :: SourcePos -> Body -> BareM Body
expandMeasureBody l (P p)   = P   <$> (resolve l =<< expandExpr p)
expandMeasureBody l (R x p) = R x <$> (resolve l =<< expandExpr p)
expandMeasureBody l (E e)   = E   <$> resolve l e