{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.CodeGen.Loop
-- Copyright   : [2015..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.CodeGen.Loop
  where

import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Monad

import Prelude                                                  hiding ( fst, snd, uncurry )
import Control.Monad


-- | TODO: Iterate over a multidimensional index space.
--
-- Build nested loops that iterate over a hyper-rectangular index space between
-- the given coordinates. The LLVM optimiser will be able to vectorise nested
-- loops, including when we insert conversions to the corresponding linear index
-- (e.g., in order to index arrays).
--
-- iterate
--     :: Shape sh
--     => Operands sh                                    -- ^ starting index
--     -> Operands sh                                    -- ^ final index
--     -> (Operands sh -> CodeGen (Operands a))          -- ^ body of the loop
--     -> CodeGen (Operands a)
-- iterate from to body = error "CodeGen.Loop.iterate"


-- | Execute the given function at each index in the range
--
imapFromStepTo
    :: forall i arch. IsNum i
    => Operands i                                     -- ^ starting index (inclusive)
    -> Operands i                                     -- ^ step size
    -> Operands i                                     -- ^ final index (exclusive)
    -> (Operands i -> CodeGen arch ())                -- ^ loop body
    -> CodeGen arch ()
imapFromStepTo :: Operands i
-> Operands i
-> Operands i
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
imapFromStepTo Operands i
start Operands i
step Operands i
end Operands i -> CodeGen arch ()
body =
  TypeR i
-> Operands i
-> (Operands i -> CodeGen arch (Operands Bool))
-> (Operands i -> CodeGen arch (Operands i))
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
forall i arch.
TypeR i
-> Operands i
-> (Operands i -> CodeGen arch (Operands Bool))
-> (Operands i -> CodeGen arch (Operands i))
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
for (ScalarType i -> TypeR i
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ScalarType i -> TypeR i) -> ScalarType i -> TypeR i
forall a b. (a -> b) -> a -> b
$ SingleType i -> ScalarType i
forall a. SingleType a -> ScalarType a
SingleScalarType (SingleType i -> ScalarType i) -> SingleType i -> ScalarType i
forall a b. (a -> b) -> a -> b
$ NumType i -> SingleType i
forall a. NumType a -> SingleType a
NumSingleType NumType i
num) Operands i
start
      (\Operands i
i -> SingleType i
-> Operands i -> Operands i -> CodeGen arch (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
lt (NumType i -> SingleType i
forall a. NumType a -> SingleType a
NumSingleType NumType i
num) Operands i
i Operands i
end)
      (\Operands i
i -> NumType i -> Operands i -> Operands i -> CodeGen arch (Operands i)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
add NumType i
num Operands i
i Operands i
step)
      Operands i -> CodeGen arch ()
body
  where num :: NumType i
num = IsNum i => NumType i
forall a. IsNum a => NumType a
numType @i


-- | Iterate with an accumulator between given start and end indices, executing
-- the given function at each.
--
iterFromStepTo
    :: forall i a arch. IsNum i
    => TypeR a
    -> Operands i                                     -- ^ starting index (inclusive)
    -> Operands i                                     -- ^ step size
    -> Operands i                                     -- ^ final index (exclusive)
    -> Operands a                                     -- ^ initial value
    -> (Operands i -> Operands a -> CodeGen arch (Operands a))    -- ^ loop body
    -> CodeGen arch (Operands a)
iterFromStepTo :: TypeR a
-> Operands i
-> Operands i
-> Operands i
-> Operands a
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iterFromStepTo TypeR a
tp Operands i
start Operands i
step Operands i
end Operands a
seed Operands i -> Operands a -> CodeGen arch (Operands a)
body =
  TypeR i
-> TypeR a
-> Operands i
-> Operands a
-> (Operands i -> CodeGen arch (Operands Bool))
-> (Operands i -> CodeGen arch (Operands i))
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
forall i a arch.
TypeR i
-> TypeR a
-> Operands i
-> Operands a
-> (Operands i -> CodeGen arch (Operands Bool))
-> (Operands i -> CodeGen arch (Operands i))
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iter (ScalarType i -> TypeR i
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ScalarType i -> TypeR i) -> ScalarType i -> TypeR i
forall a b. (a -> b) -> a -> b
$ SingleType i -> ScalarType i
forall a. SingleType a -> ScalarType a
SingleScalarType (SingleType i -> ScalarType i) -> SingleType i -> ScalarType i
forall a b. (a -> b) -> a -> b
$ NumType i -> SingleType i
forall a. NumType a -> SingleType a
NumSingleType NumType i
num) TypeR a
tp Operands i
start Operands a
seed
       (\Operands i
i -> SingleType i
-> Operands i -> Operands i -> CodeGen arch (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
lt (NumType i -> SingleType i
forall a. NumType a -> SingleType a
NumSingleType NumType i
num) Operands i
i Operands i
end)
       (\Operands i
i -> NumType i -> Operands i -> Operands i -> CodeGen arch (Operands i)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
add NumType i
num Operands i
i Operands i
step)
       Operands i -> Operands a -> CodeGen arch (Operands a)
body
  where num :: NumType i
num = IsNum i => NumType i
forall a. IsNum a => NumType a
numType @i


-- | A standard 'for' loop.
--
for :: TypeR i
    -> Operands i                                         -- ^ starting index
    -> (Operands i -> CodeGen arch (Operands Bool))       -- ^ loop test to keep going
    -> (Operands i -> CodeGen arch (Operands i))          -- ^ increment loop counter
    -> (Operands i -> CodeGen arch ())                    -- ^ body of the loop
    -> CodeGen arch ()
for :: TypeR i
-> Operands i
-> (Operands i -> CodeGen arch (Operands Bool))
-> (Operands i -> CodeGen arch (Operands i))
-> (Operands i -> CodeGen arch ())
-> CodeGen arch ()
for TypeR i
tp Operands i
start Operands i -> CodeGen arch (Operands Bool)
test Operands i -> CodeGen arch (Operands i)
incr Operands i -> CodeGen arch ()
body =
  CodeGen arch (Operands i) -> CodeGen arch ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (CodeGen arch (Operands i) -> CodeGen arch ())
-> CodeGen arch (Operands i) -> CodeGen arch ()
forall a b. (a -> b) -> a -> b
$ TypeR i
-> (Operands i -> CodeGen arch (Operands Bool))
-> (Operands i -> CodeGen arch (Operands i))
-> Operands i
-> CodeGen arch (Operands i)
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while TypeR i
tp Operands i -> CodeGen arch (Operands Bool)
test (\Operands i
i -> Operands i -> CodeGen arch ()
body Operands i
i CodeGen arch ()
-> CodeGen arch (Operands i) -> CodeGen arch (Operands i)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Operands i -> CodeGen arch (Operands i)
incr Operands i
i) Operands i
start


-- | An loop with iteration count and accumulator.
--
iter :: TypeR i
     -> TypeR a
     -> Operands i                                                -- ^ starting index
     -> Operands a                                                -- ^ initial value
     -> (Operands i -> CodeGen arch (Operands Bool))              -- ^ index test to keep looping
     -> (Operands i -> CodeGen arch (Operands i))                 -- ^ increment loop counter
     -> (Operands i -> Operands a -> CodeGen arch (Operands a))   -- ^ loop body
     -> CodeGen arch (Operands a)
iter :: TypeR i
-> TypeR a
-> Operands i
-> Operands a
-> (Operands i -> CodeGen arch (Operands Bool))
-> (Operands i -> CodeGen arch (Operands i))
-> (Operands i -> Operands a -> CodeGen arch (Operands a))
-> CodeGen arch (Operands a)
iter TypeR i
tpi TypeR a
tpa Operands i
start Operands a
seed Operands i -> CodeGen arch (Operands Bool)
test Operands i -> CodeGen arch (Operands i)
incr Operands i -> Operands a -> CodeGen arch (Operands a)
body = do
  Operands (i, a)
r <- TypeR (i, a)
-> (Operands (i, a) -> CodeGen arch (Operands Bool))
-> (Operands (i, a) -> CodeGen arch (Operands (i, a)))
-> Operands (i, a)
-> CodeGen arch (Operands (i, a))
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while (TypeR i -> TypeR a -> TypeR (i, a)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
TupRpair TypeR i
tpi TypeR a
tpa)
             (Operands i -> CodeGen arch (Operands Bool)
test (Operands i -> CodeGen arch (Operands Bool))
-> (Operands (i, a) -> Operands i)
-> Operands (i, a)
-> CodeGen arch (Operands Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands (i, a) -> Operands i
forall a b. Operands (a, b) -> Operands a
fst)
             (\Operands (i, a)
v -> do Operands a
v' <- (Operands i -> Operands a -> CodeGen arch (Operands a))
-> Operands (i, a) -> CodeGen arch (Operands a)
forall a b c.
(Operands a -> Operands b -> c) -> Operands (a, b) -> c
uncurry Operands i -> Operands a -> CodeGen arch (Operands a)
body Operands (i, a)
v     -- update value and then...
                       Operands i
i' <- Operands i -> CodeGen arch (Operands i)
incr (Operands (i, a) -> Operands i
forall a b. Operands (a, b) -> Operands a
fst Operands (i, a)
v)       -- ...calculate new index
                       Operands (i, a) -> CodeGen arch (Operands (i, a))
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands (i, a) -> CodeGen arch (Operands (i, a)))
-> Operands (i, a) -> CodeGen arch (Operands (i, a))
forall a b. (a -> b) -> a -> b
$ Operands i -> Operands a -> Operands (i, a)
forall a b. Operands a -> Operands b -> Operands (a, b)
pair Operands i
i' Operands a
v')
             (Operands i -> Operands a -> Operands (i, a)
forall a b. Operands a -> Operands b -> Operands (a, b)
pair Operands i
start Operands a
seed)
  Operands a -> CodeGen arch (Operands a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Operands a -> CodeGen arch (Operands a))
-> Operands a -> CodeGen arch (Operands a)
forall a b. (a -> b) -> a -> b
$ Operands (i, a) -> Operands a
forall a b. Operands (a, b) -> Operands b
snd Operands (i, a)
r


-- | A standard 'while' loop
--
while :: TypeR a
      -> (Operands a -> CodeGen arch (Operands Bool))
      -> (Operands a -> CodeGen arch (Operands a))
      -> Operands a
      -> CodeGen arch (Operands a)
while :: TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while TypeR a
tp Operands a -> CodeGen arch (Operands Bool)
test Operands a -> CodeGen arch (Operands a)
body Operands a
start = do
  Block
loop <- String -> CodeGen arch Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock   String
"while.top"
  Block
exit <- String -> CodeGen arch Block
forall arch. HasCallStack => String -> CodeGen arch Block
newBlock   String
"while.exit"
  Block
_    <- String -> CodeGen arch Block
forall arch. HasCallStack => String -> CodeGen arch Block
beginBlock String
"while.entry"

  -- Entry: generate the initial value
  Operands Bool
p    <- Operands a -> CodeGen arch (Operands Bool)
test Operands a
start
  Block
top  <- Operands Bool -> Block -> Block -> CodeGen arch Block
forall arch.
HasCallStack =>
Operands Bool -> Block -> Block -> CodeGen arch Block
cbr Operands Bool
p Block
loop Block
exit

  -- Create the critical variable that will be used to accumulate the results
  Operands a
prev <- TypeR a -> CodeGen arch (Operands a)
forall a arch. TypeR a -> CodeGen arch (Operands a)
fresh TypeR a
tp

  -- Generate the loop body. Afterwards, we insert a phi node at the head of the
  -- instruction stream, which selects the input value depending on which edge
  -- we entered the loop from: top or bottom.
  --
  Block -> CodeGen arch ()
forall arch. Block -> CodeGen arch ()
setBlock Block
loop
  Operands a
next <- Operands a -> CodeGen arch (Operands a)
body Operands a
prev
  Operands Bool
p'   <- Operands a -> CodeGen arch (Operands Bool)
test Operands a
next
  Block
bot  <- Operands Bool -> Block -> Block -> CodeGen arch Block
forall arch.
HasCallStack =>
Operands Bool -> Block -> Block -> CodeGen arch Block
cbr Operands Bool
p' Block
loop Block
exit

  Operands a
_    <- TypeR a
-> Block
-> Operands a
-> [(Operands a, Block)]
-> CodeGen arch (Operands a)
forall a arch.
HasCallStack =>
TypeR a
-> Block
-> Operands a
-> [(Operands a, Block)]
-> CodeGen arch (Operands a)
phi' TypeR a
tp Block
loop Operands a
prev [(Operands a
start,Block
top), (Operands a
next,Block
bot)]

  -- Now the loop exit
  Block -> CodeGen arch ()
forall arch. Block -> CodeGen arch ()
setBlock Block
exit
  TypeR a -> [(Operands a, Block)] -> CodeGen arch (Operands a)
forall arch a.
HasCallStack =>
TypeR a -> [(Operands a, Block)] -> CodeGen arch (Operands a)
phi TypeR a
tp [(Operands a
start,Block
top), (Operands a
next,Block
bot)]