{-# LANGUAGE CPP                 #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Execute
-- Copyright   : [2014..2017] Trevor L. McDonell
--               [2014..2014] Vinod Grover (NVIDIA Corporation)
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <tmcdonell@cse.unsw.edu.au>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Execute (

  Execute(..), Gamma,
  executeAcc, executeAfun1,

) where

-- accelerate
import Data.Array.Accelerate.AST
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Representation               ( SliceIndex(..) )
import Data.Array.Accelerate.Array.Sugar                        hiding ( Foreign )
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Interpreter                        ( evalPrim, evalPrimConst, evalPrj )
import qualified Data.Array.Accelerate.Array.Sugar              as S
import qualified Data.Array.Accelerate.Array.Representation     as R

import Data.Array.Accelerate.LLVM.Array.Data
import Data.Array.Accelerate.LLVM.Compile
import Data.Array.Accelerate.LLVM.Foreign
import Data.Array.Accelerate.LLVM.State

import Data.Array.Accelerate.LLVM.CodeGen.Environment           ( Gamma )

import Data.Array.Accelerate.LLVM.Execute.Async                 hiding ( join )
import Data.Array.Accelerate.LLVM.Execute.Environment

-- library
import Control.Monad
import Control.Applicative                                      hiding ( Const )
import Prelude                                                  hiding ( exp, map, unzip, scanl, scanr, scanl1, scanr1 )


class (Remote arch, Foreign arch) => Execute arch where
  map           :: (Shape sh, Elt b)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh
                -> LLVM arch (Array sh b)

  generate      :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh
                -> LLVM arch (Array sh e)

  transform     :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh
                -> LLVM arch (Array sh e)

  backpermute   :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh
                -> LLVM arch (Array sh e)

  fold          :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> LLVM arch (Array sh e)

  fold1         :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> LLVM arch (Array sh e)

  foldSeg       :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> DIM1
                -> LLVM arch (Array (sh:.Int) e)

  fold1Seg      :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> DIM1
                -> LLVM arch (Array (sh:.Int) e)

  scanl         :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> LLVM arch (Array (sh:.Int) e)

  scanl1        :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> LLVM arch (Array (sh:.Int) e)

  scanl'        :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> LLVM arch (Array (sh:.Int) e, Array sh e)

  scanr         :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> LLVM arch (Array (sh:.Int) e)

  scanr1        :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> LLVM arch (Array (sh:.Int) e)

  scanr'        :: (Shape sh, Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> sh :. Int
                -> LLVM arch (Array (sh:.Int) e, Array sh e)

  permute       :: (Shape sh, Shape sh', Elt e)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> Bool
                -> sh
                -> Array sh' e
                -> LLVM arch (Array sh' e)

  stencil1      :: (Shape sh, Elt a, Elt b)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> Array sh a
                -> LLVM arch (Array sh b)

  stencil2      :: (Shape sh, Elt a, Elt b, Elt c)
                => ExecutableR arch
                -> Gamma aenv
                -> AvalR arch aenv
                -> StreamR arch
                -> Array sh a
                -> Array sh b
                -> LLVM arch (Array sh c)


-- Array expression evaluation
-- ---------------------------

-- Computations are evaluated by traversing the AST bottom up, and for each node
-- distinguishing between three cases:
--
--  1. If it is a Use node, we return a reference to the array data.
--
--  2. If it is a non-skeleton node, such as a let-binding or shape conversion,
--     then execute directly by updating the environment or similar.
--
--  3. If it is a skeleton node, then we need to execute the compiled kernel for
--     that node.
--
{-# INLINEABLE executeAcc #-}
executeAcc
    :: forall arch a. Execute arch
    => ExecAcc arch a
    -> LLVM arch a
executeAcc acc =
  get =<< async (executeOpenAcc acc Aempty)

{-# INLINEABLE executeAfun1 #-}
executeAfun1
    :: forall arch a b. (Execute arch, Arrays a)
    => ExecAfun arch (a -> b)
    -> a
    -> LLVM arch b
executeAfun1 afun arrs = do
  AsyncR _ a <- async (useRemoteAsync arrs)
  executeOpenAfun1 afun Aempty a


-- Execute an open array function of one argument
--
{-# INLINEABLE executeOpenAfun1 #-}
executeOpenAfun1
    :: Execute arch
    => ExecOpenAfun arch aenv (a -> b)
    -> AvalR arch aenv
    -> AsyncR arch a
    -> LLVM arch b
executeOpenAfun1 (Alam (Abody f)) aenv a = get =<< async (executeOpenAcc f (aenv `Apush` a))
executeOpenAfun1 _                _    _ = error "boop!"


-- Execute an open array computation
--
{-# INLINEABLE executeOpenAcc #-}
executeOpenAcc
    :: forall arch aenv arrs. Execute arch
    => ExecOpenAcc arch aenv arrs
    -> AvalR arch aenv
    -> StreamR arch
    -> LLVM arch arrs
executeOpenAcc EmbedAcc{} _ _ =
  $internalError "execute" "unexpected delayed array"
executeOpenAcc (ExecAcc kernel gamma pacc) aenv stream =
  case pacc of

    -- Array introduction
    Use arr                     -> return (toArr arr)
    Unit x                      -> newRemote Z . const =<< travE x

    -- Environment manipulation
    Avar ix                     -> do let AsyncR event arr = aprj ix aenv
                                      after stream event
                                      return arr
    Alet bnd body               -> do bnd'  <- async (executeOpenAcc bnd aenv)
                                      body' <- executeOpenAcc body (aenv `Apush` bnd') stream
                                      return body'
    Apply f a                   -> executeOpenAfun1 f aenv =<< async (executeOpenAcc a aenv)
    Atuple tup                  -> toAtuple <$> travT tup
    Aprj ix tup                 -> evalPrj ix . fromAtuple <$> travA tup
    Acond p t e                 -> acond t e =<< travE p
    Awhile p f a                -> awhile p f =<< travA a

    -- Foreign function
    Aforeign asm _ a            -> foreignA asm =<< travA a

    -- Producers
    Map _ a                     -> map kernel gamma aenv stream         =<< extent a
    Generate sh _               -> generate kernel gamma aenv stream    =<< travE sh
    Transform sh _ _ _          -> transform kernel gamma aenv stream   =<< travE sh
    Backpermute sh _ _          -> backpermute kernel gamma aenv stream =<< travE sh
    Reshape sh a                -> reshape <$> travE sh <*> travA a

    -- Consumers
    Fold _ _ a                  -> fold  kernel gamma aenv stream =<< extent a
    Fold1 _ a                   -> fold1 kernel gamma aenv stream =<< extent a
    FoldSeg _ _ a s             -> join $ foldSeg  kernel gamma aenv stream <$> extent a <*> extent s
    Fold1Seg _ a s              -> join $ fold1Seg kernel gamma aenv stream <$> extent a <*> extent s
    Scanl _ _ a                 -> scanl kernel gamma aenv stream =<< extent a
    Scanr _ _ a                 -> scanr kernel gamma aenv stream =<< extent a
    Scanl1 _ a                  -> scanl1 kernel gamma aenv stream =<< extent a
    Scanr1 _ a                  -> scanr1 kernel gamma aenv stream =<< extent a
    Scanl' _ _ a                -> scanl' kernel gamma aenv stream =<< extent a
    Scanr' _ _ a                -> scanr' kernel gamma aenv stream =<< extent a
    Permute _ d _ a             -> join $ permute kernel gamma aenv stream (inplace d) <$> extent a <*> travA d
    Stencil _ _ a               -> stencil1 kernel gamma aenv stream =<< travA a
    Stencil2 _ _ a _ b          -> join $ stencil2 kernel gamma aenv stream <$> travA a <*> travA b

    -- Removed by fusion
    Replicate{}                 -> fusionError
    Slice{}                     -> fusionError
    ZipWith{}                   -> fusionError

  where
    fusionError :: error
    fusionError = $internalError "execute" $ "unexpected fusible material: " ++ showPreAccOp pacc

    -- Term traversals
    -- ---------------
    travA :: ExecOpenAcc arch aenv a -> LLVM arch a
    travA acc = executeOpenAcc acc aenv stream

    travE :: ExecExp arch aenv t -> LLVM arch t
    travE exp = executeExp exp aenv stream

    travT :: Atuple (ExecOpenAcc arch aenv) t -> LLVM arch t
    travT NilAtup        = return ()
    travT (SnocAtup t a) = (,) <$> travT t <*> travA a

    -- get the extent of an embedded array
    extent :: Shape sh => ExecOpenAcc arch aenv (Array sh e) -> LLVM arch sh
    extent ExecAcc{}       = $internalError "executeOpenAcc" "expected delayed array"
    extent (EmbedAcc sh)   = travE sh
    extent (UnzipAcc _ ix) = let AsyncR _ a = aprj ix aenv
                             in  return $ shape a

    inplace :: ExecOpenAcc arch aenv a -> Bool
    inplace (ExecAcc _ _ Avar{}) = False
    inplace _                    = True

    -- Skeleton implementation
    -- -----------------------

    -- Change the shape of an array without altering its contents. This does not
    -- execute any kernel programs.
    reshape :: Shape sh => sh -> Array sh' e -> Array sh e
    reshape sh (Array sh' adata)
      = $boundsCheck "reshape" "shape mismatch" (size sh == R.size sh')
      $ Array (fromElt sh) adata

    -- Array level conditional
    acond :: ExecOpenAcc arch aenv a -> ExecOpenAcc arch aenv a -> Bool -> LLVM arch a
    acond yes _  True  = travA yes
    acond _   no False = travA no

    -- Array loops
    awhile :: ExecOpenAfun arch aenv (a -> Scalar Bool)
           -> ExecOpenAfun arch aenv (a -> a)
           -> a
           -> LLVM arch a
    awhile p f a = do
      e   <- checkpoint stream
      r   <- executeOpenAfun1 p aenv (AsyncR e a)
      ok  <- indexRemote r 0
      if ok then awhile p f =<< executeOpenAfun1 f aenv (AsyncR e a)
            else return a

    -- Foreign functions
    foreignA :: (Arrays a, Arrays b, Foreign arch, S.Foreign asm)
             => asm (a -> b)
             -> a
             -> LLVM arch b
    foreignA asm a =
      case foreignAcc (undefined :: arch) asm of
        Just f  -> f stream a
        Nothing -> $internalError "foreignA" "failed to recover foreign function the second time"

executeOpenAcc (UnzipAcc tup v) aenv stream = do
  let AsyncR event arr = aprj v aenv
  after stream event
  return $ unzip tup arr
  where
    unzip :: forall t sh e. (Elt t, Elt e) => TupleIdx (TupleRepr t) e -> Array sh t -> Array sh e
    unzip tix (Array sh adata) = Array sh $ go tix (eltType (undefined::t)) adata
      where
        go :: TupleIdx v e -> TupleType t' -> ArrayData t' -> ArrayData (EltRepr e)
        go (SuccTupIdx ix) (PairTuple t _) (AD_Pair x _)           = go ix t x
        go ZeroTupIdx      (PairTuple _ t) (AD_Pair _ x)
          | Just Refl <- matchTupleType t (eltType (undefined::e)) = x
        go _ _ _                                                   = $internalError "unzip" "inconsistent valuation"


-- Scalar expression evaluation
-- ----------------------------

{-# INLINEABLE executeExp #-}
executeExp
    :: Execute arch
    => ExecExp arch aenv t
    -> AvalR arch aenv
    -> StreamR arch
    -> LLVM arch t
executeExp exp aenv stream = executeOpenExp exp Empty aenv stream

{-# INLINEABLE executeOpenExp #-}
executeOpenExp
    :: forall arch env aenv exp. Execute arch
    => ExecOpenExp arch env aenv exp
    -> Val env
    -> AvalR arch aenv
    -> StreamR arch
    -> LLVM arch exp
executeOpenExp rootExp env aenv stream = travE rootExp
  where
    travE :: ExecOpenExp arch env aenv t -> LLVM arch t
    travE exp = case exp of
      Var ix                    -> return (prj ix env)
      Let bnd body              -> travE bnd >>= \x -> executeOpenExp body (env `Push` x) aenv stream
      Const c                   -> return (toElt c)
      PrimConst c               -> return (evalPrimConst c)
      PrimApp f x               -> evalPrim f <$> travE x
      Tuple t                   -> toTuple <$> travT t
      Prj ix e                  -> evalPrj ix . fromTuple <$> travE e
      Cond p t e                -> travE p >>= \x -> if x then travE t else travE e
      While p f x               -> while p f =<< travE x
      IndexAny                  -> return Any
      IndexNil                  -> return Z
      IndexCons sh sz           -> (:.) <$> travE sh <*> travE sz
      IndexHead sh              -> (\(_  :. ix) -> ix) <$> travE sh
      IndexTail sh              -> (\(ix :.  _) -> ix) <$> travE sh
      IndexSlice ix slix sh     -> indexSlice ix <$> travE slix <*> travE sh
      IndexFull ix slix sl      -> indexFull  ix <$> travE slix <*> travE sl
      ToIndex sh ix             -> toIndex   <$> travE sh  <*> travE ix
      FromIndex sh ix           -> fromIndex <$> travE sh  <*> travE ix
      Intersect sh1 sh2         -> intersect <$> travE sh1 <*> travE sh2
      Union sh1 sh2             -> union <$> travE sh1 <*> travE sh2
      ShapeSize sh              -> size  <$> travE sh
      Shape acc                 -> shape <$> travA acc
      Index acc ix              -> join $ index       <$> travA acc <*> travE ix
      LinearIndex acc ix        -> join $ indexRemote <$> travA acc <*> travE ix
      Foreign _ f x             -> foreignE f x

    -- Helpers
    -- -------

    travT :: Tuple (ExecOpenExp arch env aenv) t -> LLVM arch t
    travT tup = case tup of
      NilTup            -> return ()
      SnocTup t e       -> (,) <$> travT t <*> travE e

    travA :: ExecOpenAcc arch aenv a -> LLVM arch a
    travA acc = executeOpenAcc acc aenv stream

    foreignE :: ExecFun arch () (a -> b) -> ExecOpenExp arch env aenv a -> LLVM arch b
    foreignE (Lam (Body f)) x = travE x >>= \e -> executeOpenExp f (Empty `Push` e) Aempty stream
    foreignE _              _ = error "I bless the rains down in Africa"

    travF1 :: ExecOpenFun arch env aenv (a -> b) -> a -> LLVM arch b
    travF1 (Lam (Body f)) x = executeOpenExp f (env `Push` x) aenv stream
    travF1 _              _ = error "LANAAAAAAAA!"

    while :: ExecOpenFun arch env aenv (a -> Bool) -> ExecOpenFun arch env aenv (a -> a) -> a -> LLVM arch a
    while p f x = do
      ok <- travF1 p x
      if ok then while p f =<< travF1 f x
            else return x

    indexSlice :: (Elt slix, Elt sh, Elt sl)
               => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
               -> slix
               -> sh
               -> sl
    indexSlice ix slix sh = toElt $ restrict ix (fromElt slix) (fromElt sh)
      where
        restrict :: SliceIndex slix sl co sh -> slix -> sh -> sl
        restrict SliceNil              ()        ()       = ()
        restrict (SliceAll   sliceIdx) (slx, ()) (sl, sz) = (restrict sliceIdx slx sl, sz)
        restrict (SliceFixed sliceIdx) (slx,  _) (sl,  _) = restrict sliceIdx slx sl

    indexFull :: (Elt slix, Elt sh, Elt sl)
              => SliceIndex (EltRepr slix) (EltRepr sl) co (EltRepr sh)
              -> slix
              -> sl
              -> sh
    indexFull ix slix sl = toElt $ extend ix (fromElt slix) (fromElt sl)
      where
        extend :: SliceIndex slix sl co sh -> slix -> sl -> sh
        extend SliceNil              ()        ()       = ()
        extend (SliceAll sliceIdx)   (slx, ()) (sh, sz) = (extend sliceIdx slx sh, sz)
        extend (SliceFixed sliceIdx) (slx, sz) sh       = (extend sliceIdx slx sh, sz)

    index :: Shape sh => Array sh e -> sh -> LLVM arch e
    index arr ix = linearIndex arr (toIndex (shape arr) ix)

    linearIndex :: Array sh e -> Int -> LLVM arch e
    linearIndex arr ix = do
      block =<< checkpoint stream
      indexRemote arr ix