{-# LANGUAGE GADTs           #-}
{-# LANGUAGE MagicHash       #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections   #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Representation.Elt
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Representation.Elt
  where

import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Control.Monad.ST
import Data.List                                                    ( intercalate )
import Data.Primitive.ByteArray
import Foreign.Storable
import Language.Haskell.TH


undefElt :: TypeR t -> t
undefElt :: TypeR t -> t
undefElt = TypeR t -> t
forall t. TypeR t -> t
tuple
  where
    tuple :: TypeR t -> t
    tuple :: TypeR t -> t
tuple TypeR t
TupRunit         = ()
    tuple (TupRpair TupR ScalarType a
ta TupR ScalarType b
tb) = (TupR ScalarType a -> a
forall t. TypeR t -> t
tuple TupR ScalarType a
ta, TupR ScalarType b -> b
forall t. TypeR t -> t
tuple TupR ScalarType b
tb)
    tuple (TupRsingle ScalarType t
t)   = ScalarType t -> t
forall t. ScalarType t -> t
scalar ScalarType t
t

    scalar :: ScalarType t -> t
    scalar :: ScalarType t -> t
scalar (SingleScalarType SingleType t
t) = SingleType t -> t
forall t. SingleType t -> t
single SingleType t
t
    scalar (VectorScalarType VectorType (Vec n a)
t) = VectorType (Vec n a) -> Vec n a
forall t. VectorType t -> t
vector VectorType (Vec n a)
t

    vector :: VectorType t -> t
    vector :: VectorType t -> t
vector (VectorType Int
n SingleType a
t) = (forall s. ST s t) -> t
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s t) -> t) -> (forall s. ST s t) -> t
forall a b. (a -> b) -> a -> b
$ do
      MutableByteArray s
mba           <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* TypeR a -> Int
forall e. TypeR e -> Int
bytesElt (ScalarType a -> TypeR a
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType a -> ScalarType a
forall a. SingleType a -> ScalarType a
SingleScalarType SingleType a
t)))
      ByteArray ByteArray#
ba# <- MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
mba
      Vec n a -> ST s (Vec n a)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteArray# -> Vec n a
forall (n :: Nat) a. ByteArray# -> Vec n a
Vec ByteArray#
ba#)

    single :: SingleType t -> t
    single :: SingleType t -> t
single (NumSingleType NumType t
t) = NumType t -> t
forall t. NumType t -> t
num NumType t
t

    num :: NumType t -> t
    num :: NumType t -> t
num (IntegralNumType IntegralType t
t) = IntegralType t -> t
forall t. IntegralType t -> t
integral IntegralType t
t
    num (FloatingNumType FloatingType t
t) = FloatingType t -> t
forall t. FloatingType t -> t
floating FloatingType t
t

    integral :: IntegralType t -> t
    integral :: IntegralType t -> t
integral IntegralType t
TypeInt    = t
0
    integral IntegralType t
TypeInt8   = t
0
    integral IntegralType t
TypeInt16  = t
0
    integral IntegralType t
TypeInt32  = t
0
    integral IntegralType t
TypeInt64  = t
0
    integral IntegralType t
TypeWord   = t
0
    integral IntegralType t
TypeWord8  = t
0
    integral IntegralType t
TypeWord16 = t
0
    integral IntegralType t
TypeWord32 = t
0
    integral IntegralType t
TypeWord64 = t
0

    floating :: FloatingType t -> t
    floating :: FloatingType t -> t
floating FloatingType t
TypeHalf   = t
0
    floating FloatingType t
TypeFloat  = t
0
    floating FloatingType t
TypeDouble = t
0

bytesElt :: TypeR e -> Int
bytesElt :: TypeR e -> Int
bytesElt = TypeR e -> Int
forall e. TypeR e -> Int
tuple
  where
    tuple :: TypeR t -> Int
    tuple :: TypeR t -> Int
tuple TypeR t
TupRunit         = Int
0
    tuple (TupRpair TupR ScalarType a
ta TupR ScalarType b
tb) = TupR ScalarType a -> Int
forall e. TypeR e -> Int
tuple TupR ScalarType a
ta Int -> Int -> Int
forall a. Num a => a -> a -> a
+ TupR ScalarType b -> Int
forall e. TypeR e -> Int
tuple TupR ScalarType b
tb
    tuple (TupRsingle ScalarType t
t)   = ScalarType t -> Int
forall t. ScalarType t -> Int
scalar ScalarType t
t

    scalar :: ScalarType t -> Int
    scalar :: ScalarType t -> Int
scalar (SingleScalarType SingleType t
t) = SingleType t -> Int
forall t. SingleType t -> Int
single SingleType t
t
    scalar (VectorScalarType VectorType (Vec n a)
t) = VectorType (Vec n a) -> Int
forall t. VectorType t -> Int
vector VectorType (Vec n a)
t

    vector :: VectorType t -> Int
    vector :: VectorType t -> Int
vector (VectorType Int
n SingleType a
t) = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* SingleType a -> Int
forall t. SingleType t -> Int
single SingleType a
t

    single :: SingleType t -> Int
    single :: SingleType t -> Int
single (NumSingleType NumType t
t) = NumType t -> Int
forall t. NumType t -> Int
num NumType t
t

    num :: NumType t -> Int
    num :: NumType t -> Int
num (IntegralNumType IntegralType t
t) = IntegralType t -> Int
forall t. IntegralType t -> Int
integral IntegralType t
t
    num (FloatingNumType FloatingType t
t) = FloatingType t -> Int
forall t. FloatingType t -> Int
floating FloatingType t
t

    integral :: IntegralType t -> Int
    integral :: IntegralType t -> Int
integral IntegralType t
TypeInt    = Int -> Int
forall a. Storable a => a -> Int
sizeOf (Int
forall a. HasCallStack => a
undefined::Int)
    integral IntegralType t
