{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RebindableSyntax    #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.Scan
  where
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.Analysis.Match
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR                        ( IR )
import Data.Array.Accelerate.LLVM.CodeGen.Loop
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Generate
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )
import Control.Applicative
import Control.Monad
import Data.String                                                  ( fromString )
import Data.Coerce                                                  as Safe
import Prelude                                                      as P
data Direction = L | R
mkScanl
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanl aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScanS L aenv combine (Just seed) arr
                              , mkScanP L aenv combine (Just seed) arr
                              , mkScanFill aenv seed
                              ]
  
  | otherwise
  = (+++) <$> mkScanS L aenv combine (Just seed) arr
          <*> mkScanFill aenv seed
mkScanl1
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanl1 aenv combine arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = (+++) <$> mkScanS L aenv combine Nothing arr
          <*> mkScanP L aenv combine Nothing arr
  
  | otherwise
  = mkScanS L aenv combine Nothing arr
mkScanl'
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScanl' aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScan'S L aenv combine seed arr
                              , mkScan'P L aenv combine seed arr
                              , mkScan'Fill aenv seed
                              ]
  
  | otherwise
  = (+++) <$> mkScan'S L aenv combine seed arr
          <*> mkScan'Fill aenv seed
mkScanr
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanr aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScanS R aenv combine (Just seed) arr
                              , mkScanP R aenv combine (Just seed) arr
                              , mkScanFill aenv seed
                              ]
  
  | otherwise
  = (+++) <$> mkScanS R aenv combine (Just seed) arr
          <*> mkScanFill aenv seed
mkScanr1
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanr1 aenv combine arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = (+++) <$> mkScanS R aenv combine Nothing arr
          <*> mkScanP R aenv combine Nothing arr
  
  | otherwise
  = mkScanS R aenv combine Nothing arr
mkScanr'
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma            aenv
    -> IRFun2    Native aenv (e -> e -> e)
    -> IRExp     Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScanr' aenv combine seed arr
  | Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
  = foldr1 (+++) <$> sequence [ mkScan'S R aenv combine seed arr
                              , mkScan'P R aenv combine seed arr
                              , mkScan'Fill aenv seed
                              ]
  
  | otherwise
  = (+++) <$> mkScan'S R aenv combine seed arr
          <*> mkScan'Fill aenv seed
mkScanFill
    :: (Shape sh, Elt e)
    => Gamma aenv
    -> IRExp Native aenv e
    -> CodeGen (IROpenAcc Native aenv (Array sh e))
mkScanFill aenv seed =
  mkGenerate aenv (IRFun1 (const seed))
mkScan'Fill
    :: forall aenv sh e. (Shape sh, Elt e)
    => Gamma aenv
    -> IRExp Native aenv e
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScan'Fill aenv seed =
  Safe.coerce <$> (mkScanFill aenv seed :: CodeGen (IROpenAcc Native aenv (Array sh e)))
mkScanS
    :: forall aenv sh e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e))
