module Data.Array.Accelerate.Array.Lifted (
Vector'(..), LiftedArray,
LiftedTupleRepr,
IsConstrained(..),
isArraysFlat,
elements', shapes', empty', length', drop', vec2Vec', fromList', toList'
) where
import Prelude hiding ( concat )
import Data.Typeable
import Data.Array.Accelerate.Product
import Data.Array.Accelerate.Array.Sugar
import qualified Data.Array.Accelerate.Array.Representation as Repr
newtype Vector' a = Vector' (LiftedRepr (ArrRepr a) a)
deriving Typeable
type family LiftedRepr r a where
LiftedRepr () () = ((),Scalar Int)
LiftedRepr (Array sh e) (Array sh e) = (((),Segments sh), Vector e)
LiftedRepr (l,r) a = LiftedTupleRepr (TupleRepr a)
type family LiftedTupleRepr t :: *
type instance LiftedTupleRepr () = ()
type instance LiftedTupleRepr (b, a) = (LiftedTupleRepr b, Vector' a)
type LiftedArray sh e = Vector' (Array sh e)
instance Arrays t => IsProduct Arrays (Vector' t) where
type ProdRepr (Vector' t) = LiftedRepr (ArrRepr t) t
fromProd _ (Vector' t) = t
toProd _ = Vector'
prod _ _ = case flavour (undefined :: t) of
ArraysFunit -> ProdRsnoc ProdRunit
ArraysFarray -> ProdRsnoc (ProdRsnoc ProdRunit)
ArraysFtuple -> tup $ prod (Proxy :: Proxy Arrays) (undefined :: t)
where
tup :: forall a. ProdR Arrays a -> ProdR Arrays (LiftedTupleRepr a)
tup ProdRunit = ProdRunit
tup (ProdRsnoc t) = swiz
where
swiz :: forall l r. (a ~ (l,r), Arrays r) => ProdR Arrays (LiftedTupleRepr a)
swiz | IsC <- isArraysFlat (undefined :: r)
= ProdRsnoc (tup t)
type instance ArrRepr (Vector' a) = ArrRepr (TupleRepr (Vector' a))
instance (Arrays t, Typeable (ArrRepr (Vector' t))) => Arrays (Vector' t) where
arrays _ = arrs (prod (Proxy :: Proxy Arrays) (undefined :: Vector' t))
where
arrs :: forall a. ProdR Arrays a -> ArraysR (ArrRepr a)
arrs ProdRunit = ArraysRunit
arrs (ProdRsnoc t) = ArraysRpair (ArraysRpair ArraysRunit (arrs t)) (arrays t')
where t' :: (a ~ (l,r)) => r
t' = undefined
flavour _ = case flavour (undefined :: t) of
ArraysFunit -> ArraysFtuple
ArraysFarray -> ArraysFtuple
ArraysFtuple | ProdRsnoc _ <- prod (Proxy :: Proxy Arrays) (undefined::t)
-> ArraysFtuple
| otherwise -> error "Absurd"
fromArr (Vector' vt) = fa (prod (Proxy :: Proxy Arrays) (undefined :: Vector' t)) vt
where
fa :: forall a. ProdR Arrays a -> a -> ArrRepr a
fa ProdRunit () = ()
fa (ProdRsnoc t) (l,a) = (((), fa t l), fromArr a)
toArr = Vector' . ta (prod (Proxy :: Proxy Arrays) (undefined :: Vector' t))
where
ta :: forall a. ProdR Arrays a -> ArrRepr a -> a
ta ProdRunit () = ()
ta (ProdRsnoc t) (((),l),a) = (ta t l, toArr a)
data IsConstrained c where
IsC :: c => IsConstrained c
type IsTypeableArrRepr t = IsConstrained (Typeable (ArrRepr t))
type IsArraysFlat t = IsConstrained (Arrays (Vector' t))
isTypeableArrRepr :: forall t. Arrays t => t -> IsTypeableArrRepr (Vector' t)
isTypeableArrRepr _ =
case flavour (undefined :: t) of
ArraysFunit -> IsC
ArraysFarray -> IsC
ArraysFtuple | IsC <- isT (prod (Proxy :: Proxy Arrays) (undefined :: Vector' t))
-> IsC
where
isT :: ProdR Arrays t' -> IsTypeableArrRepr t'
isT ProdRunit = IsC
isT (ProdRsnoc t) | IsC <- isT t = IsC
isArraysFlat :: forall t. Arrays t => t -> IsArraysFlat t
isArraysFlat t = case flavour t of
ArraysFunit -> IsC
ArraysFtuple | IsC <- isTypeableArrRepr t
-> IsC
ArraysFarray -> IsC
scalar :: Elt a => a -> Scalar a
scalar n = fromList Z [n]
emptyVec :: Elt a => Vector a
emptyVec = fromList (Z :. (0 :: Int)) []
flatten :: Array sh e -> Vector e
flatten (Array sh e) = Array ((), Repr.size sh) e
elements' :: Vector' (Array sh e) -> Vector e
elements' (Vector' (_, elts)) = elts
shapes' :: Vector' (Array sh a) -> Vector sh
shapes' (Vector' (((), shapes), _)) = shapes
empty' :: forall a. Arrays a => Vector' a
empty' = Vector' $
case flavour (undefined :: a) of
ArraysFunit -> ((), scalar 0)
ArraysFarray -> (((), emptyVec), emptyVec)
ArraysFtuple -> tup (prod (Proxy :: Proxy Arrays) (undefined :: a))
where
tup :: forall t. ProdR Arrays t -> LiftedTupleRepr t
tup ProdRunit = ()
tup (ProdRsnoc t) = (tup t, empty')
length' :: forall a. Arrays a => Vector' a -> Int
length' (Vector' x) =
case flavour (undefined :: a) of
ArraysFunit | ((), n) <- x
-> n ! Z
ArraysFarray | (((), Array ((), n) _), _) <- x
-> n
ArraysFtuple -> tup (prod (Proxy :: Proxy Arrays) (undefined :: a)) x
where
tup :: forall t. ProdR Arrays t -> LiftedTupleRepr t -> Int
tup ProdRunit () = error "unreachable"
tup (ProdRsnoc _) (_, b) = length' b
drop' :: forall a. Arrays a
=> (forall e. Elt e => Int -> Vector e -> Vector e)
-> (forall sh. Shape sh => Segments sh -> Vector Int)
-> Int -> Vector' a -> Vector' a
drop' dropVec s2o k (Vector' x) = Vector' $
case flavour (undefined :: a) of
ArraysFunit | ((), n ) <- x
-> ((), scalar (n ! Z k `max` 0))
ArraysFarray | (((), segs), vals) <- x
, Array ((), n) _ <- segs
, k < n
-> let offsets = s2o segs
k' = offsets ! (Z :. k)
in (((), dropVec k segs), dropVec k' vals)
ArraysFarray -> (((), emptyVec), emptyVec)
ArraysFtuple -> tup (prod (Proxy :: Proxy Arrays) (undefined :: a)) x
where
tup :: forall t. ProdR Arrays t -> LiftedTupleRepr t -> LiftedTupleRepr t
tup ProdRunit () = ()
tup (ProdRsnoc t) (a, b) = (tup t a, drop' dropVec s2o k b)
vec2Vec' :: Elt e => Vector e -> Vector' (Scalar e)
vec2Vec' v = Vector' (((), undefined), v)
toList' :: forall a. Arrays a
=> (forall sh e. (Shape sh, Elt e) => Segments sh -> Vector e -> [Array sh e])
-> Vector' a -> [a]
toList' fetchAll (Vector' x) =
case flavour (undefined :: a) of
ArraysFunit | ((), n) <- x -> replicate (n ! Z) ()
ArraysFarray | (((), lens), vals) <- x
-> fetchAll lens vals
ArraysFtuple -> map (toProd (Proxy :: Proxy Arrays)) (tup (prod (Proxy :: Proxy Arrays) (undefined :: a)) x)
where
tup :: forall t. ProdR Arrays t -> LiftedTupleRepr t -> [t]
tup ProdRunit () = repeat ()
tup (ProdRsnoc t) (a, b) = tup t a `zip` toList' fetchAll b
fromList' :: forall a. Arrays a
=> (forall e. Elt e => [Vector e] -> Vector e)
-> [a] -> Vector' a
fromList' concat xs = Vector' $
case flavour (undefined :: a) of
ArraysFunit -> ((), scalar (length xs))
ArraysFarray ->
let segs = map shape xs
vals = concat (map flatten xs)
in (((), fromList (Z :. length segs) segs), vals)
ArraysFtuple -> tup (prod (Proxy :: Proxy Arrays) (undefined :: a)) (map (fromProd (Proxy :: Proxy Arrays)) xs)
where
tup :: forall t. ProdR Arrays t -> [t] -> LiftedTupleRepr t
tup ProdRunit _ = ()
tup (ProdRsnoc t) a = (tup t (Prelude.map fst a), fromList' concat (map snd a))