{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} module Data.Array.Accelerate.Utility.Lift.Exp ( Unlift, Unlifted, Tuple, unlift, modify, modify2, modify3, modify4, Atom(Atom), atom, unliftPair, unliftTriple, unliftQuadruple, asExp, mapFst, mapSnd, fst3, snd3, thd3, indexCons, ) where import qualified Data.Array.Accelerate.Data.Complex as Complex import qualified Data.Array.Accelerate as A import Data.Complex (Complex((:+))) import Data.Array.Accelerate (Exp, (:.)((:.))) import qualified Data.Tuple.HT as Tuple import Data.Tuple.HT (mapTriple) class (A.Elt (Tuple pattern), A.Plain (Unlifted pattern) ~ Tuple pattern) => Unlift pattern where type Unlifted pattern type Tuple pattern unlift :: pattern -> Exp (Tuple pattern) -> Unlifted pattern modify :: (A.Lift Exp a, Unlift pattern) => pattern -> (Unlifted pattern -> a) -> Exp (Tuple pattern) -> Exp (A.Plain a) modify p f = A.lift . f . unlift p modify2 :: (A.Lift Exp a, Unlift patternA, Unlift patternB) => patternA -> patternB -> (Unlifted patternA -> Unlifted patternB -> a) -> Exp (Tuple patternA) -> Exp (Tuple patternB) -> Exp (A.Plain a) modify2 pa pb f a b = A.lift $ f (unlift pa a) (unlift pb b) modify3 :: (A.Lift Exp a, Unlift patternA, Unlift patternB, Unlift patternC) => patternA -> patternB -> patternC -> (Unlifted patternA -> Unlifted patternB -> Unlifted patternC -> a) -> Exp (Tuple patternA) -> Exp (Tuple patternB) -> Exp (Tuple patternC) -> Exp (A.Plain a) modify3 pa pb pc f a b c = A.lift $ f (unlift pa a) (unlift pb b) (unlift pc c) modify4 :: (A.Lift Exp a, Unlift patternA, Unlift patternB, Unlift patternC, Unlift patternD) => patternA -> patternB -> patternC -> patternD -> (Unlifted patternA -> Unlifted patternB -> Unlifted patternC -> Unlifted patternD -> a) -> Exp (Tuple patternA) -> Exp (Tuple patternB) -> Exp (Tuple patternC) -> Exp (Tuple patternD) -> Exp (A.Plain a) modify4 pa pb pc pd f a b c d = A.lift $ f (unlift pa a) (unlift pb b) (unlift pc c) (unlift pd d) instance (A.Elt a) => Unlift (Atom a) where type Unlifted (Atom a) = Exp a type Tuple (Atom a) = a unlift _ = id data Atom a = Atom atom :: Atom a atom = Atom instance (Unlift pa, Unlift pb) => Unlift (pa,pb) where type Unlifted (pa,pb) = (Unlifted pa, Unlifted pb) type Tuple (pa,pb) = (Tuple pa, Tuple pb) unlift (pa,pb) ab = (unlift pa $ A.fst ab, unlift pb $ A.snd ab) instance (Unlift pa, Unlift pb, Unlift pc) => Unlift (pa,pb,pc) where type Unlifted (pa,pb,pc) = (Unlifted pa, Unlifted pb, Unlifted pc) type Tuple (pa,pb,pc) = (Tuple pa, Tuple pb, Tuple pc) unlift (pa,pb,pc) = mapTriple (unlift pa, unlift pb, unlift pc) . A.unlift instance (Unlift pa, A.Slice (Tuple pa), int ~ Atom Int) => Unlift (pa :. int) where type Unlifted (pa :. int) = Unlifted pa :. Exp Int type Tuple (pa :. int) = Tuple pa :. Int unlift (pa:.pb) ab = (unlift pa $ A.indexTail ab) :. (unlift pb $ A.indexHead ab) instance (Unlift p) => Unlift (Complex p) where type Unlifted (Complex p) = Complex (Unlifted p) type Tuple (Complex p) = Complex (Tuple p) unlift (preal:+pimag) z = unlift preal (Complex.real z) :+ unlift pimag (Complex.imag z) unliftPair :: (A.Elt a, A.Elt b) => Exp (a,b) -> (Exp a, Exp b) unliftPair = A.unlift unliftTriple :: (A.Elt a, A.Elt b, A.Elt c) => Exp (a,b,c) -> (Exp a, Exp b, Exp c) unliftTriple = A.unlift unliftQuadruple :: (A.Elt a, A.Elt b, A.Elt c, A.Elt d) => Exp (a,b,c,d) -> (Exp a, Exp b, Exp c, Exp d) unliftQuadruple = A.unlift asExp :: Exp a -> Exp a asExp = id mapFst :: (A.Elt a, A.Elt b, A.Elt c) => (Exp a -> Exp b) -> Exp (a,c) -> Exp (b,c) mapFst f = modify (atom,atom) $ \(a,c) -> (f a, c) mapSnd :: (A.Elt a, A.Elt b, A.Elt c) => (Exp b -> Exp c) -> Exp (a,b) -> Exp (a,c) mapSnd f = modify (atom,atom) $ \(a,b) -> (a, f b) fst3 :: (A.Elt a, A.Elt b, A.Elt c) => Exp (a,b,c) -> Exp a fst3 = modify (atom,atom,atom) Tuple.fst3 snd3 :: (A.Elt a, A.Elt b, A.Elt c) => Exp (a,b,c) -> Exp b snd3 = modify (atom,atom,atom) Tuple.snd3 thd3 :: (A.Elt a, A.Elt b, A.Elt c) => Exp (a,b,c) -> Exp c thd3 = modify (atom,atom,atom) Tuple.thd3 indexCons :: (A.Slice ix) => Exp ix -> Exp Int -> Exp (ix :. Int) indexCons ix n = A.lift $ ix:.n