module Data.Array.Knead.Shape.Nested (
C(..),
Size,
value,
paramWith,
load,
intersect,
flattenIndex,
Range(..),
Shifted(..),
Scalar(..),
Sequence(..),
) where
import qualified Data.Array.Knead.Expression as Expr
import qualified Data.Array.Knead.Parameter as Param
import Data.Array.Knead.Expression (Exp, )
import qualified LLVM.Extra.Multi.Value.Memory as MultiMem
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Iterator as IterMV
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Arithmetic as A
import LLVM.Extra.Multi.Value (atom)
import qualified LLVM.Util.Loop as Loop
import qualified LLVM.Core as LLVM
import qualified Type.Data.Num.Decimal as TypeNum
import Foreign.Storable
(Storable, sizeOf, alignment, poke, peek, pokeElemOff, peekElemOff)
import Foreign.Ptr (Ptr, castPtr)
import Data.Word (Word8, Word16, Word32, Word64)
import Data.Int (Int8, Int16, Int32, Int64)
import qualified Control.Monad.HT as Monad
import Control.Applicative ((<$>))
type Size = Word64
value :: (C sh, Expr.Value val) => sh -> val sh
value = Expr.lift0 . MultiValue.cons
paramWith ::
(Storable b, MultiMem.C b, Expr.Value val) =>
Param.T p b ->
(forall parameters.
(Storable parameters, MultiMem.C parameters) =>
(p -> parameters) ->
(MultiValue.T parameters -> val b) ->
a) ->
a
paramWith p f =
Param.withMulti p (\get val -> f get (Expr.lift0 . val))
load ::
(MultiMem.C sh) =>
f sh -> LLVM.Value (Ptr (MultiMem.Struct sh)) ->
LLVM.CodeGenFunction r (MultiValue.T sh)
load _ = MultiMem.load
intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh
intersect = Expr.liftM2 intersectCode
flattenIndex ::
(C sh) =>
MultiValue.T sh -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Size)
flattenIndex sh ix =
fmap snd $ flattenIndexRec sh ix
class (MultiValue.C sh, MultiValue.C (Index sh)) => C sh where
type Index sh :: *
intersectCode ::
MultiValue.T sh -> MultiValue.T sh ->
LLVM.CodeGenFunction r (MultiValue.T sh)
sizeCode ::
MultiValue.T sh ->
LLVM.CodeGenFunction r (LLVM.Value Size)
size :: sh -> Int
flattenIndexRec ::
MultiValue.T sh -> MultiValue.T (Index sh) ->
LLVM.CodeGenFunction r (LLVM.Value Size, LLVM.Value Size)
iterator :: (Index sh ~ ix) => MultiValue.T sh -> Iter.T r (MultiValue.T ix)
loop ::
(Index sh ~ ix, MultiValue.C ix, Loop.Phi state) =>
(MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) ->
MultiValue.T sh -> state -> LLVM.CodeGenFunction r state
loop f sh = Iter.mapState_ f (iterator sh)
instance C () where
type Index () = ()
intersectCode _ _ = return $ MultiValue.cons ()
sizeCode _ = return A.one
size _ = 1
flattenIndexRec _ _ = return (A.one, A.zero)
iterator = Iter.singleton
loop = id
class C sh => Scalar sh where
scalar :: (Expr.Value val) => val sh
zeroIndex :: (Expr.Value val) => f sh -> val (Index sh)
instance Scalar () where
scalar = Expr.lift0 $ MultiValue.Cons ()
zeroIndex _ = Expr.lift0 $ MultiValue.Cons ()
class
(C sh,
MultiValue.IntegerConstant (Index sh),
MultiValue.Additive (Index sh)) =>
Sequence sh where
sequenceShapeFromIndex ::
MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (MultiValue.T sh)
iteratorPrimitive ::
(MultiValue.Repr LLVM.Value j ~ LLVM.Value j,
Num j, LLVM.IsConst j, LLVM.IsInteger j,
LLVM.CmpRet j, LLVM.CmpResult j ~ Bool,
MultiValue.Additive j, MultiValue.IntegerConstant j) =>
MultiValue.T j -> Iter.T r (MultiValue.T j)
iteratorPrimitive (MultiValue.Cons n) = iteratorStart n MultiValue.zero
iteratorStart ::
(Num j, LLVM.IsConst j, LLVM.IsInteger j,
LLVM.CmpRet j, LLVM.CmpResult j ~ Bool,
MultiValue.Additive i, MultiValue.IntegerConstant i) =>
LLVM.Value j -> MultiValue.T i -> Iter.T r (MultiValue.T i)
iteratorStart n start = Iter.take n $ Iter.iterate MultiValue.inc start
instance C Word32 where
type Index Word32 = Word32
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = LLVM.ext n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
Monad.lift2 (,) (LLVM.ext n) (LLVM.ext i)
iterator = iteratorPrimitive
instance Sequence Word32 where
sequenceShapeFromIndex = return
instance C Word64 where
type Index Word64 = Word64
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = return n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = return (n, i)
iterator = iteratorPrimitive
instance Sequence Word64 where
sequenceShapeFromIndex = return
unsigned8 :: LLVM.Value Int8 -> LLVM.CodeGenFunction r (LLVM.Value Word8)
unsigned8 = LLVM.bitcast
instance C Int8 where
type Index Int8 = Int8
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = LLVM.ext =<< unsigned8 n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
Monad.lift2 (,) (LLVM.ext =<< unsigned8 n) (LLVM.ext =<< unsigned8 i)
iterator = iteratorPrimitive
instance Sequence Int8 where
sequenceShapeFromIndex = return
unsigned16 :: LLVM.Value Int16 -> LLVM.CodeGenFunction r (LLVM.Value Word16)
unsigned16 = LLVM.bitcast
instance C Int16 where
type Index Int16 = Int16
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = LLVM.ext =<< unsigned16 n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
Monad.lift2 (,) (LLVM.ext =<< unsigned16 n) (LLVM.ext =<< unsigned16 i)
iterator = iteratorPrimitive
instance Sequence Int16 where
sequenceShapeFromIndex = return
instance C Int32 where
type Index Int32 = Int32
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = LLVM.zext n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
Monad.lift2 (,) (LLVM.zext n) (LLVM.zext i)
iterator = iteratorPrimitive
instance Sequence Int32 where
sequenceShapeFromIndex = return
instance C Int64 where
type Index Int64 = Int64
intersectCode = MultiValue.min
sizeCode (MultiValue.Cons n) = LLVM.bitcast n
size = fromIntegral
flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) =
Monad.lift2 (,) (LLVM.bitcast n) (LLVM.bitcast i)
iterator = iteratorPrimitive
instance Sequence Int64 where
sequenceShapeFromIndex = return
data Range n = Range n n
singletonRange :: n -> Range n
singletonRange n = Range n n
castToElemPtr :: Ptr (f a) -> Ptr a
castToElemPtr = castPtr
instance Storable n => Storable (Range n) where
sizeOf ~(Range l r) = sizeOf l + mod ( sizeOf l) (alignment r) + sizeOf r
alignment ~(Range l _) = alignment l
poke p (Range l r) =
let q = castToElemPtr p
in poke q l >> pokeElemOff q 1 r
peek p =
let q = castToElemPtr p
in Monad.lift2 Range (peek q) (peekElemOff q 1)
class
(MultiValue.Additive n, MultiValue.Real n, MultiValue.IntegerConstant n) =>
ToSize n where
toSize :: MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value Size)
instance ToSize Word32 where toSize (MultiValue.Cons n) = LLVM.ext n
instance ToSize Word64 where toSize (MultiValue.Cons n) = return n
instance ToSize Int32 where toSize (MultiValue.Cons n) = LLVM.zext n
instance ToSize Int64 where toSize (MultiValue.Cons n) = LLVM.bitcast n
rangeSize ::
(ToSize n) =>
Range (MultiValue.T n) -> LLVM.CodeGenFunction r (LLVM.Value Size)
rangeSize (Range from to) =
toSize =<< MultiValue.inc =<< MultiValue.sub to from
unzipRange :: MultiValue.T (Range n) -> Range (MultiValue.T n)
unzipRange (MultiValue.Cons (Range from to)) =
Range (MultiValue.Cons from) (MultiValue.Cons to)
zipRange :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Range n)
zipRange (MultiValue.Cons from) (MultiValue.Cons to) =
MultiValue.Cons (Range from to)
instance (MultiValue.C n) => MultiValue.C (Range n) where
type Repr f (Range n) = Range (MultiValue.Repr f n)
cons (Range from to) = zipRange (MultiValue.cons from) (MultiValue.cons to)
undef = MultiValue.compose $ singletonRange MultiValue.undef
zero = MultiValue.compose $ singletonRange MultiValue.zero
phis bb a =
case unzipRange a of
Range a0 a1 ->
Monad.lift2 zipRange (MultiValue.phis bb a0) (MultiValue.phis bb a1)
addPhis bb a b =
case (unzipRange a, unzipRange b) of
(Range a0 a1, Range b0 b1) ->
MultiValue.addPhis bb a0 b0 >>
MultiValue.addPhis bb a1 b1
type instance
MultiValue.Decomposed f (Range pn) =
Range (MultiValue.Decomposed f pn)
type instance
MultiValue.PatternTuple (Range pn) =
Range (MultiValue.PatternTuple pn)
instance (MultiValue.Compose n) => MultiValue.Compose (Range n) where
type Composed (Range n) = Range (MultiValue.Composed n)
compose (Range from to) =
zipRange (MultiValue.compose from) (MultiValue.compose to)
instance (MultiValue.Decompose pn) => MultiValue.Decompose (Range pn) where
decompose (Range pfrom pto) rng =
case unzipRange rng of
Range from to ->
Range
(MultiValue.decompose pfrom from)
(MultiValue.decompose pto to)
instance (MultiMem.C n) => MultiMem.C (Range n) where
type Struct (Range n) = PairStruct n
decompose = fmap (uncurry zipRange) . decomposeGen
compose x = case unzipRange x of Range n m -> composeGen n m
instance (Integral n, ToSize n, MultiValue.Comparison n) => C (Range n) where
type Index (Range n) = n
intersectCode =
MultiValue.modifyF2 (singletonRange atom) (singletonRange atom) $
\(Range fromN toN) (Range fromM toM) ->
Monad.lift2 Range (MultiValue.max fromN fromM) (MultiValue.min toN toM)
sizeCode = rangeSize . unzipRange
size (Range from to) = fromIntegral $ tofrom+1
flattenIndexRec rngValue i =
case unzipRange rngValue of
rng@(Range from _to) ->
Monad.lift2 (,) (rangeSize rng) (toSize =<< MultiValue.sub i from)
iterator rngValue =
case MultiValue.decompose (singletonRange atom) rngValue of
Range from to ->
IterMV.takeWhile (MultiValue.cmp LLVM.CmpGE to) $
Iter.iterate MultiValue.inc from
data Shifted n = Shifted {shiftedOffset, shiftedSize :: n}
singletonShifted :: n -> Shifted n
singletonShifted n = Shifted n n
instance Storable n => Storable (Shifted n) where
sizeOf ~(Shifted l n) = sizeOf l + mod ( sizeOf l) (alignment n) + sizeOf n
alignment ~(Shifted l _) = alignment l
poke p (Shifted l n) =
let q = castToElemPtr p
in poke q l >> pokeElemOff q 1 n
peek p =
let q = castToElemPtr p
in Monad.lift2 Shifted (peek q) (peekElemOff q 1)
unzipShifted :: MultiValue.T (Shifted n) -> Shifted (MultiValue.T n)
unzipShifted (MultiValue.Cons (Shifted from to)) =
Shifted (MultiValue.Cons from) (MultiValue.Cons to)
zipShifted :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Shifted n)
zipShifted (MultiValue.Cons from) (MultiValue.Cons to) =
MultiValue.Cons (Shifted from to)
instance (MultiValue.C n) => MultiValue.C (Shifted n) where
type Repr f (Shifted n) = Shifted (MultiValue.Repr f n)
cons (Shifted offset len) =
zipShifted (MultiValue.cons offset) (MultiValue.cons len)
undef = MultiValue.compose $ singletonShifted MultiValue.undef
zero = MultiValue.compose $ singletonShifted MultiValue.zero
phis bb a =
case unzipShifted a of
Shifted a0 a1 ->
Monad.lift2 zipShifted
(MultiValue.phis bb a0) (MultiValue.phis bb a1)
addPhis bb a b =
case (unzipShifted a, unzipShifted b) of
(Shifted a0 a1, Shifted b0 b1) ->
MultiValue.addPhis bb a0 b0 >>
MultiValue.addPhis bb a1 b1
type instance
MultiValue.Decomposed f (Shifted pn) =
Shifted (MultiValue.Decomposed f pn)
type instance
MultiValue.PatternTuple (Shifted pn) =
Shifted (MultiValue.PatternTuple pn)
instance (MultiValue.Compose n) => MultiValue.Compose (Shifted n) where
type Composed (Shifted n) = Shifted (MultiValue.Composed n)
compose (Shifted offset len) =
zipShifted (MultiValue.compose offset) (MultiValue.compose len)
instance (MultiValue.Decompose pn) => MultiValue.Decompose (Shifted pn) where
decompose (Shifted poffset plen) rng =
case unzipShifted rng of
Shifted offset len ->
Shifted
(MultiValue.decompose poffset offset)
(MultiValue.decompose plen len)
instance (MultiMem.C n) => MultiMem.C (Shifted n) where
type Struct (Shifted n) = PairStruct n
decompose = fmap (uncurry zipShifted) . decomposeGen
compose x = case unzipShifted x of Shifted n m -> composeGen n m
type PairStruct n = LLVM.Struct (MultiMem.Struct n, (MultiMem.Struct n, ()))
decomposeGen ::
(MultiMem.C n) =>
LLVM.Value (PairStruct n) ->
LLVM.CodeGenFunction r (MultiValue.T n, MultiValue.T n)
decomposeGen nm =
Monad.lift2 (,)
(MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d0)
(MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d1)
composeGen ::
(MultiMem.C n) =>
MultiValue.T n -> MultiValue.T n ->
LLVM.CodeGenFunction r (LLVM.Value (PairStruct n))
composeGen n m = do
sn <- MultiMem.compose n
sm <- MultiMem.compose m
rn <- LLVM.insertvalue (LLVM.value LLVM.undef) sn TypeNum.d0
LLVM.insertvalue rn sm TypeNum.d1
instance (Integral n, ToSize n, MultiValue.Comparison n) => C (Shifted n) where
type Index (Shifted n) = n
intersectCode =
MultiValue.modifyF2 (singletonShifted atom) (singletonShifted atom) $
\(Shifted offsetN lenN) (Shifted offsetM lenM) -> do
offset <- MultiValue.max offsetN offsetM
endN <- MultiValue.add offsetN lenN
endM <- MultiValue.add offsetM lenM
end <- MultiValue.min endN endM
Shifted offset <$> MultiValue.sub end offset
sizeCode = toSize . shiftedSize . unzipShifted
size (Shifted _offset len) = fromIntegral len
flattenIndexRec shapeValue i =
case unzipShifted shapeValue of
Shifted offset len ->
Monad.lift2 (,) (toSize len) (toSize =<< MultiValue.sub i offset)
iterator rngValue =
case MultiValue.decompose (singletonShifted atom) rngValue of
Shifted from len ->
IterMV.take len $ Iter.iterate MultiValue.inc from
instance (C n, C m) => C (n,m) where
type Index (n,m) = (Index n, Index m)
intersectCode a b =
case (MultiValue.unzip a, MultiValue.unzip b) of
((an,am), (bn,bm)) ->
Monad.lift2 MultiValue.zip
(intersectCode an bn)
(intersectCode am bm)
sizeCode nm =
case MultiValue.unzip nm of
(n,m) -> Monad.liftJoin2 A.mul (sizeCode n) (sizeCode m)
size (n,m) = size n * size m
flattenIndexRec nm ij =
case (MultiValue.unzip nm, MultiValue.unzip ij) of
((n,m), (i,j)) -> do
(ns, il) <- flattenIndexRec n i
(ms, jl) <- flattenIndexRec m j
Monad.lift2 (,)
(A.mul ns ms)
(A.add jl =<< A.mul ms il)
iterator nm =
case MultiValue.unzip nm of
(n,m) ->
uncurry MultiValue.zip <$>
Iter.cartesian (iterator n) (iterator m)
loop code nm =
case MultiValue.unzip nm of
(n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n
instance (C n, C m, C l) => C (n,m,l) where
type Index (n,m,l) = (Index n, Index m, Index l)
intersectCode a b =
case (MultiValue.unzip3 a, MultiValue.unzip3 b) of
((ai,aj,ak), (bi,bj,bk)) ->
Monad.lift3 MultiValue.zip3
(intersectCode ai bi)
(intersectCode aj bj)
(intersectCode ak bk)
sizeCode nml =
case MultiValue.unzip3 nml of
(n,m,l) ->
Monad.liftJoin2 A.mul (sizeCode n) $
Monad.liftJoin2 A.mul (sizeCode m) (sizeCode l)
size (n,m,l) = size n * size m * size l
flattenIndexRec nml ijk =
case (MultiValue.unzip3 nml, MultiValue.unzip3 ijk) of
((n,m,l), (i,j,k)) -> do
(ns, il) <- flattenIndexRec n i
(ms, jl) <- flattenIndexRec m j
x0 <- A.add jl =<< A.mul ms il
(ls, kl) <- flattenIndexRec l k
x1 <- A.add kl =<< A.mul ls x0
sz <- A.mul ns =<< A.mul ms ls
return (sz, x1)
iterator nml =
case MultiValue.unzip3 nml of
(n,m,l) ->
fmap (\(a,(b,c)) -> MultiValue.zip3 a b c) $
Iter.cartesian (iterator n) $
Iter.cartesian (iterator m) (iterator l)
loop code nml =
case MultiValue.unzip3 nml of
(n,m,l) ->
loop (\i -> loop (\j -> loop (\k ->
code (MultiValue.zip3 i j k))
l) m) n