{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.Fold
  where
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Constant
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Generate
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )
import Control.Applicative
import Prelude                                                      as P hiding ( length )
mkFold
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh :. Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array sh e))
mkFold aenv f z acc
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = (+++) <$> mkFoldAll  aenv f (Just z) acc
          <*> mkFoldFill aenv z
  | otherwise
  = (+++) <$> mkFoldDim  aenv f (Just z) acc
          <*> mkFoldFill aenv z
mkFold1
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRDelayed Native aenv (Array (sh :. Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array sh e))
mkFold1 aenv f acc
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = mkFoldAll aenv f Nothing acc
  | otherwise
  = mkFoldDim aenv f Nothing acc
mkFoldDim
  :: forall aenv sh e. (Shape sh, Elt e)
  =>          Gamma            aenv
  ->          IRFun2    Native aenv (e -> e -> e)
  -> Maybe   (IRExp     Native aenv e)
  ->          IRDelayed Native aenv (Array (sh :. Int) e)
  -> CodeGen (IROpenAcc Native aenv (Array sh e))
mkFoldDim aenv combine mseed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array sh e))
      paramEnv                  = envParam aenv
      
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      stride                    = local           scalarType ("ix.stride" :: Name Int)
  in
  makeOpenAcc "fold" (paramGang ++ paramStride : paramOut ++ paramEnv) $ do
    imapFromTo start end $ \seg -> do
      from <- mul numType seg  stride
      to   <- add numType from stride
      
      r    <- case mseed of
                Just seed -> do z <- seed
                                reduceFromTo  from to (app2 combine) z (app1 delayedLinearIndex)
                Nothing   ->    reduce1FromTo from to (app2 combine)   (app1 delayedLinearIndex)
      writeArray arrOut seg r
    return_
mkFoldAll
    :: forall aenv e. Elt e
    =>          Gamma            aenv                           
    ->          IRFun2    Native aenv (e -> e -> e)             
    -> Maybe   (IRExp     Native aenv e)                        
    ->          IRDelayed Native aenv (Vector e)                
    -> CodeGen (IROpenAcc Native aenv (Scalar e))
mkFoldAll aenv combine mseed arr =
  foldr1 (+++) <$> sequence [ mkFoldAllS  aenv combine mseed arr
                            , mkFoldAllP1 aenv combine       arr
                            , mkFoldAllP2 aenv combine mseed
                            ]
mkFoldAllS
    :: forall aenv e. Elt e
    =>          Gamma            aenv                           
    ->          IRFun2    Native aenv (e -> e -> e)             
    -> Maybe   (IRExp     Native aenv e)                        
    ->          IRDelayed Native aenv (Vector e)                
    -> CodeGen (IROpenAcc Native aenv (Scalar e))
mkFoldAllS aenv combine mseed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      paramEnv                  = envParam aenv
      (arrOut,  paramOut)       = mutableArray ("out" :: Name (Scalar e))
      zero                      = lift 0 :: IR Int
  in
  makeOpenAcc "foldAllS" (paramGang ++ paramOut ++ paramEnv) $ do
    r <- case mseed of
           Just seed -> do z <- seed
                           reduceFromTo  start end (app2 combine) z (app1 delayedLinearIndex)
           Nothing   ->    reduce1FromTo start end (app2 combine)   (app1 delayedLinearIndex)
    writeArray arrOut zero r
    return_
mkFoldAllP1
    :: forall aenv e. Elt e
    =>          Gamma            aenv                           
    ->          IRFun2    Native aenv (e -> e -> e)             
    ->          IRDelayed Native aenv (Vector e)                
    -> CodeGen (IROpenAcc Native aenv (Scalar e))
mkFoldAllP1 aenv combine IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      paramEnv                  = envParam aenv
      (arrTmp,  paramTmp)       = mutableArray ("tmp" :: Name (Vector e))
      length                    = local           scalarType ("ix.length" :: Name Int)
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramLength               = scalarParameter scalarType ("ix.length" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
  in
  makeOpenAcc "foldAllP1" (paramGang ++ paramLength : paramStride : paramTmp ++ paramEnv) $ do
    
    
    
    
    
    
    imapFromTo start end $ \i -> do
      inf <- A.mul numType i   stride
      a   <- A.add numType inf stride
      sup <- A.min scalarType a length
      r   <- reduce1FromTo inf sup (app2 combine) (app1 delayedLinearIndex)
      writeArray arrTmp i r
    return_
mkFoldAllP2
    :: forall aenv e. Elt e
    =>          Gamma            aenv                           
    ->          IRFun2    Native aenv (e -> e -> e)             
    -> Maybe   (IRExp     Native aenv e)                        
    -> CodeGen (IROpenAcc Native aenv (Scalar e))
mkFoldAllP2 aenv combine mseed =
  let
      (start, end, paramGang)   = gangParam
      paramEnv                  = envParam aenv
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Scalar e))
      zero                      = lift 0 :: IR Int
  in
  makeOpenAcc "foldAllP2" (paramGang ++ paramTmp ++ paramOut ++ paramEnv) $ do
    r <- case mseed of
           Just seed -> do z <- seed
                           reduceFromTo  start end (app2 combine) z (readArray arrTmp)
           Nothing   ->    reduce1FromTo start end (app2 combine)   (readArray arrTmp)
    writeArray arrOut zero r
    return_
mkFoldFill
    :: (Shape sh, Elt e)
    => Gamma aenv
    -> IRExp Native aenv e
    -> CodeGen (IROpenAcc Native aenv (Array sh e))
mkFoldFill aenv seed =
  mkGenerate aenv (IRFun1 (const seed))
reduceFromTo
    :: Elt a
    => IR Int                                   
    -> IR Int                                   
    -> (IR a -> IR a -> CodeGen (IR a))         
    -> IR a                                     
    -> (IR Int -> CodeGen (IR a))               
    -> CodeGen (IR a)
reduceFromTo m n f z get =
  iterFromTo m n z $ \i acc -> do
    x <- get i
    y <- f acc x
    return y
reduce1FromTo
    :: Elt a
    => IR Int                                   
    -> IR Int                                   
    -> (IR a -> IR a -> CodeGen (IR a))         
    -> (IR Int -> CodeGen (IR a))               
    -> CodeGen (IR a)
reduce1FromTo m n f get = do
  z  <- get m
  m1 <- add numType m (ir numType (num numType 1))
  reduceFromTo m1 n f z get