mkScanS dir aenv combine mseed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array (sh:.Int) e))
      paramEnv                  = envParam aenv
      
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
  in
  makeOpenAcc "scanS" (paramGang ++ paramOut ++ paramEnv) $ do
    sz    <- indexHead <$> delayedExtent
    szp1  <- A.add numType sz (lift 1)
    szm1  <- A.sub numType sz (lift 1)
    
    imapFromTo start end $ \seg -> do
      
      
      i0 <- case dir of
              L -> A.mul numType sz seg
              R -> do x <- A.mul numType sz seg
                      y <- A.add numType szm1 x
                      return y
      
      
      j0 <- case mseed of
              Nothing -> return i0        
              Just{}  -> case dir of
                           L -> A.mul numType szp1 seg
                           R -> do x <- A.mul numType szp1 seg
                                   y <- A.add numType x sz
                                   return y
      
      
      (v0,i1) <- case mseed of
                   Just seed -> (,) <$> seed                       <*> pure i0
                   Nothing   -> (,) <$> app1 delayedLinearIndex i0 <*> next i0
      
      writeArray arrOut j0 v0
      j1 <- next j0
      iz <- case dir of
              L -> A.add numType i0 sz
              R -> A.sub numType i0 sz
      let cont i = case dir of
                     L -> A.lt scalarType i iz
                     R -> A.gt scalarType i iz
      void $ while (cont . A.fst3)
                   (\(A.untrip -> (i,j,v)) -> do
                       u  <- app1 delayedLinearIndex i
                       v' <- case dir of
                               L -> app2 combine v u
                               R -> app2 combine u v
                       writeArray arrOut j v'
                       A.trip <$> next i <*> next j <*> pure v')
                   (A.trip i1 j1 v0)
    return_
mkScan'S
    :: forall aenv sh e. (Shape sh, Elt e)
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> IRExp Native aenv e
    -> IRDelayed Native aenv (Array (sh:.Int) e)
    -> CodeGen (IROpenAcc Native aenv (Array (sh:.Int) e, Array sh e))
mkScan'S dir aenv combine seed IRDelayed{..} =
  let
      (start, end, paramGang)   = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Array (sh:.Int) e))
      (arrSum, paramSum)        = mutableArray ("sum" :: Name (Array sh e))
      paramEnv                  = envParam aenv
      
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
  in
  makeOpenAcc "scanS" (paramGang ++ paramOut ++ paramSum ++ paramEnv) $ do
    sz    <- indexHead <$> delayedExtent
    szm1  <- A.sub numType sz (lift 1)
    
    imapFromTo start end $ \seg -> do
      
      i0 <- case dir of
              L -> A.mul numType seg sz
              R -> do x <- A.mul numType sz seg
                      y <- A.add numType x szm1
                      return y
      
      v0 <- seed
      iz <- case dir of
              L -> A.add numType i0 sz
              R -> A.sub numType i0 sz
      let cont i  = case dir of
                      L -> A.lt scalarType i iz
                      R -> A.gt scalarType i iz
      
      
      
      r  <- while (cont . A.fst)
                  (\(A.unpair -> (i,v)) -> do
                      writeArray arrOut i v
                      u  <- app1 delayedLinearIndex i
                      v' <- case dir of
                              L -> app2 combine v u
                              R -> app2 combine u v
                      i' <- next i
                      return $ A.pair i' v')
                  (A.pair i0 v0)
      
      writeArray arrSum seg (A.snd r)
    return_
mkScanP
    :: forall aenv e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> IRDelayed Native aenv (Vector e)
    -> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP dir aenv combine mseed arr =
  foldr1 (+++) <$> sequence [ mkScanP1 dir aenv combine mseed arr
                            , mkScanP2 dir aenv combine
                            , mkScanP3 dir aenv combine mseed
                            ]
