{-# LANGUAGE TemplateHaskell #-}

module Feldspar.Run.Compile where

import Control.Monad.Identity
import Control.Monad.Reader
import Data.Map (Map)
import qualified Data.Map as Map

import Data.Constraint (Dict (..))
import Data.Default.Class

import Language.Syntactic hiding ((:+:) (..), (:<:) (..))
import Language.Syntactic.Functional hiding (Binding (..))
import Language.Syntactic.Functional.Tuple

import qualified Control.Monad.Operational.Higher as Oper

import Language.Embedded.Expression
import Language.Embedded.Imperative hiding ((:+:) (..), (:<:) (..))
import Language.Embedded.Concurrent
import qualified Language.Embedded.Imperative as Imp
import Language.Embedded.Backend.C (ExternalCompilerOpts (..))
import qualified Language.Embedded.Backend.C as Imp

import Data.TypedStruct
import Data.Selection
import Feldspar.Primitive.Representation
import Feldspar.Primitive.Backend.C ()
import Feldspar.Representation
import Feldspar.Run.Representation
import Feldspar.Optimize

-- * Struct expressions and variables

-- | Struct expression
type VExp = Struct PrimType' Prim

-- | Struct expression with hidden result type
data VExp'
    VExp' :: Struct PrimType' Prim a -> VExp'

newRefV :: Monad m => TypeRep a -> String -> TargetT m (Struct PrimType' Imp.Ref a)
newRefV t base = lift $ mapStructA (const (newNamedRef base)) t

initRefV :: Monad m => String -> VExp a -> TargetT m (Struct PrimType' Imp.Ref a)
initRefV base = lift . mapStructA (initNamedRef base)

getRefV :: Monad m => Struct PrimType' Imp.Ref a -> TargetT m (VExp a)
getRefV = lift . mapStructA getRef

setRefV :: Monad m => Struct PrimType' Imp.Ref a -> VExp a -> TargetT m ()
setRefV r = lift . sequence_ . zipListStruct setRef r

unsafeFreezeRefV :: Monad m => Struct PrimType' Imp.Ref a -> TargetT m (VExp a)
unsafeFreezeRefV = lift . mapStructA unsafeFreezeRef

-- * Compilation options

-- | Options affecting code generation
-- A default set of options is given by 'def'.
-- The assertion labels to include in the generated code can be stated using the
-- functions 'select', 'allExcept' and 'selectBy'. For example
-- @`def` {compilerAssertions = `allExcept` [`InternalAssertion`]}@
-- states that we want to include all except internal assertions.
data CompilerOpts = CompilerOpts
    { compilerAssertions :: Selection AssertionLabel
        -- ^ Which assertions to include in the generated code

instance Default CompilerOpts
    def = CompilerOpts
      { compilerAssertions = universal

-- * Translation environment

-- | Translation environment
data Env = Env
    { envAliases :: Map Name VExp'
    , envOptions :: CompilerOpts

env0 :: Env
env0 = Env Map.empty def

-- | Add a local alias to the environment
localAlias :: MonadReader Env m
    => Name    -- ^ Old name
    -> VExp a  -- ^ New expression
    -> m b
    -> m b
localAlias v e =
    local (\env -> env {envAliases = Map.insert v (VExp' e) (envAliases env)})

-- | Lookup an alias in the environment
lookAlias :: MonadReader Env m => TypeRep a -> Name -> m (VExp a)
lookAlias t v = do
    env <- asks envAliases
    return $ case Map.lookup v env of
        Nothing -> error $ "lookAlias: variable " ++ show v ++ " not in scope"
        Just (VExp' e) -> case typeEq t (toTypeRep e) of Just Dict -> e

-- * Translation of expressions

type TargetCMD
    =       RefCMD
    Imp.:+: ArrCMD
    Imp.:+: ControlCMD
    Imp.:+: ThreadCMD
    Imp.:+: ChanCMD
    Imp.:+: PtrCMD
    Imp.:+: FileCMD
    Imp.:+: C_CMD

-- | Target monad during translation
type TargetT m = ReaderT Env (ProgramT TargetCMD (Param2 Prim PrimType') m)

-- | Monad for translated program
type ProgC = Program TargetCMD (Param2 Prim PrimType')

-- | Translate an expression
translateExp :: forall m a . Monad m => Data a -> TargetT m (VExp a)
translateExp a = do
    cs <- asks (compilerAssertions . envOptions)
    goAST $ optimize cs $ unData a
    -- Assumes that `b` is not a function type
    goAST :: ASTF FeldDomain b -> TargetT m (VExp b)
    goAST = simpleMatch (\(s :&: ValT t) -> go t s)

    goSmallAST :: PrimType' b => ASTF FeldDomain b -> TargetT m (Prim b)
    goSmallAST = fmap extractSingle . goAST

    go :: TypeRep (DenResult sig)
       -> FeldConstructs sig
       -> Args (AST FeldDomain) sig
       -> TargetT m (VExp (DenResult sig))
    go t lit Nil
        | Just (Lit a) <- prj lit
        = return $ mapStruct (constExp . runIdentity) $ toStruct t a
    go t lit Nil
        | Just (Literal a) <- prj lit
        = return $ mapStruct (constExp . runIdentity) $ toStruct t a
    go t var Nil
        | Just (VarT v) <- prj var
        = lookAlias t v
    go t lt (a :* (lam :$ body) :* Nil)
        | Just (Let tag) <- prj lt
        , Just (LamT v)  <- prj lam
        = do let base = if null tag then "let" else tag
             r  <- initRefV base =<< goAST a
             a' <- unsafeFreezeRefV r
             localAlias v a' $ goAST body
    go t tup (a :* b :* Nil)
        | Just Pair <- prj tup = Two <$> goAST a <*> goAST b
    go t sel (ab :* Nil)
        | Just Fst <- prj sel = do
            Two a _ <- goAST ab
            return a
        | Just Snd <- prj sel = do
            Two _ b <- goAST ab
            return b
    go _ c Nil
        | Just Pi <- prj c = return $ Single $ sugarSymPrim Pi
    go _ op (a :* Nil)
        | Just Neg       <- prj op = liftStruct (sugarSymPrim Neg)       <$> goAST a
        | Just Abs       <- prj op = liftStruct (sugarSymPrim Abs)       <$> goAST a
        | Just Sign      <- prj op = liftStruct (sugarSymPrim Sign)      <$> goAST a
        | Just Exp       <- prj op = liftStruct (sugarSymPrim Exp)       <$> goAST a
        | Just Log       <- prj op = liftStruct (sugarSymPrim Log)       <$> goAST a
        | Just Sqrt      <- prj op = liftStruct (sugarSymPrim Sqrt)      <$> goAST a
        | Just Sin       <- prj op = liftStruct (sugarSymPrim Sin)       <$> goAST a
        | Just Cos       <- prj op = liftStruct (sugarSymPrim Cos)       <$> goAST a
        | Just Tan       <- prj op = liftStruct (sugarSymPrim Tan)       <$> goAST a
        | Just Asin      <- prj op = liftStruct (sugarSymPrim Asin)      <$> goAST a
        | Just Acos      <- prj op = liftStruct (sugarSymPrim Acos)      <$> goAST a
        | Just Atan      <- prj op = liftStruct (sugarSymPrim Atan)      <$> goAST a
        | Just Sinh      <- prj op = liftStruct (sugarSymPrim Sinh)      <$> goAST a
        | Just Cosh      <- prj op = liftStruct (sugarSymPrim Cosh)      <$> goAST a
        | Just Tanh      <- prj op = liftStruct (sugarSymPrim Tanh)      <$> goAST a
        | Just Asinh     <- prj op = liftStruct (sugarSymPrim Asinh)     <$> goAST a
        | Just Acosh     <- prj op = liftStruct (sugarSymPrim Acosh)     <$> goAST a
        | Just Atanh     <- prj op = liftStruct (sugarSymPrim Atanh)     <$> goAST a
        | Just Real      <- prj op = liftStruct (sugarSymPrim Real)      <$> goAST a
        | Just Imag      <- prj op = liftStruct (sugarSymPrim Imag)      <$> goAST a
        | Just Magnitude <- prj op = liftStruct (sugarSymPrim Magnitude) <$> goAST a
        | Just Phase     <- prj op = liftStruct (sugarSymPrim Phase)     <$> goAST a
        | Just Conjugate <- prj op = liftStruct (sugarSymPrim Conjugate) <$> goAST a
        | Just I2N       <- prj op = liftStruct (sugarSymPrim I2N)       <$> goAST a
        | Just I2B       <- prj op = liftStruct (sugarSymPrim I2B)       <$> goAST a
        | Just B2I       <- prj op = liftStruct (sugarSymPrim B2I)       <$> goAST a
        | Just Round     <- prj op = liftStruct (sugarSymPrim Round)     <$> goAST a
        | Just Not       <- prj op = liftStruct (sugarSymPrim Not)       <$> goAST a
        | Just BitCompl  <- prj op = liftStruct (sugarSymPrim BitCompl)  <$> goAST a
    go _ op (a :* b :* Nil)
        | Just Add     <- prj op = liftStruct2 (sugarSymPrim Add)     <$> goAST a <*> goAST b
        | Just Sub     <- prj op = liftStruct2 (sugarSymPrim Sub)     <$> goAST a <*> goAST b
        | Just Mul     <- prj op = liftStruct2 (sugarSymPrim Mul)     <$> goAST a <*> goAST b
        | Just FDiv    <- prj op = liftStruct2 (sugarSymPrim FDiv)    <$> goAST a <*> goAST b
        | Just Quot    <- prj op = liftStruct2 (sugarSymPrim Quot)    <$> goAST a <*> goAST b
        | Just Rem     <- prj op = liftStruct2 (sugarSymPrim Rem)     <$> goAST a <*> goAST b
        | Just Div     <- prj op = liftStruct2 (sugarSymPrim Div)     <$> goAST a <*> goAST b
        | Just Mod     <- prj op = liftStruct2 (sugarSymPrim Mod)     <$> goAST a <*> goAST b
        | Just Complex <- prj op = liftStruct2 (sugarSymPrim Complex) <$> goAST a <*> goAST b
        | Just Polar   <- prj op = liftStruct2 (sugarSymPrim Polar)   <$> goAST a <*> goAST b
        | Just Pow     <- prj op = liftStruct2 (sugarSymPrim Pow)     <$> goAST a <*> goAST b
        | Just Eq      <- prj op = liftStruct2 (sugarSymPrim Eq)      <$> goAST a <*> goAST b
        | Just And     <- prj op = liftStruct2 (sugarSymPrim And)     <$> goAST a <*> goAST b
        | Just Or      <- prj op = liftStruct2 (sugarSymPrim Or)      <$> goAST a <*> goAST b
        | Just Lt      <- prj op = liftStruct2 (sugarSymPrim Lt)      <$> goAST a <*> goAST b
        | Just Gt      <- prj op = liftStruct2 (sugarSymPrim Gt)      <$> goAST a <*> goAST b
        | Just Le      <- prj op = liftStruct2 (sugarSymPrim Le)      <$> goAST a <*> goAST b
        | Just Ge      <- prj op = liftStruct2 (sugarSymPrim Ge)      <$> goAST a <*> goAST b
        | Just BitAnd  <- prj op = liftStruct2 (sugarSymPrim BitAnd)  <$> goAST a <*> goAST b
        | Just BitOr   <- prj op = liftStruct2 (sugarSymPrim BitOr)   <$> goAST a <*> goAST b
        | Just BitXor  <- prj op = liftStruct2 (sugarSymPrim BitXor)  <$> goAST a <*> goAST b
        | Just ShiftL  <- prj op = liftStruct2 (sugarSymPrim ShiftL)  <$> goAST a <*> goAST b
        | Just ShiftR  <- prj op = liftStruct2 (sugarSymPrim ShiftR)  <$> goAST a <*> goAST b
    go _ arrIx (i :* Nil)
        | Just (ArrIx arr) <- prj arrIx = do
            i' <- goSmallAST i
            return $ Single $ sugarSymPrim (ArrIx arr) i'
    go ty cond (c :* t :* f :* Nil)
        | Just Cond <- prj cond = do
            env <- ask
            case (flip runReaderT env $ goAST t, flip runReaderT env $ goAST f) of
              (t',f') -> do
                  tView <- lift $ lift $ Oper.viewT t'
                  fView <- lift $ lift $ Oper.viewT f'
                  case (tView,fView) of
                      (Oper.Return (Single tExp), Oper.Return (Single fExp)) -> do
                          c' <- goSmallAST c
                          return $ Single $ sugarSymPrim Cond c' tExp fExp
                      _ -> do
                          c'  <- goSmallAST c
                          res <- newRefV ty "v"
                          ReaderT $ \env -> iff c'
                              (flip runReaderT env . setRefV res =<< t')
                              (flip runReaderT env . setRefV res =<< f')
                          unsafeFreezeRefV res
    go t divBal (a :* b :* Nil)
        | Just DivBalanced <- prj divBal
        = liftStruct2 (sugarSymPrim Quot) <$> goAST a <*> goAST b
    go t guard (cond :* a :* Nil)
        | Just (GuardVal c msg) <- prj guard
        = do cs <- asks (compilerAssertions . envOptions)
             when (cs `includes` c) $ do
               cond' <- extractSingle <$> goAST cond
               lift $ assert cond' msg
             goAST a
    go t loop (len :* init :* (lami :$ (lams :$ body)) :* Nil)
        | Just ForLoop   <- prj loop
        , Just (LamT iv) <- prj lami
        , Just (LamT sv) <- prj lams
        = do len'  <- goSmallAST len
             state <- initRefV "state" =<< goAST init
             ReaderT $ \env -> for (0, 1, Excl len') $ \i -> flip runReaderT env $ do
                s <- case t of
                    Single _ -> unsafeFreezeRefV state  -- For non-compound states
                    _        -> getRefV state
                s' <- localAlias iv (Single i) $
                        localAlias sv s $
                          goAST body
                setRefV state s'
             unsafeFreezeRefV state
    go _ free Nil
        | Just (FreeVar v) <- prj free = return $ Single $ sugarSymPrim $ FreeVar v
    go t unsPerf Nil
        | Just (UnsafePerform prog) <- prj unsPerf
        = translateExp =<<
            Oper.reexpressEnv unsafeTransSmallExp (Oper.liftProgram $ unComp prog)
    go _ s _ = error $ "translateExp: no handling of symbol " ++ renderSym s

-- | Translate an expression that is assumed to fulfill @`PrimType` a@
unsafeTransSmallExp :: Monad m => Data a -> TargetT m (Prim a)
unsafeTransSmallExp a = do
    Single b <- translateExp a
    return b
  -- This function should ideally have a `PrimType' a` constraint, but that is
  -- not allowed when passing it to `reexpressEnv`. It should be possible to
  -- make it work by changing the interface to `reexpressEnv`.

translate :: Env -> Run a -> ProgC a
translate env
    = Oper.interpretWithMonadT Oper.singleton id
        -- fuse the monad stack
    . flip runReaderT env . Oper.reexpressEnv unsafeTransSmallExp
        -- compile outer monad
    . Oper.interpretWithMonadT Oper.singleton
        (lift . flip runReaderT env . Oper.reexpressEnv unsafeTransSmallExp)
        -- compile inner monad
    . unRun

instance (Imp.ControlCMD Oper.:<: instr) =>
    Oper.Reexpressible AssertCMD instr Env
    reexpressInstrEnv reexp (Assert c cond msg) = do
        cs <- asks (compilerAssertions . envOptions)
        when (cs `includes` c) $
          (reexp cond >>= lift . flip Imp.assert msg)

-- * Back ends

-- | Interpret a program in the 'IO' monad
runIO :: MonadRun m => m a -> IO a
runIO = Imp.runIO . translate env0 . liftRun

-- | Interpret a program in the 'IO' monad
runIO' :: MonadRun m => m a -> IO a
    = Oper.interpretWithMonadBiT
        (return . evalExp)
        (Imp.interpretBi (return . evalExp))
    . unRun
    . liftRun
  -- Unlike `runIO`, this function does the interpretation directly, without
  -- first lowering the program. This might be faster, but I haven't done any
  -- measurements to se if it is.
  -- One disadvantage with `runIO'` is that it cannot handle expressions
  -- involving `IOSym`. But at the moment of writing this, we're not using those
  -- symbols for anything anyway.

-- | Like 'runIO' but with explicit input/output connected to @stdin@/@stdout@
captureIO :: MonadRun m
    => m a        -- ^ Program to run
    -> String     -- ^ Input to send to @stdin@
    -> IO String  -- ^ Result from @stdout@
captureIO = Imp.captureIO . translate env0 . liftRun

-- | Compile a program to C code represented as a string. To compile the
-- resulting C code, use something like
-- > cc -std=c99 YOURPROGRAM.c
-- This function returns only the first (main) module. To get all C translation
-- unit, use 'compileAll'.
compile' :: MonadRun m => CompilerOpts -> m a -> String
compile' opts = Imp.compile . translate (Env mempty opts) . liftRun

-- | Compile a program to C code represented as a string. To compile the
-- resulting C code, use something like
-- > cc -std=c99 YOURPROGRAM.c
-- This function returns only the first (main) module. To get all C translation
-- unit, use 'compileAll'.
-- By default, only assertions labeled with 'UserAssertion' will be included in
-- the generated code.
compile :: MonadRun m => m a -> String
compile = compile' def {compilerAssertions = onlyUserAssertions}

-- | Compile a program to C modules, each one represented as a pair of a name
-- and the code represented as a string
-- To compile the resulting C code, use something like
-- > cc -std=c99 YOURPROGRAM.c
compileAll' :: MonadRun m => CompilerOpts -> m a -> [(String, String)]
compileAll' opts = Imp.compileAll . translate (Env mempty opts) . liftRun

-- | Compile a program to C modules, each one represented as a pair of a name
-- and the code represented as a string
-- To compile the resulting C code, use something like
-- > cc -std=c99 YOURPROGRAM.c
-- By default, only assertions labeled with 'UserAssertion' will be included in
-- the generated code.
compileAll :: MonadRun m => m a -> [(String, String)]
compileAll = compileAll' def {compilerAssertions = onlyUserAssertions}

-- | Compile a program to C code and print it on the screen. To compile the
-- resulting C code, use something like
-- > cc -std=c99 YOURPROGRAM.c
icompile' :: MonadRun m => CompilerOpts -> m a -> IO ()
icompile' opts = Imp.icompile . translate (Env mempty opts) . liftRun

-- | Compile a program to C code and print it on the screen. To compile the
-- resulting C code, use something like
-- > cc -std=c99 YOURPROGRAM.c
-- By default, only assertions labeled with 'UserAssertion' will be included in
-- the generated code.
icompile :: MonadRun m => m a -> IO ()
icompile = icompile' def {compilerAssertions = onlyUserAssertions}

-- | Generate C code and use CC to check that it compiles (no linking)
compileAndCheck' :: MonadRun m
    => CompilerOpts
    -> ExternalCompilerOpts
    -> m a
    -> IO ()
compileAndCheck' opts eopts =
    Imp.compileAndCheck' eopts . translate (Env mempty opts) . liftRun

-- | Generate C code and use CC to check that it compiles (no linking)
-- By default, all assertions will be included in the generated code.
compileAndCheck :: MonadRun m => m a -> IO ()
compileAndCheck = compileAndCheck' def def

-- | Generate C code, use CC to compile it, and run the resulting executable
runCompiled' :: MonadRun m
    => CompilerOpts
    -> ExternalCompilerOpts
    -> m a
    -> IO ()
runCompiled' opts eopts =
    Imp.runCompiled' eopts . translate (Env mempty opts) . liftRun

-- | Generate C code, use CC to compile it, and run the resulting executable
-- By default, all assertions will be included in the generated code.
runCompiled :: MonadRun m => m a -> IO ()
runCompiled = runCompiled' def def

-- | Compile a program and make it available as an 'IO' function from 'String'
-- to 'String' (connected to @stdin@/@stdout@. respectively). Note that
-- compilation only happens once, even if the 'IO' function is used many times
-- in the body.
withCompiled' :: MonadRun m
    => CompilerOpts
    -> ExternalCompilerOpts
    -> m a  -- ^ Program to compile
    -> ((String -> IO String) -> IO b)
         -- ^ Function that has access to the compiled executable as a function
    -> IO b
withCompiled' opts eopts =
    Imp.withCompiled' eopts . translate (Env mempty opts) . liftRun

-- | Compile a program and make it available as an 'IO' function from 'String'
-- to 'String' (connected to @stdin@/@stdout@. respectively). Note that
-- compilation only happens once, even if the 'IO' function is used many times
-- in the body.
-- By default, all assertions will be included in the generated code.
withCompiled :: MonadRun m
    => m a  -- ^ Program to compile
    -> ((String -> IO String) -> IO b)
         -- ^ Function that has access to the compiled executable as a function
    -> IO b
withCompiled = withCompiled' def def {externalSilent = True}

-- | Like 'runCompiled'' but with explicit input/output connected to
-- @stdin@/@stdout@. Note that the program will be compiled every time the
-- function is applied to a string. In order to compile once and run many times,
-- use the function 'withCompiled''.
captureCompiled' :: MonadRun m
    => CompilerOpts
    -> ExternalCompilerOpts
    -> m a        -- ^ Program to run
    -> String     -- ^ Input to send to @stdin@
    -> IO String  -- ^ Result from @stdout@
captureCompiled' opts eopts =
    Imp.captureCompiled' eopts . translate (Env mempty opts) . liftRun

-- | Like 'runCompiled' but with explicit input/output connected to
-- @stdin@/@stdout@. Note that the program will be compiled every time the
-- function is applied to a string. In order to compile once and run many times,
-- use the function 'withCompiled'.
-- By default, all assertions will be included in the generated code.
captureCompiled :: MonadRun m
    => m a        -- ^ Program to run
    -> String     -- ^ Input to send to @stdin@
    -> IO String  -- ^ Result from @stdout@
captureCompiled = captureCompiled' def def

-- | Compare the content written to @stdout@ from the reference program and from
-- running the compiled C code
compareCompiled' :: MonadRun m
    => CompilerOpts
    -> ExternalCompilerOpts
    -> m a     -- ^ Program to run
    -> IO a    -- ^ Reference program
    -> String  -- ^ Input to send to @stdin@
    -> IO ()
compareCompiled' opts eopts =
    Imp.compareCompiled' eopts . translate (Env mempty opts) . liftRun

-- | Compare the content written to @stdout@ from the reference program and from
-- running the compiled C code
-- By default, all assertions will be included in the generated code.
compareCompiled :: MonadRun m
    => m a     -- ^ Program to run
    -> IO a    -- ^ Reference program
    -> String  -- ^ Input to send to @stdin@
    -> IO ()
compareCompiled = compareCompiled' def def