{-# LANGUAGE CPP                   #-}
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE UndecidableInstances  #-}
#endif
-- |
-- Module      : Data.Array.Accelerate.Lift
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- Lifting and lowering surface expressions through constructors.
--

module Data.Array.Accelerate.Lift (

  -- * Lifting and unlifting
  Lift(..), Unlift(..),

  lift1, lift2, lift3,
  ilift1, ilift2, ilift3,

) where

import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Array
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Type

import Language.Haskell.TH                                          hiding ( Exp, tupP, tupE )
import Language.Haskell.TH.Extra


-- | Lift a unary function into 'Exp'.
--
lift1 :: (Unlift Exp a, Lift Exp b)
      => (a -> b)
      -> Exp (Plain a)
      -> Exp (Plain b)
lift1 f = lift . f . unlift

-- | Lift a binary function into 'Exp'.
--
lift2 :: (Unlift Exp a, Unlift Exp b, Lift Exp c)
      => (a -> b -> c)
      -> Exp (Plain a)
      -> Exp (Plain b)
      -> Exp (Plain c)
lift2 f x y = lift $ f (unlift x) (unlift y)

-- | Lift a ternary function into 'Exp'.
--
lift3 :: (Unlift Exp a, Unlift Exp b, Unlift Exp c, Lift Exp d)
      => (a -> b -> c -> d)
      -> Exp (Plain a)
      -> Exp (Plain b)
      -> Exp (Plain c)
      -> Exp (Plain d)
lift3 f x y z = lift $ f (unlift x) (unlift y) (unlift z)

-- | Lift a unary function to a computation over rank-1 indices.
--
ilift1 :: (Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1
ilift1 f = lift1 (\(Z:.i) -> Z :. f i)

-- | Lift a binary function to a computation over rank-1 indices.
--
ilift2 :: (Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1
ilift2 f = lift2 (\(Z:.i) (Z:.j) -> Z :. f i j)

-- | Lift a ternary function to a computation over rank-1 indices.
--
ilift3 :: (Exp Int -> Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -> Exp DIM1
ilift3 f = lift3 (\(Z:.i) (Z:.j) (Z:.k) -> Z :. f i j k)


-- | The class of types @e@ which can be lifted into @c@.
--
class Lift c e where
  -- | An associated-type (i.e. a type-level function) that strips all
  --   instances of surface type constructors @c@ from the input type @e@.
  --
  --   For example, the tuple types @(Exp Int, Int)@ and @(Int, Exp
  --   Int)@ have the same \"Plain\" representation.  That is, the
  --   following type equality holds:
  --
  --    @Plain (Exp Int, Int) ~ (Int,Int) ~ Plain (Int, Exp Int)@
  --
  type Plain e

  -- | Lift the given value into a surface type 'c' --- either 'Exp' for scalar
  -- expressions or 'Acc' for array computations. The value may already contain
  -- subexpressions in 'c'.
  --
  lift :: e -> c (Plain e)

-- | A limited subset of types which can be lifted, can also be unlifted.
class Lift c e => Unlift c e where

  -- | Unlift the outermost constructor through the surface type. This is only
  -- possible if the constructor is fully determined by its type - i.e., it is a
  -- singleton.
  --
  unlift :: c (Plain e) -> e


-- Identity instances
-- ------------------

instance Lift Exp (Exp e) where
  type Plain (Exp e) = e
  lift = id

instance Unlift Exp (Exp e) where
  unlift = id

instance Lift Acc (Acc a) where
  type Plain (Acc a) = a
  lift = id

instance Unlift Acc (Acc a) where
  unlift = id

-- instance Lift Seq (Seq a) where
--   type Plain (Seq a) = a
--   lift = id

-- instance Unlift Seq (Seq a) where
--   unlift = id


-- Instances for indices
-- ---------------------

instance Lift Exp Z where
  type Plain Z = Z
  lift _ = Z_

instance Unlift Exp Z where
  unlift _ = Z

instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Int) where
  type Plain (ix :. Int) = Plain ix :. Int
  lift (ix :. i) = lift ix ::. lift i

instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. All) where
  type Plain (ix :. All) = Plain ix :. All
  lift (ix :. i) = lift ix ::. constant i

instance (Elt e, Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Exp e) where
  type Plain (ix :. Exp e) = Plain ix :. e
  lift (ix :. i) = lift ix ::. i

instance {-# OVERLAPPABLE #-} (Elt e, Elt (Plain ix), Unlift Exp ix) => Unlift Exp (ix :. Exp e) where
  unlift (ix ::. i) = unlift ix :. i

instance {-# OVERLAPPABLE #-} (Elt e, Elt ix) => Unlift Exp (Exp ix :. Exp e) where
  unlift (ix ::. i) = ix :. i

instance (Shape sh, Elt (Any sh)) => Lift Exp (Any sh) where
  type Plain (Any sh) = Any sh
  lift Any = constant Any

-- Instances for numeric types
-- ---------------------------

{-# INLINE expConst #-}
expConst :: forall e. Elt e => IsScalar (EltR e) => e -> Exp e
expConst = Exp . SmartExp . Const (scalarType @(EltR e)) . fromElt

instance Lift Exp Int where
  type Plain Int = Int
  lift = expConst

instance Lift Exp Int8 where
  type Plain Int8 = Int8
  lift = expConst

instance Lift Exp Int16 where
  type Plain Int16 = Int16
  lift = expConst

instance Lift Exp Int32 where
  type Plain Int32 = Int32
  lift = expConst

instance Lift Exp Int64 where
  type Plain Int64 = Int64
  lift = expConst

instance Lift Exp Word where
  type Plain Word = Word
  lift = expConst

instance Lift Exp Word8 where
  type Plain Word8 = Word8
  lift = expConst

instance Lift Exp Word16 where
  type Plain Word16 = Word16
  lift = expConst

instance Lift Exp Word32 where
  type Plain Word32 = Word32
  lift = expConst

instance Lift Exp Word64 where
  type Plain Word64 = Word64
  lift = expConst

instance Lift Exp CShort where
  type Plain CShort = CShort
  lift = expConst

instance Lift Exp CUShort where
  type Plain CUShort = CUShort
  lift = expConst

instance Lift Exp CInt where
  type Plain CInt = CInt
  lift = expConst

instance Lift Exp CUInt where
  type Plain CUInt = CUInt
  lift = expConst

instance Lift Exp CLong where
  type Plain CLong = CLong
  lift = expConst

instance Lift Exp CULong where
  type Plain CULong = CULong
  lift = expConst

instance Lift Exp CLLong where
  type Plain CLLong = CLLong
  lift = expConst

instance Lift Exp CULLong where
  type Plain CULLong = CULLong
  lift = expConst

instance Lift Exp Half where
  type Plain Half = Half
  lift = expConst

instance Lift Exp Float where
  type Plain Float = Float
  lift = expConst

instance Lift Exp Double where
  type Plain Double = Double
  lift = expConst

instance Lift Exp CFloat where
  type Plain CFloat = CFloat
  lift = expConst

instance Lift Exp CDouble where
  type Plain CDouble = CDouble
  lift = expConst

instance Lift Exp Bool where
  type Plain Bool = Bool
  lift True  = Exp . SmartExp $ SmartExp (Const scalarType 1) `Pair` SmartExp Nil
  lift False = Exp . SmartExp $ SmartExp (Const scalarType 0) `Pair` SmartExp Nil

instance Lift Exp Char where
  type Plain Char = Char
  lift = expConst

instance Lift Exp CChar where
  type Plain CChar = CChar
  lift = expConst

instance Lift Exp CSChar where
  type Plain CSChar = CSChar
  lift = expConst

instance Lift Exp CUChar where
  type Plain CUChar = CUChar
  lift = expConst

-- Instances for tuples
-- --------------------

instance Lift Exp () where
  type Plain () = ()
  lift _ = Exp (SmartExp Nil)

instance Unlift Exp () where
  unlift _ = ()

instance Lift Acc () where
  type Plain () = ()
  lift _ = Acc (SmartAcc Anil)

instance Unlift Acc () where
  unlift _ = ()

instance (Shape sh, Elt e) => Lift Acc (Array sh e) where
  type Plain (Array sh e) = Array sh e
  lift (Array arr) = Acc $ SmartAcc $ Use (arrayR @sh @e) arr

-- Lift and Unlift instances for tuples
--
runQ $ do
    let
        mkInstances :: Name -> TypeQ -> ExpQ -> ExpQ -> ExpQ -> ExpQ -> Int -> Q [Dec]
        mkInstances con cst smart prj nil pair n = do
          let
              xs      = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
              ts      = map varT xs
              res1    = tupT ts
              res2    = tupT (map (conT con `appT`) ts)
              plain   = tupT (map (\t -> [t| Plain $t |]) ts)
              ctx1    = tupT (map (\t -> [t| Lift $(conT con) $t |]) ts)
              ctx2    = tupT (map (\t -> [t| $cst (Plain $t) |]) ts)
              ctx3    = tupT (map (appT cst) ts)
              --
              get x 0 = [| $(conE con) ($smart ($prj PairIdxRight $x)) |]
              get x i = get [| $smart ($prj PairIdxLeft $x) |] (i-1)
          --
          _x <- newName "_x"
          [d| instance ($ctx1, $ctx2) => Lift $(conT con) $res1 where
                type Plain $res1 = $plain
                lift $(tupP (map varP xs)) =
                  $(conE con)
                  $(foldl (\vs v -> do _v <- newName "_v"
                                       [| let $(conP con [varP _v]) = lift $(varE v)
                                           in $smart ($pair $vs $(varE _v)) |]) [| $smart $nil |] xs)

              instance $ctx3 => Unlift $(conT con) $res2 where
                unlift $(conP con [varP _x]) =
                  $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0]))
            |]

        mkAccInstances = mkInstances (mkName "Acc") [t| Arrays |] [| SmartAcc |] [| Aprj |] [| Anil |] [| Apair |]
        mkExpInstances = mkInstances (mkName "Exp") [t| Elt    |] [| SmartExp |] [| Prj  |] [| Nil  |] [| Pair  |]
    --
    as <- mapM mkAccInstances [2..16]
    es <- mapM mkExpInstances [2..16]
    return $ concat (as ++ es)