{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.Utility.Lift.Exp (
   Unlift,
   Unlifted,
   Tuple,
   unlift,
   modify,
   modify2,
   modify3,
   modify4,
   Exp(Exp), expr, atom,
   unliftPair,
   unliftTriple,
   unliftQuadruple,
   asExp,
   mapFst,
   mapSnd,
   fst3,
   snd3,
   thd3,
   indexCons,
   ) where

import qualified Data.Array.Accelerate.Data.Complex as Complex
import qualified Data.Array.Accelerate as A
import Data.Complex (Complex((:+)))
import Data.Array.Accelerate ((:.)((:.)))

import qualified Data.Tuple.HT as Tuple
import Data.Tuple.HT (mapTriple)


{- |
This class simplifies untupling of expressions.
If you have a function

> g :: ((Exp a, Exp b), Exp (c,d)) -> (Exp e, Exp f)

you cannot apply it to an array @arr :: Array sh ((a,b),(c,d))@ using 'A.map'.
Here, the 'modify' function helps:

> modify ((expr,expr),expr) g :: Exp ((a,b),(c,d)) -> Exp (e,f)

The 'expr'-pattern tells, how deep the tuple shall be unlifted.
This way you can write:

> A.map
>    (Exp.modify ((expr,expr),expr) $ \((a,b), cd) -> g ((a,b), cd))
>    arr

'modify' is based on 'unlift'.
In contrast to 'A.unlift' it does not only unlift one level of tupels,
but is guided by an 'expr'-pattern.
In the example I have demonstrated,
how the pair @(a,b)@ is unlifted, but the pair @(c,d)@ is not.
For the result tuple, 'modify' simply calls 'A.lift'.
In contrast to 'A.unlift',
'A.lift' lifts over all tupel levels until it obtains a single 'Exp'.
-}
class
   (A.Elt (Tuple pattern), A.Plain (Unlifted pattern) ~ Tuple pattern) =>
      Unlift pattern where
   type Unlifted pattern
   type Tuple pattern
   unlift :: pattern -> A.Exp (Tuple pattern) -> Unlifted pattern

modify ::
   (A.Lift A.Exp a, Unlift pattern) =>
   pattern ->
   (Unlifted pattern -> a) ->
   A.Exp (Tuple pattern) -> A.Exp (A.Plain a)
modify p f = A.lift . f . unlift p

modify2 ::
   (A.Lift A.Exp a, Unlift patternA, Unlift patternB) =>
   patternA ->
   patternB ->
   (Unlifted patternA -> Unlifted patternB -> a) ->
   A.Exp (Tuple patternA) -> A.Exp (Tuple patternB) -> A.Exp (A.Plain a)
modify2 pa pb f a b = A.lift $ f (unlift pa a) (unlift pb b)

modify3 ::
   (A.Lift A.Exp a, Unlift patternA, Unlift patternB, Unlift patternC) =>
   patternA ->
   patternB ->
   patternC ->
   (Unlifted patternA -> Unlifted patternB -> Unlifted patternC -> a) ->
   A.Exp (Tuple patternA) -> A.Exp (Tuple patternB) ->
   A.Exp (Tuple patternC) -> A.Exp (A.Plain a)
modify3 pa pb pc f a b c =
   A.lift $ f (unlift pa a) (unlift pb b) (unlift pc c)

modify4 ::
   (A.Lift A.Exp a,
    Unlift patternA, Unlift patternB, Unlift patternC, Unlift patternD) =>
   patternA ->
   patternB ->
   patternC ->
   patternD ->
   (Unlifted patternA -> Unlifted patternB ->
    Unlifted patternC -> Unlifted patternD -> a) ->
   A.Exp (Tuple patternA) -> A.Exp (Tuple patternB) ->
   A.Exp (Tuple patternC) -> A.Exp (Tuple patternD) -> A.Exp (A.Plain a)
modify4 pa pb pc pd f a b c d =
   A.lift $ f (unlift pa a) (unlift pb b) (unlift pc c) (unlift pd d)


instance (A.Elt a) => Unlift (Exp a) where
   type Unlifted (Exp a) = A.Exp a
   type Tuple (Exp a) = a
   unlift _ = id

data Exp e = Exp

expr :: Exp e
expr = Exp

{-# DEPRECATED atom "use expr instead" #-}
-- | for compatibility with accelerate-utility-0.0
atom :: Exp e
atom = expr


instance (Unlift pa, Unlift pb) => Unlift (pa,pb) where
   type Unlifted (pa,pb) = (Unlifted pa, Unlifted pb)
   type Tuple (pa,pb) = (Tuple pa, Tuple pb)
   unlift (pa,pb) ab =
      (unlift pa $ A.fst ab, unlift pb $ A.snd ab)

instance
   (Unlift pa, Unlift pb, Unlift pc) =>
      Unlift (pa,pb,pc) where
   type Unlifted (pa,pb,pc) = (Unlifted pa, Unlifted pb, Unlifted pc)
   type Tuple (pa,pb,pc) = (Tuple pa, Tuple pb, Tuple pc)
   unlift (pa,pb,pc) =
      mapTriple (unlift pa, unlift pb, unlift pc) . A.unlift


instance (Unlift pa, A.Slice (Tuple pa), int ~ Exp Int) => Unlift (pa :. int) where
   type Unlifted (pa :. int) = Unlifted pa :. A.Exp Int
   type Tuple (pa :. int) = Tuple pa :. Int
   unlift (pa:.pb) ab =
      (unlift pa $ A.indexTail ab) :. (unlift pb $ A.indexHead ab)


instance (Unlift p) => Unlift (Complex p) where
   type Unlifted (Complex p) = Complex (Unlifted p)
   type Tuple (Complex p) = Complex (Tuple p)
   unlift (preal:+pimag) z =
      unlift preal (Complex.real z)
      :+
      unlift pimag (Complex.imag z)


unliftPair :: (A.Elt a, A.Elt b) => A.Exp (a,b) -> (A.Exp a, A.Exp b)
unliftPair = A.unlift

unliftTriple ::
   (A.Elt a, A.Elt b, A.Elt c) => A.Exp (a,b,c) -> (A.Exp a, A.Exp b, A.Exp c)
unliftTriple = A.unlift

unliftQuadruple ::
   (A.Elt a, A.Elt b, A.Elt c, A.Elt d) =>
   A.Exp (a,b,c,d) -> (A.Exp a, A.Exp b, A.Exp c, A.Exp d)
unliftQuadruple = A.unlift

asExp :: A.Exp a -> A.Exp a
asExp = id

mapFst ::
   (A.Elt a, A.Elt b, A.Elt c) =>
   (A.Exp a -> A.Exp b) -> A.Exp (a,c) -> A.Exp (b,c)
mapFst f = modify (expr,expr) $ \(a,c) -> (f a, c)

mapSnd ::
   (A.Elt a, A.Elt b, A.Elt c) =>
   (A.Exp b -> A.Exp c) -> A.Exp (a,b) -> A.Exp (a,c)
mapSnd f = modify (expr,expr) $ \(a,b) -> (a, f b)


fst3 ::
   (A.Elt a, A.Elt b, A.Elt c) =>
   A.Exp (a,b,c) -> A.Exp a
fst3 = modify (expr,expr,expr) Tuple.fst3

snd3 ::
   (A.Elt a, A.Elt b, A.Elt c) =>
   A.Exp (a,b,c) -> A.Exp b
snd3 = modify (expr,expr,expr) Tuple.snd3

thd3 ::
   (A.Elt a, A.Elt b, A.Elt c) =>
   A.Exp (a,b,c) -> A.Exp c
thd3 = modify (expr,expr,expr) Tuple.thd3



indexCons ::
   (A.Slice ix) => A.Exp ix -> A.Exp Int -> A.Exp (ix :. Int)
indexCons ix n = A.lift $ ix:.n