{-# LANGUAGE RankNTypes #-} {-# LANGUAGE EmptyDataDecls #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} module Data.Array.Knead.Shape.Cubic ( constant, paramWith, tunnel, T(..), Z(Z), z, (:.)((:.)), Shape, Index, cons, (#:.), head, tail, switchR, ) where import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Shape.Cubic.Int as Index import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, ) import qualified Data.Array.Comfort.Shape as ComfortShape import Data.Array.Comfort.Shape (ZeroBased(ZeroBased)) import qualified LLVM.DSL.Parameter as Param import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Multi.Iterator as IterMV import qualified LLVM.Extra.Marshal as Marshal import qualified LLVM.Extra.Iterator as Iter import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Tuple as Tuple import qualified LLVM.Extra.Control as C import LLVM.Extra.Multi.Value (Atom) import qualified LLVM.Core as LLVM import qualified Foreign.Storable as St import Foreign.Storable.FixedArray (sizeOfArray, ) import Foreign.Ptr (castPtr, ) import qualified Type.Data.Num.Decimal as Dec import qualified Type.Data.Num.Unary as Unary import Type.Base.Proxy (Proxy(Proxy)) import qualified Data.Traversable as Trav import qualified Data.Foldable as Fold import qualified Data.FixedLength as FixedLength import Data.FixedLength ((!:)) import Control.Monad (liftM2, ) import Control.Applicative (pure, (<$>), ) import Prelude hiding (min, head, tail, ) newtype T tag rank = Cons {decons :: FixedLength.T rank Index.Int} data ShapeTag data IndexTag type Shape = T ShapeTag type Index = T IndexTag paramWith :: (Unary.Natural rank, Dec.Natural (Dec.FromUnary rank), Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) => Param.T p (T tag rank) -> (forall parameters. (Marshal.MV parameters) => (p -> parameters) -> (forall val. (Expr.Value val) => MultiValue.T parameters -> val (T tag rank)) -> a) -> a paramWith p f = case tunnel p of Param.Tunnel get val -> f get (Expr.lift0 . val) tunnel :: (Unary.Natural rank, Dec.Natural (Dec.FromUnary rank), Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) => Param.T p (T tag rank) -> Param.Tunnel p (T tag rank) tunnel p = Param.tunnel MultiValue.cons p data Z = Z deriving (Eq, Ord, Read, Show) infixl 3 :., #:. data tail :. head = !tail :. !head deriving (Eq, Ord, Read, Show) (#:.) :: (Expr.Value val) => val (T tag rank) -> val Index.Int -> val (T tag (Unary.Succ rank)) (#:.) = cons cons :: (Expr.Value val) => val (T tag rank) -> val Index.Int -> val (T tag (Unary.Succ rank)) cons = Expr.lift2 $ \(MultiValue.Cons t) (MultiValue.Cons h) -> MultiValue.Cons (h!:t) z :: (Expr.Value val) => val (T tag Unary.Zero) z = Expr.lift0 $ MultiValue.Cons FixedLength.end head :: (Expr.Value val, Unary.Natural rank) => val (T tag (Unary.Succ rank)) -> val Index.Int head = Expr.lift1 $ \(MultiValue.Cons sh) -> MultiValue.Cons $ FixedLength.head sh tail :: (Expr.Value val, Unary.Natural rank) => val (T tag (Unary.Succ rank)) -> val (T tag rank) tail = Expr.lift1 $ \(MultiValue.Cons sh) -> MultiValue.Cons $ FixedLength.tail sh switchR :: (Unary.Natural rank) => Expr.Value val => (val (T tag rank) -> val Index.Int -> a) -> val (T tag (Unary.Succ rank)) -> a switchR f ix = f (tail ix) (head ix) rank :: T tag rank -> Proxy rank rank (Cons _) = Proxy instance (tag ~ ShapeTag, rank ~ Unary.Zero) => Shape.Scalar (T tag rank) where scalar = Expr.lift0 $ MultiValue.Cons FixedLength.end zeroIndex _ = Expr.lift0 $ MultiValue.Cons FixedLength.end type family AtomRank sh type instance AtomRank (Atom (T tag rank)) = rank type instance AtomRank (sh:.s) = Unary.Succ (AtomRank s) type family AtomTag sh type instance AtomTag (Atom (T tag rank)) = tag type instance AtomTag (sh:.s) = AtomTag sh type instance MultiValue.PatternTuple (sh:.s) = T (AtomTag sh) (Unary.Succ (AtomRank sh)) type instance MultiValue.Decomposed f (sh:.s) = MultiValue.Decomposed f sh :. f Index.Int instance (Expr.Decompose sh, Expr.Decompose s, MultiValue.Decomposed Exp s ~ Exp Index.Int, MultiValue.PatternTuple s ~ Index.Int, MultiValue.PatternTuple sh ~ T (AtomTag sh) (AtomRank sh), Unary.Natural (AtomRank sh)) => Expr.Decompose (sh :. s) where decompose (psh:.ps) x = Expr.decompose psh (tail x) :. Expr.decompose ps (head x) type family Rank sh type instance Rank (T tag rank) = rank type family Tag sh type instance Tag (T tag rank) = tag instance (Expr.Compose sh, Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Rank (Expr.Composed sh)), Expr.Compose s, Expr.Composed s ~ Index.Int) => Expr.Compose (sh :. s) where type Composed (sh :. s) = T (Tag (Expr.Composed sh)) (Unary.Succ (Rank (Expr.Composed sh))) compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s) instance (Unary.Natural rank) => St.Storable (T tag rank) where sizeOf sh = sizeOfArray (Unary.integralFromProxy $ rank sh) (0::Shape.Size) alignment (Cons _sh) = St.alignment (0::Shape.Size) poke ptr = St.poke (castPtr ptr) . fmap (\(Index.Int i) -> i) . decons peek = fmap (Cons . fmap Index.Int) . St.peek . castPtr instance (Unary.Natural rank, Dec.Natural (Dec.FromUnary rank), Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) => Marshal.C (T tag rank) where pack = LLVM.Array . map Marshal.pack . Fold.toList . decons unpack (LLVM.Array sh) = Cons $ toFixedList $ map Marshal.unpack sh toFixedList :: (Unary.Natural n) => [a] -> FixedLength.T n a toFixedList xs = snd $ Trav.mapAccumL (\(y:ys) () -> (ys,y)) xs (pure ()) instance (Unary.Natural rank, Dec.Natural (Dec.FromUnary rank), Dec.Natural (Dec.FromUnary rank Dec.:*: LLVM.SizeOf Shape.Size)) => Marshal.MV (T tag rank) where instance (Unary.Natural rank) => Tuple.Value (T tag rank) where type ValueOf (T tag rank) = FixedLength.T rank (Tuple.ValueOf Index.Int) valueOf = fmap Tuple.valueOf . decons instance (Unary.Natural rank) => MultiValue.C (T tag rank) where cons = MultiValue.Cons . fmap (\(Index.Int i) -> LLVM.valueOf i) . decons undef = constant $ MultiValue.undef zero = constant $ MultiValue.zero addPhi bb (MultiValue.Cons a) (MultiValue.Cons b) = Tuple.addPhiFoldable bb a b phi bb (MultiValue.Cons a) = fmap MultiValue.Cons . Tuple.phiTraversable bb $ a constant :: (Unary.Natural rank) => MultiValue.T Index.Int -> MultiValue.T (T tag rank) constant (MultiValue.Cons x) = MultiValue.Cons $ pure x instance (tag ~ ShapeTag, Unary.Natural rank) => ComfortShape.C (T tag rank) where size = Fold.product . fmap (ComfortShape.size . shapeFromInt) . decons instance (tag ~ ShapeTag, Unary.Natural rank) => ComfortShape.Indexed (T tag rank) where type Index (T tag rank) = Index rank indices (Cons ix) = map (Cons . fmap Index.Int) $ Trav.mapM (ComfortShape.indices . shapeFromInt) ix inBounds (Cons sh) (Cons ix) = Fold.and $ FixedLength.zipWith ComfortShape.inBounds (shapeFromInt <$> sh) (indexFromInt <$> ix) offset (Cons sh) (Cons ix) = Fold.foldl' (\off (s,i) -> off * ComfortShape.size s + fromIntegral i) 0 $ FixedLength.zipWith (,) (shapeFromInt <$> sh) (indexFromInt <$> ix) shapeFromInt :: Index.Int -> ZeroBased Shape.Size shapeFromInt (Index.Int i) = ZeroBased i indexFromInt :: Index.Int -> Shape.Size indexFromInt (Index.Int i) = i instance (tag ~ ShapeTag, Unary.Natural rank) => Shape.C (T tag rank) where size (MultiValue.Cons sh) = Fold.foldlM A.mul A.one sh intersectCode (MultiValue.Cons sh0) (MultiValue.Cons sh1) = fmap MultiValue.Cons $ Trav.sequence $ FixedLength.zipWith A.min sh0 sh1 sizeOffset sh = -- would a joint implementation be more efficient? liftM2 (,) (Shape.size sh) (return $ offsetCode sh) iterator = iterator loop = loop offsetCode :: (Unary.Natural rank) => MultiValue.T (Shape rank) -> MultiValue.T (Index rank) -> LLVM.CodeGenFunction r (LLVM.Value Shape.Size) offsetCode (MultiValue.Cons sh) (MultiValue.Cons ix) = Fold.foldlM (\off (s,i) -> A.mul off s >>= A.add i) A.zero $ FixedLength.zipWith (,) sh ix newtype Iterator r rank = Iterator { runIterator :: MultiValue.T (Shape rank) -> Iter.T r (MultiValue.T (Index rank)) } iterator :: (Unary.Natural rank) => MultiValue.T (Shape rank) -> Iter.T r (MultiValue.T (Index rank)) iterator = runIterator $ Unary.switchNat (Iterator $ \ _z -> Iter.singleton z) (Iterator $ switchR $ \sh n -> fmap (\(ix,i) -> ix#:.i) $ Iter.cartesian (iterator sh) (IterMV.takeWhile (MultiValue.cmp LLVM.CmpGT n) $ Iter.iterate MultiValue.inc MultiValue.zero)) newtype Loop r state rank = Loop { runLoop :: (MultiValue.T (Index rank) -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T (Shape rank) -> state -> LLVM.CodeGenFunction r state } loop :: (Unary.Natural rank, Tuple.Phi state) => (MultiValue.T (Index rank) -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T (Shape rank) -> state -> LLVM.CodeGenFunction r state loop = runLoop $ Unary.switchNat (Loop $ \code _z -> code z) (Loop $ \code -> switchR $ \sh (MultiValue.Cons n) -> loop (\ix ptrStart -> fmap fst $ C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) -> liftM2 (,) (code (ix #:. MultiValue.Cons k) ptr) (A.inc k)) sh)