-- |
-- Module      :  Cryptol.Transform.Specialize
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable

module Cryptol.Transform.Specialize
where

import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.TypeMap
import Cryptol.TypeCheck.Subst
import qualified Cryptol.ModuleSystem as M
import qualified Cryptol.ModuleSystem.Env as M
import qualified Cryptol.ModuleSystem.Monad as M
import Cryptol.ModuleSystem.Name

import           Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)

import MonadLib hiding (mapM)

import Prelude ()
import Prelude.Compat

-- Specializer Monad -----------------------------------------------------------

-- | A Name should have an entry in the SpecCache iff it is
-- specializable. Each Name starts out with an empty TypesMap.
type SpecCache = Map Name (Decl, TypesMap (Name, Maybe Decl))

-- | The specializer monad.
type SpecT m a = StateT SpecCache (M.ModuleT m) a

type SpecM a = SpecT IO a

runSpecT :: SpecCache -> SpecT m a -> M.ModuleT m (a, SpecCache)
runSpecT s m = runStateT s m

liftSpecT :: Monad m => M.ModuleT m a -> SpecT m a
liftSpecT m = lift m

getSpecCache :: Monad m => SpecT m SpecCache
getSpecCache = get

setSpecCache :: Monad m => SpecCache -> SpecT m ()
setSpecCache = set

modifySpecCache :: Monad m => (SpecCache -> SpecCache) -> SpecT m ()
modifySpecCache = modify

modify :: StateM m s => (s -> s) -> m ()
modify f = get >>= (set . f)


-- Specializer -----------------------------------------------------------------

