module Data.Array.Knead.Shape.Cubic (
constant,
paramWith,
tunnel,
T(..),
Z(Z), z,
(:.)((:.)),
Shape,
Index,
cons, (#:.),
head,
tail,
switchR,
) where
import qualified Data.Array.Knead.Shape as Shape
import qualified Data.Array.Knead.Shape.Cubic.Int as Index
import qualified Data.Array.Knead.Expression as Expr
import Data.Array.Knead.Expression (Exp, )
import qualified Data.Array.Comfort.Shape as ComfortShape
import Data.Array.Comfort.Shape (ZeroBased(ZeroBased))
import qualified LLVM.DSL.Parameter as Param
import qualified LLVM.Extra.Multi.Value as MultiValue
import qualified LLVM.Extra.Multi.Iterator as IterMV
import qualified LLVM.Extra.Marshal as Marshal
import qualified LLVM.Extra.Iterator as Iter
import qualified LLVM.Extra.Arithmetic as A
import qualified LLVM.Extra.Tuple as Tuple
import qualified LLVM.Extra.Control as C
import LLVM.Extra.Multi.Value (Atom)
import qualified LLVM.Core as LLVM
import qualified Foreign.Storable as St
import Foreign.Storable.FixedArray (sizeOfArray, )
import Foreign.Ptr (castPtr, )
import qualified Type.Data.Num.Decimal as Dec
import qualified Type.Data.Num.Unary as Unary
import Type.Base.Proxy (Proxy(Proxy))
import qualified Data.Traversable as Trav
import qualified Data.Foldable as Fold
import qualified Data.FixedLength as FixedLength
import Data.FixedLength ((!:))
import Control.Monad (liftM2, )
import Control.Applicative (pure, (<$>), )
import Prelude hiding (min, head, tail, )
newtype T tag rank = Cons {decons :: FixedLength.T rank Index.Int}
data ShapeTag
data IndexTag
type Shape = T ShapeTag
type Index = T IndexTag
paramWith ::
(Unary.Natural rank,
Dec.Natural (Dec.FromUnary rank),
Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) =>
Param.T p (T tag rank) ->
(forall parameters.
(Marshal.MV parameters) =>
(p -> parameters) ->
(forall val. (Expr.Value val) =>
MultiValue.T parameters -> val (T tag rank)) ->
a) ->
a
paramWith p f =
case tunnel p of
Param.Tunnel get val -> f get (Expr.lift0 . val)
tunnel ::
(Unary.Natural rank,
Dec.Natural (Dec.FromUnary rank),
Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) =>
Param.T p (T tag rank) -> Param.Tunnel p (T tag rank)
tunnel p = Param.tunnel MultiValue.cons p
data Z = Z
deriving (Eq, Ord, Read, Show)
infixl 3 :., #:.
data tail :. head = !tail :. !head
deriving (Eq, Ord, Read, Show)
(#:.) ::
(Expr.Value val) =>
val (T tag rank) -> val Index.Int -> val (T tag (Unary.Succ rank))
(#:.) = cons
cons ::
(Expr.Value val) =>
val (T tag rank) -> val Index.Int -> val (T tag (Unary.Succ rank))
cons =
Expr.lift2 $
\(MultiValue.Cons t) (MultiValue.Cons h) -> MultiValue.Cons (h!:t)
z :: (Expr.Value val) => val (T tag Unary.Zero)
z = Expr.lift0 $ MultiValue.Cons FixedLength.end
head ::
(Expr.Value val, Unary.Natural rank) =>
val (T tag (Unary.Succ rank)) -> val Index.Int
head =
Expr.lift1 $ \(MultiValue.Cons sh) -> MultiValue.Cons $ FixedLength.head sh
tail ::
(Expr.Value val, Unary.Natural rank) =>
val (T tag (Unary.Succ rank)) -> val (T tag rank)
tail =
Expr.lift1 $ \(MultiValue.Cons sh) -> MultiValue.Cons $ FixedLength.tail sh
switchR ::
(Unary.Natural rank) =>
Expr.Value val =>
(val (T tag rank) -> val Index.Int -> a) ->
val (T tag (Unary.Succ rank)) -> a
switchR f ix = f (tail ix) (head ix)
rank :: T tag rank -> Proxy rank
rank (Cons _) = Proxy
instance (tag ~ ShapeTag, rank ~ Unary.Zero) => Shape.Scalar (T tag rank) where
scalar = Expr.lift0 $ MultiValue.Cons FixedLength.end
zeroIndex _ = Expr.lift0 $ MultiValue.Cons FixedLength.end
type family AtomRank sh
type instance AtomRank (Atom (T tag rank)) = rank
type instance AtomRank (sh:.s) = Unary.Succ (AtomRank s)
type family AtomTag sh
type instance AtomTag (Atom (T tag rank)) = tag
type instance AtomTag (sh:.s) = AtomTag sh
type instance MultiValue.PatternTuple (sh:.s) =
T (AtomTag sh) (Unary.Succ (AtomRank sh))
type instance MultiValue.Decomposed f (sh:.s) =
MultiValue.Decomposed f sh :. f Index.Int
instance
(Expr.Decompose sh, Expr.Decompose s,
MultiValue.Decomposed Exp s ~ Exp Index.Int,
MultiValue.PatternTuple s ~ Index.Int,
MultiValue.PatternTuple sh ~ T (AtomTag sh) (AtomRank sh),
Unary.Natural (AtomRank sh)) =>
Expr.Decompose (sh :. s) where
decompose (psh:.ps) x =
Expr.decompose psh (tail x) :. Expr.decompose ps (head x)
type family Rank sh
type instance Rank (T tag rank) = rank
type family Tag sh
type instance Tag (T tag rank) = tag
instance
(Expr.Compose sh,
Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Rank (Expr.Composed sh)),
Expr.Compose s,
Expr.Composed s ~ Index.Int) =>
Expr.Compose (sh :. s) where
type Composed (sh :. s) =
T (Tag (Expr.Composed sh)) (Unary.Succ (Rank (Expr.Composed sh)))
compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s)
instance (Unary.Natural rank) => St.Storable (T tag rank) where
sizeOf sh = sizeOfArray (Unary.integralFromProxy $ rank sh) (0::Shape.Size)
alignment (Cons _sh) = St.alignment (0::Shape.Size)
poke ptr = St.poke (castPtr ptr) . fmap (\(Index.Int i) -> i) . decons
peek = fmap (Cons . fmap Index.Int) . St.peek . castPtr
instance
(Unary.Natural rank,
Dec.Natural (Dec.FromUnary rank),
Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) =>
Marshal.C (T tag rank) where
pack = LLVM.Array . map Marshal.pack . Fold.toList . decons
unpack (LLVM.Array sh) = Cons $ toFixedList $ map Marshal.unpack sh
toFixedList :: (Unary.Natural n) => [a] -> FixedLength.T n a
toFixedList xs = snd $ Trav.mapAccumL (\(y:ys) () -> (ys,y)) xs (pure ())
instance
(Unary.Natural rank,
Dec.Natural (Dec.FromUnary rank),
Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) =>
Marshal.MV (T tag rank) where
instance (Unary.Natural rank) => Tuple.Value (T tag rank) where
type ValueOf (T tag rank) = FixedLength.T rank (Tuple.ValueOf Index.Int)
valueOf = fmap Tuple.valueOf . decons
instance (Unary.Natural rank) => MultiValue.C (T tag rank) where
cons = MultiValue.Cons . fmap (\(Index.Int i) -> LLVM.valueOf i) . decons
undef = constant $ MultiValue.undef
zero = constant $ MultiValue.zero
addPhi bb (MultiValue.Cons a) (MultiValue.Cons b) =
Tuple.addPhiFoldable bb a b
phi bb (MultiValue.Cons a) =
fmap MultiValue.Cons . Tuple.phiTraversable bb $ a
constant ::
(Unary.Natural rank) => MultiValue.T Index.Int -> MultiValue.T (T tag rank)
constant (MultiValue.Cons x) = MultiValue.Cons $ pure x
instance
(tag ~ ShapeTag, Unary.Natural rank) =>
ComfortShape.C (T tag rank) where
size = Fold.product . fmap (ComfortShape.size . shapeFromInt) . decons
instance
(tag ~ ShapeTag, Unary.Natural rank) =>
ComfortShape.Indexed (T tag rank) where
type Index (T tag rank) = Index rank
indices (Cons ix) =
map (Cons . fmap Index.Int) $
Trav.mapM (ComfortShape.indices . shapeFromInt) ix
inBounds (Cons sh) (Cons ix) =
Fold.and $
FixedLength.zipWith ComfortShape.inBounds
(shapeFromInt <$> sh) (indexFromInt <$> ix)
offset (Cons sh) (Cons ix) =
Fold.foldl'
(\off (s,i) -> off * ComfortShape.size s + fromIntegral i) 0 $
FixedLength.zipWith (,) (shapeFromInt <$> sh) (indexFromInt <$> ix)
shapeFromInt :: Index.Int -> ZeroBased Shape.Size
shapeFromInt (Index.Int i) = ZeroBased i
indexFromInt :: Index.Int -> Shape.Size
indexFromInt (Index.Int i) = i
instance (tag ~ ShapeTag, Unary.Natural rank) => Shape.C (T tag rank) where
size (MultiValue.Cons sh) = Fold.foldlM A.mul A.one sh
intersectCode (MultiValue.Cons sh0) (MultiValue.Cons sh1) =
fmap MultiValue.Cons $ Trav.sequence $ FixedLength.zipWith A.min sh0 sh1
sizeOffset sh =
liftM2 (,) (Shape.size sh) (return $ offsetCode sh)
iterator = iterator
loop = loop
offsetCode ::
(Unary.Natural rank) =>
MultiValue.T (Shape rank) -> MultiValue.T (Index rank) ->
LLVM.CodeGenFunction r (LLVM.Value Shape.Size)
offsetCode (MultiValue.Cons sh) (MultiValue.Cons ix) =
Fold.foldlM (\off (s,i) -> A.mul off s >>= A.add i) A.zero $
FixedLength.zipWith (,) sh ix
newtype Iterator r rank =
Iterator {
runIterator ::
MultiValue.T (Shape rank) -> Iter.T r (MultiValue.T (Index rank))
}
iterator ::
(Unary.Natural rank) =>
MultiValue.T (Shape rank) -> Iter.T r (MultiValue.T (Index rank))
iterator =
runIterator $
Unary.switchNat
(Iterator $ \ _z -> Iter.singleton z)
(Iterator $ switchR $ \sh n ->
fmap (\(ix,i) -> ix#:.i) $
Iter.cartesian
(iterator sh)
(IterMV.takeWhile (MultiValue.cmp LLVM.CmpGT n) $
Iter.iterate MultiValue.inc MultiValue.zero))
newtype Loop r state rank =
Loop {
runLoop ::
(MultiValue.T (Index rank) ->
state ->
LLVM.CodeGenFunction r state) ->
MultiValue.T (Shape rank) ->
state ->
LLVM.CodeGenFunction r state
}
loop ::
(Unary.Natural rank, Tuple.Phi state) =>
(MultiValue.T (Index rank) ->
state ->
LLVM.CodeGenFunction r state) ->
MultiValue.T (Shape rank) ->
state ->
LLVM.CodeGenFunction r state
loop =
runLoop $
Unary.switchNat
(Loop $ \code _z -> code z)
(Loop $ \code -> switchR $ \sh (MultiValue.Cons n) ->
loop
(\ix ptrStart ->
fmap fst $
C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) ->
liftM2 (,)
(code (ix #:. MultiValue.Cons k) ptr)
(A.inc k))
sh)