{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.Utility.Lift.Acc (
   Unlift,
   Unlifted,
   Tuple,
   unlift,
   modify,
   modify2,
   modify3,
   modify4,
   Acc(Acc), acc,
   Exp(Exp), expr,
   unliftPair,
   unliftTriple,
   unliftQuadruple,
   mapFst,
   mapSnd,
   singleton,
   the,
   ) where

import Data.Array.Accelerate.Utility.Lift.Exp (Exp(Exp), expr)
import qualified Data.Array.Accelerate as A

import qualified Data.Tuple.HT as Tuple
import Data.Tuple.HT (mapTriple)


{- |
This class is like 'Data.Array.Accelerate.Utility.Lift.Exp.Unlift'
but for the 'Acc' environment.
It allows you to unlift an 'Acc' of nested tuples
into tuples of 'Exp' and 'Acc' values.
It can be quite handy when working with 'A.acond' and 'A.awhile'.
It can also be useful in connection with running an @accelerate@ algorithm
at a certain backend,
like 'Data.Array.Accelerate.Interpreter.run1'.
But in this case you might prefer "Data.Array.Accelerate.Utility.Lift.Run".
-}
class (A.Arrays (Tuple pattern)) => Unlift pattern where
   type Unlifted pattern
   type Tuple pattern
   unlift :: pattern -> A.Acc (Tuple pattern) -> Unlifted pattern

modify ::
   (A.Lift A.Acc a, Unlift pattern) =>
   pattern ->
   (Unlifted pattern -> a) ->
   A.Acc (Tuple pattern) -> A.Acc (A.Plain a)
modify p f = A.lift . f . unlift p

modify2 ::
   (A.Lift A.Acc a, Unlift patternA, Unlift patternB) =>
   patternA ->
   patternB ->
   (Unlifted patternA -> Unlifted patternB -> a) ->
   A.Acc (Tuple patternA) -> A.Acc (Tuple patternB) -> A.Acc (A.Plain a)
modify2 pa pb f a b = A.lift $ f (unlift pa a) (unlift pb b)

modify3 ::
   (A.Lift A.Acc a, Unlift patternA, Unlift patternB, Unlift patternC) =>
   patternA ->
   patternB ->
   patternC ->
   (Unlifted patternA -> Unlifted patternB -> Unlifted patternC -> a) ->
   A.Acc (Tuple patternA) -> A.Acc (Tuple patternB) ->
   A.Acc (Tuple patternC) -> A.Acc (A.Plain a)
modify3 pa pb pc f a b c =
   A.lift $ f (unlift pa a) (unlift pb b) (unlift pc c)

modify4 ::
   (A.Lift A.Acc a,
    Unlift patternA, Unlift patternB, Unlift patternC, Unlift patternD) =>
   patternA ->
   patternB ->
   patternC ->
   patternD ->
   (Unlifted patternA -> Unlifted patternB ->
    Unlifted patternC -> Unlifted patternD -> a) ->
   A.Acc (Tuple patternA) -> A.Acc (Tuple patternB) ->
   A.Acc (Tuple patternC) -> A.Acc (Tuple patternD) -> A.Acc (A.Plain a)
modify4 pa pb pc pd f a b c d =
   A.lift $ f (unlift pa a) (unlift pb b) (unlift pc c) (unlift pd d)


instance (A.Arrays a) => Unlift (Acc a) where
   type Unlifted (Acc a) = A.Acc a
   type Tuple (Acc a) = a
   unlift _ = id

data Acc a = Acc

acc :: Acc a
acc = Acc


instance (A.Elt a) => Unlift (Exp a) where
   type Unlifted (Exp a) = A.Exp a
   type Tuple (Exp a) = A.Scalar a
   unlift _ = A.the


-- | like 'A.unit' in the 'Acc' environment
singleton :: (A.Elt e) => e -> A.Scalar e
singleton x = A.fromList A.Z [x]

-- | like 'A.the' in the 'Acc' environment
the :: (A.Elt e) => A.Scalar e -> e
the arr = A.indexArray arr A.Z


instance (Unlift pa, Unlift pb) => Unlift (pa,pb) where
   type Unlifted (pa,pb) = (Unlifted pa, Unlifted pb)
   type Tuple (pa,pb) = (Tuple pa, Tuple pb)
   unlift (pa,pb) ab =
      (unlift pa $ A.afst ab, unlift pb $ A.asnd ab)

instance
   (Unlift pa, Unlift pb, Unlift pc) =>
      Unlift (pa,pb,pc) where
   type Unlifted (pa,pb,pc) = (Unlifted pa, Unlifted pb, Unlifted pc)
   type Tuple (pa,pb,pc) = (Tuple pa, Tuple pb, Tuple pc)
   unlift (pa,pb,pc) =
      mapTriple (unlift pa, unlift pb, unlift pc) . A.unlift



unliftPair :: (A.Arrays a, A.Arrays b) => A.Acc (a,b) -> (A.Acc a, A.Acc b)
unliftPair = A.unlift

unliftTriple ::
   (A.Arrays a, A.Arrays b, A.Arrays c) => A.Acc (a,b,c) -> (A.Acc a, A.Acc b, A.Acc c)
unliftTriple = A.unlift

unliftQuadruple ::
   (A.Arrays a, A.Arrays b, A.Arrays c, A.Arrays d) =>
   A.Acc (a,b,c,d) -> (A.Acc a, A.Acc b, A.Acc c, A.Acc d)
unliftQuadruple = A.unlift


mapFst ::
   (A.Arrays a, A.Arrays b, A.Arrays c) =>
   (A.Acc a -> A.Acc b) -> A.Acc (a,c) -> A.Acc (b,c)
mapFst f = modify (acc,acc) $ Tuple.mapFst f

mapSnd ::
   (A.Arrays a, A.Arrays b, A.Arrays c) =>
   (A.Acc b -> A.Acc c) -> A.Acc (a,b) -> A.Acc (a,c)
mapSnd f = modify (acc,acc) $ Tuple.mapSnd f