-- | Add a `where` clause to the given expression containing
-- type-specialized versions of all functions called (transitively) by
-- the body of the expression.
specialize :: Expr -> M.ModuleCmd Expr
specialize expr (ev,modEnv) = run $ do
  let extDgs = allDeclGroups modEnv
  let (tparams, expr') = destETAbs expr
  spec' <- specializeEWhere expr' extDgs
  return (foldr ETAbs spec' tparams)
  where
  run = M.runModuleT (ev,modEnv) . fmap fst . runSpecT Map.empty

specializeExpr :: Expr -> SpecM Expr
specializeExpr expr =
  case expr of
    EList es t    -> EList <$> traverse specializeExpr es <*> pure t
    ETuple es     -> ETuple <$> traverse specializeExpr es
    ERec fs       -> ERec <$> traverse (traverseSnd specializeExpr) fs
    ESel e s      -> ESel <$> specializeExpr e <*> pure s
    ESet e s v    -> ESet <$> specializeExpr e <*> pure s <*> specializeExpr v
    EIf e1 e2 e3  -> EIf <$> specializeExpr e1 <*> specializeExpr e2 <*> specializeExpr e3
    EComp len t e mss -> EComp len t <$> specializeExpr e <*> traverse (traverse specializeMatch) mss
    -- Bindings within list comprehensions always have monomorphic types.
    EVar {}       -> specializeConst expr
    ETAbs t e     -> do
      cache <- getSpecCache
      setSpecCache Map.empty
      e' <- specializeExpr e
      setSpecCache cache
      return (ETAbs t e')
    -- We need to make sure that after processing `e`, no specialized
    -- decls mentioning type variable `t` escape outside the
    -- `ETAbs`. To avoid this, we reset to an empty SpecCache while we
    -- run `specializeExpr e`, and restore it afterward: this
    -- effectively prevents the specializer from registering any type
    -- instantiations involving `t` for any decls bound outside the
    -- scope of `t`.
    ETApp {}      -> specializeConst expr
    EApp e1 e2    -> EApp <$> specializeExpr e1 <*> specializeExpr e2
    EAbs qn t e   -> EAbs qn t <$> specializeExpr e
    EProofAbs p e -> EProofAbs p <$> specializeExpr e
    EProofApp {}  -> specializeConst expr
    EWhere e dgs  -> specializeEWhere e dgs

specializeMatch :: Match -> SpecM Match
specializeMatch (From qn l t e) = From qn l t <$> specializeExpr e
specializeMatch (Let decl)
  | null (sVars (dSignature decl)) = return (Let decl)
  | otherwise = fail "unimplemented: specializeMatch Let unimplemented"
  -- TODO: should treat this case like EWhere.


-- | Add the declarations to the SpecCache, run the given monadic
-- action, and then pull the specialized declarations back out of the
-- SpecCache state. Return the result along with the declarations and
-- a table of names of specialized bindings.
withDeclGroups :: [DeclGroup] -> SpecM a
                  -> SpecM (a, [DeclGroup], Map Name (TypesMap Name))
withDeclGroups dgs action = do
  origCache <- getSpecCache
  let decls = concatMap groupDecls dgs
  let newCache = Map.fromList [ (dName d, (d, emptyTM)) | d <- decls ]
  let savedCache = Map.intersection origCache newCache
  -- We assume that the names bound in dgs are disjoint from the other names in scope.
  setSpecCache (Map.union newCache origCache)
  result <- action
  -- Then reassemble the DeclGroups.
  let splitDecl :: Decl -> SpecM [Decl]
      splitDecl d = do
        ~(Just (_, tm)) <- Map.lookup (dName d) <$> getSpecCache
        return (catMaybes $ map (snd . snd) $ toListTM tm)
  let splitDeclGroup :: DeclGroup -> SpecM [DeclGroup]
      splitDeclGroup (Recursive ds) = do
        ds' <- concat <$> traverse splitDecl ds
        if null ds'
          then return []
          else return [Recursive ds']
      splitDeclGroup (NonRecursive d) = map NonRecursive <$> splitDecl d
  dgs' <- concat <$> traverse splitDeclGroup dgs
  -- Get updated map of only the local entries we added.
  newCache' <- flip Map.intersection newCache <$> getSpecCache
  let nameTable = fmap (fmap fst . snd) newCache'
  -- Remove local definitions from the cache.
  modifySpecCache (Map.union savedCache . flip Map.difference newCache)
  return (result, dgs', nameTable)

-- | Compute the specialization of `EWhere e dgs`. A decl within `dgs`
-- is replicated once for each monomorphic type instance at which it
-- is used; decls not mentioned in `e` (even monomorphic ones) are
-- simply dropped.
specializeEWhere :: Expr -> [DeclGroup] -> SpecM Expr
specializeEWhere e dgs = do
  (e', dgs', _) <- withDeclGroups dgs (specializeExpr e)
  return $ if null dgs'
    then e'
    else EWhere e' dgs'

-- | Transform the given declaration groups into a set of monomorphic
-- declarations. All of the original declarations with monomorphic
-- types are kept; additionally the result set includes instantiated
-- versions of polymorphic decls that are referenced by the
-- monomorphic bindings. We also return a map relating generated names
-- to the names from the original declarations.
specializeDeclGroups :: [DeclGroup] -> SpecM ([DeclGroup], Map Name (TypesMap Name))
specializeDeclGroups dgs = do
  let decls = concatMap groupDecls dgs
  let isMonoType s = null (sVars s) && null (sProps s)
  let monos = [ EVar (dName d) | d <- decls, isMonoType (dSignature d) ]
  (_, dgs', names) <- withDeclGroups dgs $ mapM specializeExpr monos
  return (dgs', names)

specializeConst :: Expr -> SpecM Expr
specializeConst e0 = do
  let (e1, n) = destEProofApps e0
  let (e2, ts) = destETApps e1
  case e2 of
    EVar qname ->
      do cache <- getSpecCache
         case Map.lookup qname cache of
           Nothing -> return e0 -- Primitive/unspecializable variable; leave it alone
           Just (decl, tm) ->
             case lookupTM ts tm of
               Just (qname', _) -> return (EVar qname') -- Already specialized
               Nothing -> do  -- A new type instance of this function
                 qname' <- freshName qname ts -- New type instance, record new name
                 sig' <- instantiateSchema ts n (dSignature decl)
                 modifySpecCache (Map.adjust (fmap (insertTM ts (qname', Nothing))) qname)
                 rhs' <- case dDefinition decl of
                           DExpr e -> do e' <- specializeExpr =<< instantiateExpr ts n e
                                         return (DExpr e')
                           DPrim   -> return DPrim
                 let decl' = decl { dName = qname', dSignature = sig', dDefinition = rhs' }
                 modifySpecCache (Map.adjust (fmap (insertTM ts (qname', Just decl'))) qname)
                 return (EVar qname')
    _ -> return e0 -- type/proof application to non-variable; not specializable


-- Utility Functions -----------------------------------------------------------

destEProofApps :: Expr -> (Expr, Int)
destEProofApps = go 0
  where
    go n (EProofApp e) = go (n + 1) e
    go n e             = (e, n)

destETApps :: Expr -> (Expr, [Type])
destETApps = go []
  where
    go ts (ETApp e t) = go (t : ts) e
    go ts e           = (e, ts)

destEProofAbs :: Expr -> ([Prop], Expr)
destEProofAbs = go []
  where
    go ps (EProofAbs p e) = go (p : ps) e
    go ps e               = (ps, e)

destETAbs :: Expr -> ([TParam], Expr)
destETAbs = go []
  where
    go ts (ETAbs t e) = go (t : ts) e
    go ts e           = (ts, e)

-- Any top-level declarations in the current module can be found in the
-- ModuleEnv's LoadedModules, and so we can count of freshName to avoid
-- collisions with them.  Any generated name for a
-- specialized function will be qualified with the current @ModName@, so genned
-- names will not collide with local decls either.
-- freshName :: Name -> [Type] -> SpecM Name
-- freshName n [] = return n
-- freshName (QName m name) tys = do
--   let name' = reifyName name tys
--   bNames <- matchingBoundNames m
--   let loop i = let nm = name' ++ "_" ++ show i
--                  in if nm `elem` bNames
--                       then loop $ i + 1
--                       else nm
--   let go = if name' `elem` bNames
--                then loop (1 :: Integer)
--                else name'
--   return $ QName m (mkName go)

-- | Freshen a name by giving it a new unique.
freshName :: Name -> [Type] -> SpecM Name
freshName n _ =
  case nameInfo n of
    Declared ns s -> liftSupply (mkDeclared ns s ident fx loc)
    Parameter     -> liftSupply (mkParameter ident loc)
  where
  fx    = nameFixity n
  ident = nameIdent n
  loc   = nameLoc n

-- matchingBoundNames :: (Maybe ModName) -> SpecM [String]
-- matchingBoundNames m = do
--   qns <- allPublicNames <$> liftSpecT M.getModuleEnv
--   return [ unpack n | QName m' (Name n) <- qns , m == m' ]

-- reifyName :: Name -> [Type] -> String
-- reifyName name tys = intercalate "_" (showName name : concatMap showT tys)
--   where
--     tvInt (TVFree i _ _ _) = i
--     tvInt (TVBound i _) = i
--     showT typ =
--       case typ of
--         TCon tc ts  -> showTCon tc : concatMap showT ts
--         TUser _ _ t -> showT t
--         TVar tv     -> [ "a" ++ show (tvInt tv) ]
--         TRec tr     -> "rec" : concatMap showRecFld tr
--     showTCon tCon =
--       case tCon of
--         TC tc -> showTC tc
--         PC pc -> showPC pc
--         TF tf -> showTF tf
--     showPC pc =
--       case pc of
--         PEqual   -> "eq"
--         PNeq     -> "neq"
--         PGeq     -> "geq"
--         PFin     -> "fin"
--         PHas sel -> "sel_" ++ showSel sel
--         PArith   -> "arith"
--         PCmp     -> "cmp"
--     showTC tc =
--       case tc of
--         TCNum n     -> show n
--         TCInf       -> "inf"
--         TCBit       -> "bit"
--         TCSeq       -> "seq"
--         TCFun       -> "fun"
--         TCTuple n   -> "t" ++ show n
--         TCNewtype _ -> "user"
--     showSel sel = intercalate "_" $
--       case sel of
--         TupleSel  _ sig -> "tup"  : maybe [] ((:[]) . show) sig
--         RecordSel x sig -> "rec"  : showName x : map showName (maybe [] id sig)
--         ListSel   _ sig -> "list" : maybe [] ((:[]) . show) sig
--     showName nm =
--       case nm of
--         Name s       -> unpack s
--         NewName _ n -> "x" ++ show n
--     showTF tf =
--       case tf of
--         TCAdd           -> "add"
--         TCSub           -> "sub"
--         TCMul           -> "mul"
--         TCDiv           -> "div"
--         TCMod           -> "mod"
--         TCExp           -> "exp"
--         TCWidth         -> "width"
--         TCMin           -> "min"
--         TCMax           -> "max"
--         TCLenFromThenTo -> "len_from_then_to"
--     showRecFld (nm,t) = showName nm : showT t



instantiateSchema :: [Type] -> Int -> Schema -> SpecM Schema
instantiateSchema ts n (Forall params props ty)
  | length params /= length ts = fail "instantiateSchema: wrong number of type arguments"
  | length props /= n          = fail "instantiateSchema: wrong number of prop arguments"
  | otherwise                  = return $ Forall [] [] (apSubst sub ty)
  where sub = listParamSubst (zip params ts)

-- | Reduce `length ts` outermost type abstractions and `n` proof abstractions.
instantiateExpr :: [Type] -> Int -> Expr -> SpecM Expr
instantiateExpr [] 0 e = return e
instantiateExpr [] n (EProofAbs _ e) = instantiateExpr [] (n - 1) e
instantiateExpr (t : ts) n (ETAbs param e) =
  instantiateExpr ts n (apSubst (singleSubst (tpVar param) t) e)
instantiateExpr _ _ _ = fail "instantiateExpr: wrong number of type/proof arguments"



allDeclGroups :: M.ModuleEnv -> [DeclGroup]
allDeclGroups =
    concatMap mDecls
  . M.loadedModules

allLoadedModules :: M.ModuleEnv -> [M.LoadedModule]
allLoadedModules =
    M.getLoadedModules
  . M.meLoadedModules

allPublicNames :: M.ModuleEnv -> [Name]
allPublicNames =
    concatMap
      ( Map.keys
      . M.ifDecls
      . M.ifPublic
      . M.lmInterface
      )
  . allLoadedModules

traverseSnd :: Functor f => (b -> f c) -> (a, b) -> f (a, c)
traverseSnd f (x, y) = (,) x <$> f y