{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeOperators #-}
{- |
Simplify running @accelerate@ functions
with multiple curried array and expression arguments.
-}
module Data.Array.Accelerate.Utility.Lift.Run (
C(..), with,
Argument(..),
) where
import qualified Data.Array.Accelerate.Utility.Lift.Acc as Acc
import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Acc, Exp, Z, (:.))
import Data.Tuple.HT (mapPair, mapTriple)
merge ::
(A.Arrays a, A.Arrays packed) =>
(Acc packed -> b) ->
(Acc a -> b -> c) -> (Acc (a,packed) -> c)
merge unpack f arr = f (A.afst arr) (unpack $ A.asnd arr)
_mergeAcc ::
(A.Arrays a, A.Arrays b) =>
(Acc a -> Acc b -> c) -> (Acc (a,b) -> c)
_mergeAcc f arr = f (A.afst arr) (A.asnd arr)
_mergeExp ::
(A.Elt a, A.Arrays b) =>
(Exp a -> Acc b -> c) -> (Acc (A.Scalar a, b) -> c)
_mergeExp f arr = f (A.the $ A.afst arr) (A.asnd arr)
_mergeExpR ::
(A.Arrays a, A.Elt b) =>
(Acc a -> Exp b -> c) -> (Acc (a, A.Scalar b) -> c)
_mergeExpR f arr = f (A.afst arr) (A.the $ A.asnd arr)
split ::
(A.Arrays a, A.Arrays packed) =>
(unpacked -> packed) ->
((a, packed) -> c) -> (a -> unpacked -> c)
split pack f a b = f (a, pack b)
_splitAcc :: ((a,b) -> c) -> (a -> b -> c)
_splitAcc = curry
_splitExp :: (A.Elt a) => ((A.Scalar a, b) -> c) -> (a -> b -> c)
_splitExp f a b = f (Acc.singleton a, b)
_splitExpR :: (A.Elt b) => ((a, A.Scalar b) -> c) -> (a -> b -> c)
_splitExpR f a b = f (a, Acc.singleton b)
{- |
If you have a function:
> f :: Exp a -> (Acc b, Acc c) -> (Acc d, Acc e)
you cannot run this immediately using 'Data.Array.Accelerate.Interpreter.run1',
since @run1@ expects a function
with a single 'Acc' parameter and a single 'Acc' result.
Using the 'with' function you can just run @f@ as is:
> with run1 f :: a -> (b,c) -> (d,e)
-}
{-
(Acc ((a,b),c)) -> Acc d) -> (((a,b),c) -> d)
(Acc (a,b) -> Acc c -> Acc d) -> ((a,b) -> c -> d)
(Acc a -> Acc b -> Acc c -> Acc d) -> (a -> b -> c -> d)
-}
class C f where
type Arguments a f
type Result f
type Plain f
with1 ::
(A.Arrays a) =>
((Acc (Arguments a f) -> Acc (Result f)) ->
(Arguments a f -> Result f)) ->
(Acc a -> f) -> a -> Plain f
with ::
(C f) =>
((Acc (Arguments () f) -> Acc (Result f)) ->
(Arguments () f -> Result f)) ->
f -> Plain f
with run f = with1 run (const f) ()
instance C (Acc r) where
type Arguments a (Acc r) = a
type Result (Acc r) = r
type Plain (Acc r) = r
with1 run f = run f
instance
(A.Lift Acc r, A.Arrays (A.Plain r),
A.Lift Acc s, A.Arrays (A.Plain s)) =>
C (r,s) where
type Arguments a (r,s) = a
type Result (r,s) = (A.Plain r, A.Plain s)
type Plain (r,s) = (A.Plain r, A.Plain s)
with1 run f = run (A.lift . f)
instance
(A.Lift Acc r, A.Arrays (A.Plain r),
A.Lift Acc s, A.Arrays (A.Plain s),
A.Lift Acc t, A.Arrays (A.Plain t)) =>
C (r,s,t) where
type Arguments a (r,s,t) = a
type Result (r,s,t) = (A.Plain r, A.Plain s, A.Plain t)
type Plain (r,s,t) = (A.Plain r, A.Plain s, A.Plain t)
with1 run f = run (A.lift . f)
instance (Argument arg, C f) => C (arg -> f) where
type Arguments a (arg -> f) = Arguments (a, Packed arg) f
type Result (arg -> f) = Result f
type Plain (arg -> f) = Unpacked arg -> Plain f
with1 run f =
case tunnel of
(pack, unpack) ->
split pack (with1 run $ merge unpack f)
class (A.Arrays (Packed a)) => Argument a where
type Packed a
type Unpacked a
tunnel :: (Unpacked a -> Packed a, Acc (Packed a) -> a)
instance (A.Arrays a) => Argument (Acc a) where
type Packed (Acc a) = a
type Unpacked (Acc a) = a
tunnel = (id, id)
instance (A.Elt a) => Argument (Exp a) where
type Packed (Exp a) = A.Scalar a
type Unpacked (Exp a) = a
tunnel = (Acc.singleton, A.the)
instance (Argument a, Argument b) => Argument (a,b) where
type Packed (a,b) = (Packed a, Packed b)
type Unpacked (a,b) = (Unpacked a, Unpacked b)
tunnel =
case (tunnel, tunnel) of
((packA, unpackA), (packB, unpackB)) ->
(mapPair (packA,packB),
mapPair (unpackA,unpackB) . A.unlift)
instance (Argument a, Argument b, Argument c) => Argument (a,b,c) where
type Packed (a,b,c) = (Packed a, Packed b, Packed c)
type Unpacked (a,b,c) = (Unpacked a, Unpacked b, Unpacked c)
tunnel =
case (tunnel, tunnel, tunnel) of
((packA, unpackA), (packB, unpackB), (packC, unpackC)) ->
(mapTriple (packA,packB,packC),
mapTriple (unpackA,unpackB,unpackC) . A.unlift)
instance Argument Z where
type Packed Z = A.Scalar Z
type Unpacked Z = Z
tunnel = (Acc.singleton, A.unlift . A.the)
instance
(A.Unlift Exp a, A.Lift Exp a, A.Slice (A.Plain a), b ~ Exp Int) =>
Argument (a:.b) where
type Packed (a:.b) = A.Scalar (A.Plain a :. Int)
type Unpacked (a:.b) = A.Plain a :. Int
tunnel = (Acc.singleton, A.unlift . A.the)