TypeInt8   = Int
1
    integral IntegralType t
TypeInt16  = Int
2
    integral IntegralType t
TypeInt32  = Int
4
    integral IntegralType t
TypeInt64  = Int
8
    integral IntegralType t
TypeWord   = Word -> Int
forall a. Storable a => a -> Int
sizeOf (Word
forall a. HasCallStack => a
undefined::Word)
    integral IntegralType t
TypeWord8  = Int
1
    integral IntegralType t
TypeWord16 = Int
2
    integral IntegralType t
TypeWord32 = Int
4
    integral IntegralType t
TypeWord64 = Int
8

    floating :: FloatingType t -> Int
    floating :: FloatingType t -> Int
floating FloatingType t
TypeHalf   = Int
2
    floating FloatingType t
TypeFloat  = Int
4
    floating FloatingType t
TypeDouble = Int
8

showElt :: TypeR e -> e -> String
showElt :: TypeR e -> e -> String
showElt TypeR e
t e
v = TypeR e -> e -> ShowS
forall e. TypeR e -> e -> ShowS
showsElt TypeR e
t e
v String
""

showsElt :: TypeR e -> e -> ShowS
showsElt :: TypeR e -> e -> ShowS
showsElt = TypeR e -> e -> ShowS
forall e. TypeR e -> e -> ShowS
tuple
  where
    tuple :: TypeR e -> e -> ShowS
    tuple :: TypeR e -> e -> ShowS
tuple TypeR e
TupRunit         ()       = String -> ShowS
showString String
"()"
    tuple (TupRpair TupR ScalarType a
t1 TupR ScalarType b
t2) (e1, e2) = String -> ShowS
showString String
"(" ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TupR ScalarType a -> a -> ShowS
forall e. TypeR e -> e -> ShowS
tuple TupR ScalarType a
t1 a
e1 ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
", " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TupR ScalarType b -> b -> ShowS
forall e. TypeR e -> e -> ShowS
tuple TupR ScalarType b
t2 b
e2 ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
")"
    tuple (TupRsingle ScalarType e
tp)  e
val      = ScalarType e -> e -> ShowS
forall e. ScalarType e -> e -> ShowS
scalar ScalarType e
tp e
val

    scalar :: ScalarType e -> e -> ShowS
    scalar :: ScalarType e -> e -> ShowS
scalar (SingleScalarType SingleType e
t) e
e = SingleType e -> e -> ShowS
forall e. SingleType e -> e -> ShowS
single SingleType e
t e
e
    scalar (VectorScalarType VectorType (Vec n a)
t) e
e = VectorType (Vec n a) -> Vec n a -> ShowS
forall (n :: Nat) a. VectorType (Vec n a) -> Vec n a -> ShowS
vector VectorType (Vec n a)
t e
Vec n a
e

    single :: SingleType e -> e -> ShowS
    single :: SingleType e -> e -> ShowS
single (NumSingleType NumType e
t) = NumType e -> e -> ShowS
forall e. NumType e -> e -> ShowS
num NumType e
t

    num :: NumType e -> e -> ShowS
    num :: NumType e -> e -> ShowS
num (IntegralNumType IntegralType e
t) = IntegralType e -> e -> ShowS
forall e. IntegralType e -> e -> ShowS
integral IntegralType e
t
    num (FloatingNumType FloatingType e
t) = FloatingType e -> e -> ShowS
forall e. FloatingType e -> e -> ShowS
floating FloatingType e
t

    integral :: IntegralType e -> e -> ShowS
    integral :: IntegralType e -> e -> ShowS
integral IntegralType e
TypeInt    = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeInt8   = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeInt16  = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeInt32  = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeInt64  = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeWord   = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeWord8  = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeWord16 = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeWord32 = e -> ShowS
forall a. Show a => a -> ShowS
shows
    integral IntegralType e
TypeWord64 = e -> ShowS
forall a. Show a => a -> ShowS
shows

    floating :: FloatingType e -> e -> ShowS
    floating :: FloatingType e -> e -> ShowS
floating FloatingType e
TypeHalf   = e -> ShowS
forall a. Show a => a -> ShowS
shows
    floating FloatingType e
TypeFloat  = e -> ShowS
forall a. Show a => a -> ShowS
shows
    floating FloatingType e
TypeDouble = e -> ShowS
forall a. Show a => a -> ShowS
shows

    vector :: VectorType (Vec n a) -> Vec n a -> ShowS
    vector :: VectorType (Vec n a) -> Vec n a -> ShowS
vector (VectorType Int
_ SingleType a
s) Vec n a
vec
      | SingleDict a
SingleDict <- SingleType a -> SingleDict a
forall a. SingleType a -> SingleDict a
singleDict SingleType a
s
      = String -> ShowS
showString
      (String -> ShowS) -> String -> ShowS
forall a b. (a -> b) -> a -> b
$ String
"<" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " ((\a
v -> SingleType a -> a -> ShowS
forall e. SingleType e -> e -> ShowS
single SingleType a
s a
v String
"") (a -> String) -> [a] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vec n a -> [a]
forall a (n :: Nat). (Prim a, KnownNat n) => Vec n a -> [a]
listOfVec Vec n a
vec) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
">"

liftElt :: TypeR t -> t -> Q (TExp t)
liftElt :: TypeR t -> t -> Q (TExp t)
liftElt TypeR t
TupRunit         ()    = [|| () ||]
liftElt (TupRsingle ScalarType t
t)   t
x     = [|| $$(liftScalar t x) ||]
liftElt (TupRpair TupR ScalarType a
ta TupR ScalarType b
tb) (a,b) = [|| ($$(liftElt ta a), $$(liftElt tb b)) ||]