module Data.Array.Accelerate.Prelude (
  
  zip, unzip,
  
  
  foldAll, fold1All,
  
  
  prescanl, postscanl, prescanr, postscanr, 
  
  scanlSeg, scanlSeg', scanl1Seg, prescanlSeg, postscanlSeg, 
  scanrSeg, scanrSeg', scanr1Seg, prescanrSeg, postscanrSeg
  
) where
import Prelude   hiding (replicate, zip, unzip, map, scanl, scanl1, scanr, scanr1, zipWith,
                         filter, max, min, not, fst, snd, curry, uncurry)
import qualified Prelude
import Data.Array.Accelerate.Array.Sugar hiding ((!), ignore, shape, size, index)
import Data.Array.Accelerate.Language
zip :: (Shape sh, Elt a, Elt b) 
    => Acc (Array sh a)
    -> Acc (Array sh b)
    -> Acc (Array sh (a, b))
zip = zipWith (curry lift)
unzip :: (Shape sh, Elt a, Elt b)
      => Acc (Array sh (a, b))
      -> (Acc (Array sh a), Acc (Array sh b))
unzip arr = (map fst arr, map snd arr)
foldAll :: (Shape sh, Elt a)
        => (Exp a -> Exp a -> Exp a) 
        -> Exp a 
        -> Acc (Array sh a)
        -> Acc (Scalar a)
foldAll f e arr = fold f e (reshape (index1 $ size arr) arr)
fold1All :: (Shape sh, Elt a)
         => (Exp a -> Exp a -> Exp a) 
         -> Acc (Array sh a)
         -> Acc (Scalar a)
fold1All f arr = fold1 f (reshape (index1 $ size arr) arr)
prescanl :: Elt a
         => (Exp a -> Exp a -> Exp a)
         -> Exp a
         -> Acc (Vector a)
         -> Acc (Vector a)
prescanl f e = Prelude.fst . scanl' f e
postscanl :: Elt a
          => (Exp a -> Exp a -> Exp a)
          -> Exp a
          -> Acc (Vector a)
          -> Acc (Vector a)
postscanl f e = map (e `f`) . scanl1 f
prescanr :: Elt a
         => (Exp a -> Exp a -> Exp a)
         -> Exp a
         -> Acc (Vector a)
         -> Acc (Vector a)
prescanr f e = Prelude.fst . scanr' f e
postscanr :: Elt a
          => (Exp a -> Exp a -> Exp a)
          -> Exp a
          -> Acc (Vector a)
          -> Acc (Vector a)
postscanr f e = map (`f` e) . scanr1 f
scanlSeg :: Elt a
         => (Exp a -> Exp a -> Exp a)
         -> Exp a
         -> Acc (Vector a)
         -> Acc Segments
         -> Acc (Vector a)
scanlSeg f e arr seg = scans
  where
    
    
    
    scans      = scanl1Seg f idInjArr seg'
    idInjArr   = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) headFlags $ zip idsArr arrShifted
    headFlags  = permute (+) zerosArr' (\ix -> index1 $ segOffsets' ! ix)
               $ generate (shape seg) (const 1)
    arrShifted = backpermute nSh (\ix -> index1 $ shiftCoords ! ix) arr
    idsArr     = generate nSh (const e)
    
    
    
    shiftCoords = permute (+) zerosArr' (ilift1 $ \i -> i + (offsetArr ! index1 i) + 1) coords
    coords      = Prelude.fst $ scanl' (+) 0 onesArr
    offsetArr   = scanl1 max $ permute (+) zerosArr (\ix -> index1 $ segOffsets ! ix) segIxs
    segIxs      = Prelude.fst $ scanl' (+) 0 $ generate (index1 $ size seg) (const 1)
    segOffsets' = Prelude.fst $ scanl' (+) 0 seg'
    segOffsets  = Prelude.fst $ scanl' (+) 0 seg
    
    nSh       = index1 $ size arr + size seg
    seg'      = map (+ 1) seg
    onesArr   = generate (shape arr) (const 1)
    zerosArr  = generate (shape arr) (const 0)
    zerosArr' = generate nSh (const 0)
scanlSeg' :: Elt a
          => (Exp a -> Exp a -> Exp a)
          -> Exp a
          -> Acc (Vector a)
          -> Acc Segments
          -> (Acc (Vector a), Acc (Vector a))
scanlSeg' f e arr seg = (scans, sums)
  where
    
    
    
    scans      = scanl1Seg f idInjArr seg
    idInjArr   = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) headFlags $ zip idsArr arrShifted
    headFlags  = permute (+) zerosArr (\ix -> index1 $ segOffsets ! ix)
               $ generate (shape seg) (const (1 :: Exp Int))
    segOffsets = Prelude.fst $ scanl' (+) 0 seg
    arrShifted = backpermute (shape arr) (ilift1 $ \i -> i ==* 0 ? (i, i  1)) arr
    idsArr     = generate (shape arr) (const e)
    zerosArr   = generate (shape arr) (const 0)
    
    
    sums       = map (`f` e)
               $ backpermute (shape seg) (\ix -> index1 $ sumOffsets ! ix)
               $ scanl1Seg f arr seg
    sumOffsets = map (subtract 1) $ scanl1 (+) seg
scanl1Seg :: Elt a
          => (Exp a -> Exp a -> Exp a)
          -> Acc (Vector a)
          -> Acc Segments
          -> Acc (Vector a)