mkScanP1
    :: forall aenv e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> IRDelayed Native aenv (Vector e)
    -> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP1 dir aenv combine mseed IRDelayed{..} =
  let
      (chunk, _, paramGang)     = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      
      steps                     = local           scalarType ("ix.steps"  :: Name Int)
      paramSteps                = scalarParameter scalarType ("ix.steps"  :: Name Int)
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
      firstChunk                = case dir of
                                    L -> lift 0
                                    R -> steps
  in
  makeOpenAcc "scanP1" (paramGang ++ paramStride : paramSteps : paramOut ++ paramTmp ++ paramEnv) $ do
    len <- indexHead <$> delayedExtent
    
    
    
    
    
    
    inf <- A.mul numType chunk stride
    a   <- A.add numType inf   stride
    sup <- A.min scalarType a  len
    
    
    i0  <- case dir of
             L -> return inf
             R -> next sup
    
    
    
    j0  <- case mseed of
             Nothing -> return i0
             Just _  -> case dir of
                          L -> if A.eq scalarType chunk firstChunk
                                 then return i0
                                 else next i0
                          R -> if A.eq scalarType chunk firstChunk
                                 then return sup
                                 else return i0
    
    
    (v0,i1) <- A.unpair <$> case mseed of
                 Just seed -> if A.eq scalarType chunk firstChunk
                                then A.pair <$> seed                       <*> pure i0
                                else A.pair <$> app1 delayedLinearIndex i0 <*> next i0
                 Nothing   ->        A.pair <$> app1 delayedLinearIndex i0 <*> next i0
    
    writeArray arrOut j0 v0
    j1  <- next j0
    
    let cont i =
           case dir of
             L -> A.lt  scalarType i sup
             R -> A.gte scalarType i inf
    r   <- while (cont . A.fst3)
                 (\(A.untrip -> (i,j,v)) -> do
                     u  <- app1 delayedLinearIndex i
                     v' <- case dir of
                             L -> app2 combine v u
                             R -> app2 combine u v
                     writeArray arrOut j v'
                     A.trip <$> next i <*> next j <*> pure v')
                 (A.trip i1 j1 v0)
    
    writeArray arrTmp chunk (A.thd3 r)
    return_
mkScanP2
    :: forall aenv e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP2 dir aenv combine =
  let
      (start, end, paramGang)   = gangParam
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      
      cont i                    = case dir of
                                    L -> A.lt  scalarType i end
                                    R -> A.gte scalarType i start
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
  in
  makeOpenAcc "scanP2" (paramGang ++ paramTmp ++ paramEnv) $ do
    i0 <- case dir of
            L -> return start
            R -> next end
    v0 <- readArray arrTmp i0
    i1 <- next i0
    void $ while (cont . A.fst)
                 (\(A.unpair -> (i,v)) -> do
                    u  <- readArray arrTmp i
                    i' <- next i
                    v' <- case dir of
                            L -> app2 combine v u
                            R -> app2 combine u v
                    writeArray arrTmp i v'
                    return $ A.pair i' v')
                 (A.pair i1 v0)
    return_
mkScanP3
    :: forall aenv e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> Maybe (IRExp Native aenv e)
    -> CodeGen (IROpenAcc Native aenv (Vector e))
mkScanP3 dir aenv combine mseed =
  let
      (chunk, _, paramGang)     = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
      prev i                    = case dir of
                                    L -> A.sub numType i (lift 1)
                                    R -> A.add numType i (lift 1)
  in
  makeOpenAcc "scanP3" (paramGang ++ paramStride : paramOut ++ paramTmp ++ paramEnv) $ do
    
    
    a     <- case dir of
               L -> next chunk
               R -> pure chunk
    b     <- A.mul numType a stride
    c     <- A.add numType b stride
    d     <- A.min scalarType c (indexHead (irArrayShape arrOut))
    (inf,sup) <- case (dir,mseed) of
                   (L,Just _) -> (,) <$> next b <*> next d
                   _          -> (,) <$> pure b <*> pure d
    
    e     <- case dir of
               L -> pure chunk
               R -> prev chunk
    carry <- readArray arrTmp e
    imapFromTo inf sup $ \i -> do
      x <- readArray arrOut i
      y <- case dir of
             L -> app2 combine carry x
             R -> app2 combine x carry
      writeArray arrOut i y
    return_
mkScan'P
    :: forall aenv e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> IRExp Native aenv e
    -> IRDelayed Native aenv (Vector e)
    -> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P dir aenv combine seed arr =
  foldr1 (+++) <$> sequence [ mkScan'P1 dir aenv combine seed arr
                            , mkScan'P2 dir aenv combine
                            , mkScan'P3 dir aenv combine
                            ]
mkScan'P1
    :: forall aenv e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> IRExp Native aenv e
    -> IRDelayed Native aenv (Vector e)
    -> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P1 dir aenv combine seed IRDelayed{..} =
  let
      (chunk, _, paramGang)     = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      
      steps                     = local           scalarType ("ix.steps"  :: Name Int)
      paramSteps                = scalarParameter scalarType ("ix.steps"  :: Name Int)
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
      firstChunk                = case dir of
                                    L -> lift 0
                                    R -> steps
  in
  makeOpenAcc "scanP1" (paramGang ++ paramStride : paramSteps : paramOut ++ paramTmp ++ paramEnv) $ do
    
    
    len <- indexHead <$> delayedExtent
    inf <- A.mul numType chunk stride
    a   <- A.add numType inf   stride
    sup <- A.min scalarType a  len
    
    i0 <- case dir of
            L -> return inf
            R -> next sup
    
    
    
    j0      <- if A.eq scalarType chunk firstChunk
                 then pure i0
                 else next i0
    
    
    (v0,i1) <- A.unpair <$> if A.eq scalarType chunk firstChunk
                              then A.pair <$> seed                       <*> pure i0
                              else A.pair <$> app1 delayedLinearIndex i0 <*> pure j0
    
    writeArray arrOut j0 v0
    j1 <- next j0
    
    let cont i =
           case dir of
             L -> A.lt  scalarType i sup
             R -> A.gte scalarType i inf
    r  <- while (cont . A.fst3)
                (\(A.untrip-> (i,j,v)) -> do
                    u  <- app1 delayedLinearIndex i
                    v' <- case dir of
                            L -> app2 combine v u
                            R -> app2 combine u v
                    writeArray arrOut j v'
                    A.trip <$> next i <*> next j <*> pure v')
                (A.trip i1 j1 v0)
    
    writeArray arrTmp chunk (A.thd3 r)
    return_
mkScan'P2
    :: forall aenv e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P2 dir aenv combine =
  let
      (start, end, paramGang)   = gangParam
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      (arrSum, paramSum)        = mutableArray ("sum" :: Name (Scalar e))
      paramEnv                  = envParam aenv
      
      cont i                    = case dir of
                                    L -> A.lt  scalarType i end
                                    R -> A.gte scalarType i start
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
  in
  makeOpenAcc "scanP2" (paramGang ++ paramSum ++ paramTmp ++ paramEnv) $ do
    i0 <- case dir of
            L -> return start
            R -> next end
    v0 <- readArray arrTmp i0
    i1 <- next i0
    r  <- while (cont . A.fst)
                (\(A.unpair -> (i,v)) -> do
                   u  <- readArray arrTmp i
                   i' <- next i
                   v' <- case dir of
                           L -> app2 combine v u
                           R -> app2 combine u v
                   writeArray arrTmp i v'
                   return $ A.pair i' v')
                (A.pair i1 v0)
    writeArray arrSum (lift 0 :: IR Int) (A.snd r)
    return_
mkScan'P3
    :: forall aenv e. Elt e
    => Direction
    -> Gamma aenv
    -> IRFun2 Native aenv (e -> e -> e)
    -> CodeGen (IROpenAcc Native aenv (Vector e, Scalar e))
mkScan'P3 dir aenv combine =
  let
      (chunk, _, paramGang)     = gangParam
      (arrOut, paramOut)        = mutableArray ("out" :: Name (Vector e))
      (arrTmp, paramTmp)        = mutableArray ("tmp" :: Name (Vector e))
      paramEnv                  = envParam aenv
      
      stride                    = local           scalarType ("ix.stride" :: Name Int)
      paramStride               = scalarParameter scalarType ("ix.stride" :: Name Int)
      
      next i                    = case dir of
                                    L -> A.add numType i (lift 1)
                                    R -> A.sub numType i (lift 1)
      prev i                    = case dir of
                                    L -> A.sub numType i (lift 1)
                                    R -> A.add numType i (lift 1)
  in
  makeOpenAcc "scanP3" (paramGang ++ paramStride : paramOut ++ paramTmp ++ paramEnv) $ do
    
    
    a     <- case dir of
               L -> next chunk
               R -> pure chunk
    b     <- A.mul numType a stride
    c     <- A.add numType b stride
    d     <- A.min scalarType c (indexHead (irArrayShape arrOut))
    inf   <- next b
    sup   <- next d
    
    e     <- case dir of
               L -> pure chunk
               R -> prev chunk
    carry <- readArray arrTmp e
    imapFromTo inf sup $ \i -> do
      x <- readArray arrOut i
      y <- case dir of
             L -> app2 combine carry x
             R -> app2 combine x carry
      writeArray arrOut i y
    return_