module Cryptol.Transform.Specialize
where
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.TypeMap
import Cryptol.TypeCheck.Subst
import qualified Cryptol.ModuleSystem as M
import qualified Cryptol.ModuleSystem.Env as M
import qualified Cryptol.ModuleSystem.Monad as M
import Cryptol.ModuleSystem.Name
import           Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (catMaybes)
import MonadLib hiding (mapM)
type SpecCache = Map Name (Decl, TypesMap (Name, Maybe Decl))
type SpecT m a = StateT SpecCache (M.ModuleT m) a
type SpecM a = SpecT IO a
runSpecT :: SpecCache -> SpecT m a -> M.ModuleT m (a, SpecCache)
runSpecT s m = runStateT s m
liftSpecT :: Monad m => M.ModuleT m a -> SpecT m a
liftSpecT m = lift m
getSpecCache :: Monad m => SpecT m SpecCache
getSpecCache = get
setSpecCache :: Monad m => SpecCache -> SpecT m ()
setSpecCache = set
modifySpecCache :: Monad m => (SpecCache -> SpecCache) -> SpecT m ()
modifySpecCache = modify
modify :: StateM m s => (s -> s) -> m ()
modify f = get >>= (set . f)
specialize :: Expr -> M.ModuleCmd Expr
specialize expr (ev, byteReader, modEnv) = run $ do
  let extDgs = allDeclGroups modEnv
  let (tparams, expr') = destETAbs expr
  spec' <- specializeEWhere expr' extDgs
  return (foldr ETAbs spec' tparams)
  where
  run = M.runModuleT (ev, byteReader, modEnv) . fmap fst . runSpecT Map.empty
specializeExpr :: Expr -> SpecM Expr
specializeExpr expr =
  case expr of
    EList es t    -> EList <$> traverse specializeExpr es <*> pure t
    ETuple es     -> ETuple <$> traverse specializeExpr es
    ERec fs       -> ERec <$> traverse specializeExpr fs
    ESel e s      -> ESel <$> specializeExpr e <*> pure s
    ESet e s v    -> ESet <$> specializeExpr e <*> pure s <*> specializeExpr v
    EIf e1 e2 e3  -> EIf <$> specializeExpr e1 <*> specializeExpr e2 <*> specializeExpr e3
    EComp len t e mss -> EComp len t <$> specializeExpr e <*> traverse (traverse specializeMatch) mss
    
    EVar {}       -> specializeConst expr
    ETAbs t e     -> do
      cache <- getSpecCache
      setSpecCache Map.empty
      e' <- specializeExpr e
      setSpecCache cache
      return (ETAbs t e')
    
    
    
    
    
    
    
    ETApp {}      -> specializeConst expr
    EApp e1 e2    -> EApp <$> specializeExpr e1 <*> specializeExpr e2
    EAbs qn t e   -> EAbs qn t <$> specializeExpr e
    EProofAbs p e -> EProofAbs p <$> specializeExpr e
    EProofApp {}  -> specializeConst expr
    EWhere e dgs  -> specializeEWhere e dgs
specializeMatch :: Match -> SpecM Match
specializeMatch (From qn l t e) = From qn l t <$> specializeExpr e
specializeMatch (Let decl)
  | null (sVars (dSignature decl)) = return (Let decl)
  | otherwise = fail "unimplemented: specializeMatch Let unimplemented"
  
withDeclGroups :: [DeclGroup] -> SpecM a
                  -> SpecM (a, [DeclGroup], Map Name (TypesMap Name))
withDeclGroups dgs action = do
  origCache <- getSpecCache
  let decls = concatMap groupDecls dgs
  let newCache = Map.fromList [ (dName d, (d, emptyTM)) | d <- decls ]
  let savedCache = Map.intersection origCache newCache
  
  setSpecCache (Map.union newCache origCache)
  result <- action
  
  let splitDecl :: Decl -> SpecM [Decl]
      splitDecl d = do
        ~(Just (_, tm)) <- Map.lookup (dName d) <$> getSpecCache
        return (catMaybes $ map (snd . snd) $ toListTM tm)
  let splitDeclGroup :: DeclGroup -> SpecM [DeclGroup]
      splitDeclGroup (Recursive ds) = do
        ds' <- concat <$> traverse splitDecl ds
        if null ds'
          then return []
          else return [Recursive ds']
      splitDeclGroup (NonRecursive d) = map NonRecursive <$> splitDecl d
  dgs' <- concat <$> traverse splitDeclGroup dgs
  
  newCache' <- flip Map.intersection newCache <$> getSpecCache
  let nameTable = fmap (fmap fst . snd) newCache'
  
  modifySpecCache (Map.union savedCache . flip Map.difference newCache)
  return (result, dgs', nameTable)
specializeEWhere :: Expr -> [DeclGroup] -> SpecM Expr
specializeEWhere e dgs = do
  (e', dgs', _) <- withDeclGroups dgs (specializeExpr e)
  return $ if null dgs'
    then e'
    else EWhere e' dgs'
specializeDeclGroups :: [DeclGroup] -> SpecM ([DeclGroup], Map Name (TypesMap Name))
specializeDeclGroups dgs = do
  let decls = concatMap groupDecls dgs
  let isMonoType s = null (sVars s) && null (sProps s)
  let monos = [ EVar (dName d) | d <- decls, isMonoType (dSignature d) ]
  (_, dgs', names) <- withDeclGroups dgs $ mapM specializeExpr monos
  return (dgs', names)
specializeConst :: Expr -> SpecM Expr
specializeConst e0 = do
  let (e1, n) = destEProofApps e0
  let (e2, ts) = destETApps e1
  case e2 of
    EVar qname ->
      do cache <- getSpecCache
         case Map.lookup qname cache of
           Nothing -> return e0 
           Just (decl, tm) ->
             case lookupTM ts tm of
               Just (qname', _) -> return (EVar qname') 
               Nothing -> do  
                 qname' <- freshName qname ts 
                 sig' <- instantiateSchema ts n (dSignature decl)
                 modifySpecCache (Map.adjust (fmap (insertTM ts (qname', Nothing))) qname)
                 rhs' <- case dDefinition decl of
                           DExpr e -> do e' <- specializeExpr =<< instantiateExpr ts n e
                                         return (DExpr e')
                           DPrim   -> return DPrim
                 let decl' = decl { dName = qname', dSignature = sig', dDefinition = rhs' }
                 modifySpecCache (Map.adjust (fmap (insertTM ts (qname', Just decl'))) qname)
                 return (EVar qname')
    _ -> return e0 
destEProofApps :: Expr -> (Expr, Int)
destEProofApps = go 0
  where
    go n (EProofApp e) = go (n + 1) e
    go n e             = (e, n)
destETApps :: Expr -> (Expr, [Type])
destETApps = go []
  where
    go ts (ETApp e t) = go (t : ts) e
    go ts e           = (e, ts)
destEProofAbs :: Expr -> ([Prop], Expr)
destEProofAbs = go []
  where
    go ps (EProofAbs p e) = go (p : ps) e
    go ps e               = (ps, e)
destETAbs :: Expr -> ([TParam], Expr)
destETAbs = go []
  where
    go ts (ETAbs t e) = go (t : ts) e
    go ts e           = (ts, e)
freshName :: Name -> [Type] -> SpecM Name
freshName n _ =
  case nameInfo n of
    Declared ns s -> liftSupply (mkDeclared ns s ident fx loc)
    Parameter     -> liftSupply (mkParameter ident loc)
  where
  fx    = nameFixity n
  ident = nameIdent n
  loc   = nameLoc n
instantiateSchema :: [Type] -> Int -> Schema -> SpecM Schema
instantiateSchema ts n (Forall params props ty)
  | length params /= length ts = fail "instantiateSchema: wrong number of type arguments"
  | length props /= n          = fail "instantiateSchema: wrong number of prop arguments"
  | otherwise                  = return $ Forall [] [] (apSubst sub ty)
  where sub = listParamSubst (zip params ts)
instantiateExpr :: [Type] -> Int -> Expr -> SpecM Expr
instantiateExpr [] 0 e = return e
instantiateExpr [] n (EProofAbs _ e) = instantiateExpr [] (n - 1) e
instantiateExpr (t : ts) n (ETAbs param e) =
  instantiateExpr ts n (apSubst (singleTParamSubst param t) e)
instantiateExpr _ _ _ = fail "instantiateExpr: wrong number of type/proof arguments"
allDeclGroups :: M.ModuleEnv -> [DeclGroup]
allDeclGroups =
    concatMap mDecls
  . M.loadedModules
traverseSnd :: Functor f => (b -> f c) -> (a, b) -> f (a, c)
traverseSnd f (x, y) = (,) x <$> f y