{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.FoldSeg
  where
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Exp                       ( indexHead )
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Fold
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )
import Control.Applicative
import Control.Monad
import Prelude                                                      as P
mkFoldSeg
    :: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh :. Int) e)
    -> IRDelayed Native aenv (Segments i)
    -> CodeGen (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFoldSeg aenv combine seed arr seg =
  (+++) <$> mkFoldSegS aenv combine (Just seed) arr seg
        <*> mkFoldSegP aenv combine (Just seed) arr seg
mkFold1Seg
    :: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRDelayed Native aenv (Array (sh :. Int) e)
    -> IRDelayed Native aenv (Segments i)
    -> CodeGen (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFold1Seg aenv combine arr seg =
  (+++) <$> mkFoldSegS aenv combine Nothing arr seg
        <*> mkFoldSegP aenv combine Nothing arr seg
mkFoldSegS
    :: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
    =>          Gamma            aenv
    ->          IRFun2    Native aenv (e -> e -> e)
    -> Maybe   (IRExp     Native aenv e)
    ->          IRDelayed Native aenv (Array (sh :. Int) e)
    ->          IRDelayed Native aenv (Segments i)
    -> CodeGen (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFoldSegS aenv combine mseed arr seg =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array (sh :. Int) e))
      paramEnv                  = envParam aenv
  in
  makeOpenAcc "foldSegS" (paramGang ++ paramOut ++ paramEnv) $ do
    
    ss <- indexHead <$> delayedExtent seg
    let test si = A.lt scalarType (A.fst si) end
        initial = A.pair start (lift 0)
        body :: IR (Int,Int) -> CodeGen (IR (Int,Int))
        body (A.unpair -> (s,inf)) = do
          
          
          
          s'  <- case rank (undefined::sh) of
                   0 -> return s
                   _ -> A.rem integralType s ss
          len <- A.fromIntegral integralType numType =<< app1 (delayedLinearIndex seg) s'
          sup <- A.add numType inf len
          r   <- case mseed of
                   Just seed -> do z <- seed
                                   reduceFromTo  inf sup (app2 combine) z (app1 (delayedLinearIndex arr))
                   Nothing   ->    reduce1FromTo inf sup (app2 combine)   (app1 (delayedLinearIndex arr))
          writeArray arrOut s r
          t <- A.add numType s (lift 1)
          return $ A.pair t sup
    void $ while test body initial
    return_
mkFoldSegP
    :: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
    =>          Gamma            aenv
    ->          IRFun2    Native aenv (e -> e -> e)
    -> Maybe   (IRExp     Native aenv e)
    ->          IRDelayed Native aenv (Array (sh :. Int) e)
    ->          IRDelayed Native aenv (Segments i)
    -> CodeGen (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFoldSegP aenv combine mseed arr seg =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array (sh :. Int) e))
      paramEnv                  = envParam aenv
  in
  makeOpenAcc "foldSegP" (paramGang ++ paramOut ++ paramEnv) $ do
    
    
    
    
    
    sz <- indexHead <$> delayedExtent arr
    ss <- do n <- indexHead <$> delayedExtent seg
             A.sub numType n (lift 1)
    imapFromTo start end $ \s -> do
      i   <- case rank (undefined::sh) of
               0 -> return s
               _ -> A.rem integralType s ss
      j   <- A.add numType i (lift 1)
      u   <- A.fromIntegral integralType numType =<< app1 (delayedLinearIndex seg) i
      v   <- A.fromIntegral integralType numType =<< app1 (delayedLinearIndex seg) j
      (inf,sup) <- A.unpair <$> case rank (undefined::sh) of
                     0 -> return (A.pair u v)
                     _ -> do q <- A.quot integralType s ss
                             a <- A.mul numType q sz
                             A.pair <$> A.add numType u a <*> A.add numType v a
      r   <- case mseed of
               Just seed -> do z <- seed
                               reduceFromTo  inf sup (app2 combine) z (app1 (delayedLinearIndex arr))
               Nothing   ->    reduce1FromTo inf sup (app2 combine)   (app1 (delayedLinearIndex arr))
      writeArray arrOut s r
    return_