scanl1Seg f arr seg = map snd $ scanl1 (mkSegApply f) $ zip (mkHeadFlags seg) arr
prescanlSeg :: Elt a
            => (Exp a -> Exp a -> Exp a)
            -> Exp a
            -> Acc (Vector a)
            -> Acc Segments
            -> Acc (Vector a)
prescanlSeg f e arr seg = Prelude.fst $ scanlSeg' f e arr seg
postscanlSeg :: Elt a
             => (Exp a -> Exp a -> Exp a)
             -> Exp a
             -> Acc (Vector a)
             -> Acc Segments
             -> Acc (Vector a)
postscanlSeg f e arr seg = map (e `f`) $ scanl1Seg f arr seg
scanrSeg :: Elt a
         => (Exp a -> Exp a -> Exp a)
         -> Exp a
         -> Acc (Vector a)
         -> Acc Segments
         -> Acc (Vector a)
scanrSeg f e arr seg = scans
  where
    
    scans      = scanr1Seg f idInjArr seg'
    idInjArr   = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) tailFlags $ zip idsArr arrShifted
    tailFlags  = permute (+) zerosArr' (\ix -> index1 $ (segOffsets' ! ix)  1)
               $ generate (shape seg) (const 1)
    arrShifted = backpermute nSh (\ix -> index1 $ shiftCoords ! ix) arr
    idsArr     = generate nSh (const e)
    
    shiftCoords = permute (+) zerosArr' (ilift1 $ \i -> i + (offsetArr ! index1 i)) coords
    coords      = Prelude.fst $ scanl' (+) 0 onesArr
    offsetArr   = scanl1 max $ permute (+) zerosArr (\ix -> index1 $ segOffsets ! ix) segIxs
    segIxs      = Prelude.fst $ scanl' (+) 0 $ generate (shape seg) (const 1)
    segOffsets' = scanl1 (+) seg'
    segOffsets  = Prelude.fst $ scanl' (+) 0 seg
    
    nSh       = index1 $ size arr + size seg
    seg'      = map (+ 1) seg
    onesArr   = generate (shape arr) (const 1)
    zerosArr  = generate (shape arr) (const 0)
    zerosArr' = generate nSh (const 0)
scanrSeg' :: Elt a
            => (Exp a -> Exp a -> Exp a)
            -> Exp a
            -> Acc (Vector a)
            -> Acc Segments
            -> (Acc (Vector a), Acc (Vector a))
scanrSeg' f e arr seg = (scans, sums)
  where
    
    scans      = scanr1Seg f idInjArr seg
    idInjArr   = zipWith (\t x -> t ==* 1 ? (fst x, snd x)) tailFlags $ zip idsArr arrShifted
    tailFlags  = permute (+) zerosArr (\ix -> index1 $ (segOffsets ! ix)  1)
               $ generate (shape seg) (const (1 :: Exp Int))
    segOffsets = scanl1 (+) seg
    arrShifted = backpermute (shape arr) (ilift1 $ \i -> i ==* (size arr  1) ? (i, i + 1)) arr
    idsArr     = generate (shape arr) (const e)
    zerosArr   = generate (shape arr) (const 0)
    
    sums       = map (`f` e) $  backpermute (shape seg) (\ix -> index1 $ sumOffsets ! ix)
               $ scanr1Seg f arr seg
    sumOffsets = Prelude.fst $ scanl' (+) 0 seg
scanr1Seg :: Elt a
          => (Exp a -> Exp a -> Exp a)
          -> Acc (Vector a)
          -> Acc Segments
          -> Acc (Vector a)
scanr1Seg f arr seg = map snd $ scanr1 (mkSegApply f) $ zip (mkTailFlags seg) arr
prescanrSeg :: Elt a
            => (Exp a -> Exp a -> Exp a)
            -> Exp a
            -> Acc (Vector a)
            -> Acc Segments
            -> Acc (Vector a)
prescanrSeg f e arr seg = Prelude.fst $ scanrSeg' f e arr seg
postscanrSeg :: Elt a
             => (Exp a -> Exp a -> Exp a)
             -> Exp a
             -> Acc (Vector a)
             -> Acc Segments
             -> Acc (Vector a)
postscanrSeg f e arr seg = map (`f` e) $ scanr1Seg f arr seg
mkHeadFlags :: Acc (Array DIM1 Int) -> Acc (Array DIM1 Int)
mkHeadFlags seg = permute (\_ _ -> 1) zerosArr (\ix -> index1 (segOffsets ! ix)) segOffsets
  where
    (segOffsets, len) = scanl' (+) 0 seg
    zerosArr          = generate (index1 $ the len) (const 0)
mkTailFlags :: Acc (Array DIM1 Int) -> Acc (Array DIM1 Int)
mkTailFlags seg
  = permute (\_ _ -> 1) zerosArr (ilift1 $ \i -> (segOffsets ! index1 i)  1) segOffsets
  where
    segOffsets = scanl1 (+) seg
    len        = segOffsets ! index1 (size seg  1)
    zerosArr   = generate (index1 len) (const 0)
mkSegApply :: (Elt e)
           => (Exp e -> Exp e -> Exp e)
           -> (Exp (Int, e) -> Exp (Int, e) -> Exp (Int, e))
mkSegApply op = apply
  where
    apply a b = lift (((aF ==* 1) ||* (bF ==* 1)) ? (1, 0), (bF ==* 1) ? (bV, aV `op` bV))
      where
        aF = fst a
        aV = snd a
        bF = fst b
        bV = snd b