{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Array.Accelerate.Utility.Lift.Acc (
   Unlift,
   Unlifted,
   Tuple,
   unlift,
   modify,
   modify2,
   modify3,
   modify4,
   Acc(Acc), acc,
   Exp(Exp), expr,
   unliftPair,
   unliftTriple,
   unliftQuadruple,
   mapFst,
   mapSnd,
   singleton,
   the,
   ) where

import qualified Data.Array.Accelerate as A

import qualified Data.Tuple.HT as Tuple
import Data.Tuple.HT (mapTriple)


class (A.Arrays (Tuple pattern)) => Unlift pattern where
   type Unlifted pattern
   type Tuple pattern
   unlift :: pattern -> A.Acc (Tuple pattern) -> Unlifted pattern

modify ::
   (A.Lift A.Acc a, Unlift pattern) =>
   pattern ->
   (Unlifted pattern -> a) ->
   A.Acc (Tuple pattern) -> A.Acc (A.Plain a)
modify p f = A.lift . f . unlift p

modify2 ::
   (A.Lift A.Acc a, Unlift patternA, Unlift patternB) =>
   patternA ->
   patternB ->
   (Unlifted patternA -> Unlifted patternB -> a) ->
   A.Acc (Tuple patternA) -> A.Acc (Tuple patternB) -> A.Acc (A.Plain a)
modify2 pa pb f a b = A.lift $ f (unlift pa a) (unlift pb b)

modify3 ::
   (A.Lift A.Acc a, Unlift patternA, Unlift patternB, Unlift patternC) =>
   patternA ->
   patternB ->
   patternC ->
   (Unlifted patternA -> Unlifted patternB -> Unlifted patternC -> a) ->
   A.Acc (Tuple patternA) -> A.Acc (Tuple patternB) ->
   A.Acc (Tuple patternC) -> A.Acc (A.Plain a)
modify3 pa pb pc f a b c =
   A.lift $ f (unlift pa a) (unlift pb b) (unlift pc c)

modify4 ::
   (A.Lift A.Acc a,
    Unlift patternA, Unlift patternB, Unlift patternC, Unlift patternD) =>
   patternA ->
   patternB ->
   patternC ->
   patternD ->
   (Unlifted patternA -> Unlifted patternB ->
    Unlifted patternC -> Unlifted patternD -> a) ->
   A.Acc (Tuple patternA) -> A.Acc (Tuple patternB) ->
   A.Acc (Tuple patternC) -> A.Acc (Tuple patternD) -> A.Acc (A.Plain a)
modify4 pa pb pc pd f a b c d =
   A.lift $ f (unlift pa a) (unlift pb b) (unlift pc c) (unlift pd d)


instance (A.Arrays a) => Unlift (Acc a) where
   type Unlifted (Acc a) = A.Acc a
   type Tuple (Acc a) = a
   unlift _ = id

data Acc a = Acc

acc :: Acc a
acc = Acc


instance (A.Elt a) => Unlift (Exp a) where
   type Unlifted (Exp a) = A.Exp a
   type Tuple (Exp a) = A.Scalar a
   unlift _ = A.the

data Exp e = Exp

expr :: Exp e
expr = Exp


-- | like 'A.unit' in the 'Acc' environment
singleton :: (A.Elt e) => e -> A.Scalar e
singleton x = A.fromList A.Z [x]

-- | like 'A.the' in the 'Acc' environment
the :: (A.Elt e) => A.Scalar e -> e
the arr = A.indexArray arr A.Z


instance (Unlift pa, Unlift pb) => Unlift (pa,pb) where
   type Unlifted (pa,pb) = (Unlifted pa, Unlifted pb)
   type Tuple (pa,pb) = (Tuple pa, Tuple pb)
   unlift (pa,pb) ab =
      (unlift pa $ A.afst ab, unlift pb $ A.asnd ab)

instance
   (Unlift pa, Unlift pb, Unlift pc) =>
      Unlift (pa,pb,pc) where
   type Unlifted (pa,pb,pc) = (Unlifted pa, Unlifted pb, Unlifted pc)
   type Tuple (pa,pb,pc) = (Tuple pa, Tuple pb, Tuple pc)
   unlift (pa,pb,pc) =
      mapTriple (unlift pa, unlift pb, unlift pc) . A.unlift



unliftPair :: (A.Arrays a, A.Arrays b) => A.Acc (a,b) -> (A.Acc a, A.Acc b)
unliftPair = A.unlift

unliftTriple ::
   (A.Arrays a, A.Arrays b, A.Arrays c) => A.Acc (a,b,c) -> (A.Acc a, A.Acc b, A.Acc c)
unliftTriple = A.unlift

unliftQuadruple ::
   (A.Arrays a, A.Arrays b, A.Arrays c, A.Arrays d) =>
   A.Acc (a,b,c,d) -> (A.Acc a, A.Acc b, A.Acc c, A.Acc d)
unliftQuadruple = A.unlift


mapFst ::
   (A.Arrays a, A.Arrays b, A.Arrays c) =>
   (A.Acc a -> A.Acc b) -> A.Acc (a,c) -> A.Acc (b,c)
mapFst f = modify (acc,acc) $ Tuple.mapFst f

mapSnd ::
   (A.Arrays a, A.Arrays b, A.Arrays c) =>
   (A.Acc b -> A.Acc c) -> A.Acc (a,b) -> A.Acc (a,c)
mapSnd f = modify (acc,acc) $ Tuple.mapSnd f