module Data.Array.Accelerate.Utility.Loop (
   nest,
   nestLog2,
   ) where

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Acc, Exp)



nest, _nest ::
   (A.Arrays a) =>
   Exp Int -> (Acc a -> Acc a) -> Acc a -> Acc a
nest n0 f x0 =
   A.asnd $
   A.awhile
      (A.map (A.>* 0) . A.afst)
      (A.lift . (\(n, x) -> (A.map (subtract 1) n, f x)) . A.unlift)
      (A.lift (A.unit n0, x0))

_nest n0 f x0 =
   A.asnd $
   A.awhile
      (A.unit . (A.>* 0) . A.the . A.afst)
      (A.lift . (\(n, x) -> (A.unit $ A.the n - 1, f x)) . A.unlift)
      (A.lift (A.unit n0, x0))


nestLog2, _nestLog2 ::
   (A.Arrays a) =>
   Exp Int -> (Acc a -> Acc a) -> Acc a -> Acc a
nestLog2 n0 f x0 =
   A.asnd $
   A.awhile
      (A.map (A.>* 1) . A.afst)
      (A.lift . (\(n, x) -> (A.map (flip div 2 . (1+)) n, f x)) . A.unlift)
      (A.lift (A.unit n0, x0))

{-
This is an infinite loop,
because A.acond always has to generate code for both branches.
-}
_nestLog2 n0 f x0 =
   let go n x =
          A.acond (n A.<=* 1) x $
          go (div (n+1) 2) (f x)
   in  go n0 x0