{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeOperators #-}
{- |
Simplify running @accelerate@ functions
with multiple curried array and expression arguments.
-}
module Data.Array.Accelerate.Utility.Lift.Run (
   C(..), with,
   Argument(..),
   ) where

import qualified Data.Array.Accelerate.Utility.Lift.Acc as Acc

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Acc, Exp, Z, (:.))

import Data.Tuple.HT (mapPair, mapTriple)


merge ::
   (A.Arrays a, A.Arrays packed) =>
   (Acc packed -> b) ->
   (Acc a -> b -> c) -> (Acc (a,packed) -> c)
merge unpack f arr = f (A.afst arr) (unpack $ A.asnd arr)

_mergeAcc ::
   (A.Arrays a, A.Arrays b) =>
   (Acc a -> Acc b -> c) -> (Acc (a,b) -> c)
_mergeAcc f arr = f (A.afst arr) (A.asnd arr)

_mergeExp ::
   (A.Elt a, A.Arrays b) =>
   (Exp a -> Acc b -> c) -> (Acc (A.Scalar a, b) -> c)
_mergeExp f arr = f (A.the $ A.afst arr) (A.asnd arr)

_mergeExpR ::
   (A.Arrays a, A.Elt b) =>
   (Acc a -> Exp b -> c) -> (Acc (a, A.Scalar b) -> c)
_mergeExpR f arr = f (A.afst arr) (A.the $ A.asnd arr)


split ::
   (A.Arrays a, A.Arrays packed) =>
   (unpacked -> packed) ->
   ((a, packed) -> c) -> (a -> unpacked -> c)
split pack f a b = f (a, pack b)

_splitAcc :: ((a,b) -> c) -> (a -> b -> c)
_splitAcc = curry

_splitExp :: (A.Elt a) => ((A.Scalar a, b) -> c) -> (a -> b -> c)
_splitExp f a b = f (Acc.singleton a, b)

_splitExpR :: (A.Elt b) => ((a, A.Scalar b) -> c) -> (a -> b -> c)
_splitExpR f a b = f (a, Acc.singleton b)



{-
(Acc ((a,b),c)) -> Acc d) -> (((a,b),c) -> d)
(Acc (a,b) -> Acc c -> Acc d) -> ((a,b) -> c -> d)
(Acc a -> Acc b -> Acc c -> Acc d) -> (a -> b -> c -> d)
-}
class C f where
   type Arguments a f
   type Result f
   type Plain f
   with1 ::
      (A.Arrays a) =>
      ((Acc (Arguments a f) -> Acc (Result f)) ->
       ((Arguments a f) -> Result f)) ->
      (Acc a -> f) -> a -> Plain f

with ::
   (C f) =>
   ((Acc (Arguments () f) -> Acc (Result f)) ->
    ((Arguments () f) -> Result f)) ->
   f -> Plain f
with run f = with1 run (const f) ()


instance C (Acc r) where
   type Arguments a (Acc r) = a
   type Result (Acc r) = r
   type Plain (Acc r) = r
   with1 run f = run f

instance
   (A.Lift Acc r, A.Arrays (A.Plain r),
    A.Lift Acc s, A.Arrays (A.Plain s)) =>
      C (r,s) where
   type Arguments a (r,s) = a
   type Result (r,s) = (A.Plain r, A.Plain s)
   type Plain (r,s) = (A.Plain r, A.Plain s)
   with1 run f = run (A.lift . f)

instance
   (A.Lift Acc r, A.Arrays (A.Plain r),
    A.Lift Acc s, A.Arrays (A.Plain s),
    A.Lift Acc t, A.Arrays (A.Plain t)) =>
      C (r,s,t) where
   type Arguments a (r,s,t) = a
   type Result (r,s,t) = (A.Plain r, A.Plain s, A.Plain t)
   type Plain (r,s,t) = (A.Plain r, A.Plain s, A.Plain t)
   with1 run f = run (A.lift . f)



instance (Argument arg, C f) => C (arg -> f) where
   type Arguments a (arg -> f) = Arguments (a, Packed arg) f
   type Result (arg -> f) = Result f
   type Plain (arg -> f) = Unpacked arg -> Plain f
   with1 run f =
      case tunnel of
         (pack, unpack) ->
            split pack (with1 run $ merge unpack f)


class (A.Arrays (Packed a)) => Argument a where
   type Packed a
   type Unpacked a
   tunnel :: (Unpacked a -> Packed a, Acc (Packed a) -> a)

instance (A.Arrays a) => Argument (Acc a) where
   type Packed (Acc a) = a
   type Unpacked (Acc a) = a
   tunnel = (id, id)

instance (A.Elt a) => Argument (Exp a) where
   type Packed (Exp a) = A.Scalar a
   type Unpacked (Exp a) = a
   tunnel = (Acc.singleton, A.the)

instance (Argument a, Argument b) => Argument (a,b) where
   type Packed (a,b) = (Packed a, Packed b)
   type Unpacked (a,b) = (Unpacked a, Unpacked b)
   tunnel =
      case (tunnel, tunnel) of
         ((packA, unpackA), (packB, unpackB)) ->
            (mapPair (packA,packB),
             mapPair (unpackA,unpackB) . A.unlift)

instance (Argument a, Argument b, Argument c) => Argument (a,b,c) where
   type Packed (a,b,c) = (Packed a, Packed b, Packed c)
   type Unpacked (a,b,c) = (Unpacked a, Unpacked b, Unpacked c)
   tunnel =
      case (tunnel, tunnel, tunnel) of
         ((packA, unpackA), (packB, unpackB), (packC, unpackC)) ->
            (mapTriple (packA,packB,packC),
             mapTriple (unpackA,unpackB,unpackC) . A.unlift)


instance Argument Z where
   type Packed Z = A.Scalar Z
   type Unpacked Z = Z
   tunnel = (Acc.singleton, A.unlift . A.the)

instance
   (A.Unlift Exp a, A.Lift Exp a, A.Slice (A.Plain a), b ~ Exp Int) =>
      Argument (a:.b) where
   type Packed (a:.b) = A.Scalar (A.Plain a :. Int)
   type Unpacked (a:.b) = A.Plain a :. Int
   tunnel = (Acc.singleton, A.unlift . A.the)