{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE Rank2Types #-} module Data.Array.Knead.Index.Nested.Shape where import qualified Data.Array.Knead.Expression as Expr import qualified Data.Array.Knead.Parameter as Param import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Control as C import LLVM.Extra.Monad (liftR2) import qualified LLVM.Util.Loop as Loop import qualified LLVM.Core as LLVM import Foreign.Storable (Storable, ) import Foreign.Ptr (Ptr, ) import Data.Word (Word32, Word64) import qualified Control.Monad.HT as Monad value :: (C sh, Expr.Value val) => sh -> val sh value = Expr.lift0 . MultiValue.cons paramWith :: (Storable b, MultiValueMemory.C b, Expr.Value val) => Param.T p b -> (forall parameters. (Storable parameters, MultiValueMemory.C parameters) => (p -> parameters) -> (MultiValue.T parameters -> val b) -> a) -> a paramWith p f = Param.withMulti p (\get val -> f get (Expr.lift0 . val)) load :: (MultiValueMemory.C sh) => f sh -> LLVM.Value (Ptr (MultiValueMemory.Struct sh)) -> LLVM.CodeGenFunction r (MultiValue.T sh) load _ = MultiValueMemory.load intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh intersect = Expr.liftM2 intersectCode flattenIndex :: (C sh) => MultiValue.T sh -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Word32) flattenIndex sh ix = fmap snd $ flattenIndexRec sh ix class (MultiValue.C sh) => C sh where type Index sh :: * {- It would be better to restrict zipWith to matching shapes and turn shape intersection into a bound check. -} intersectCode :: MultiValue.T sh -> MultiValue.T sh -> LLVM.CodeGenFunction r (MultiValue.T sh) sizeCode :: MultiValue.T sh -> LLVM.CodeGenFunction r (LLVM.Value Word32) size :: sh -> Int {- | Result is @(size, flattenedIndex)@. @size@ must equal the result of 'sizeCode'. We use this for sharing intermediate results. -} flattenIndexRec :: MultiValue.T sh -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Word32, LLVM.Value Word32) loop :: (Index sh ~ ix, Loop.Phi state) => (MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T sh -> state -> LLVM.CodeGenFunction r state instance C () where type Index () = () intersectCode _ _ = return $ MultiValue.cons () sizeCode _ = return A.one size _ = 1 flattenIndexRec _ _ = return (A.one, A.zero) loop = id class C sh => Scalar sh where scalar :: (Expr.Value val) => val sh zeroIndex :: (Expr.Value val) => f sh -> val (Index sh) instance Scalar () where scalar = Expr.lift0 $ MultiValue.Cons () zeroIndex _ = Expr.lift0 $ MultiValue.Cons () loopPrimitive :: (MultiValue.Repr LLVM.Value i ~ LLVM.Value i, Num i, LLVM.IsConst i, LLVM.IsInteger i, LLVM.CmpRet i, LLVM.CmpResult i ~ Bool, Loop.Phi state) => (MultiValue.T i -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T i -> state -> LLVM.CodeGenFunction r state loopPrimitive code (MultiValue.Cons n) ptrStart = fmap fst $ C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) -> Monad.lift2 (,) (code (MultiValue.Cons k) ptr) (A.inc k) instance C Word32 where type Index Word32 = Word32 intersectCode = MultiValue.min sizeCode (MultiValue.Cons n) = return n size = fromIntegral flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = return (n, i) loop = loopPrimitive instance C Word64 where type Index Word64 = Word64 intersectCode = MultiValue.min sizeCode (MultiValue.Cons n) = LLVM.trunc n size = fromIntegral flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = Monad.lift2 (,) (LLVM.trunc n) (LLVM.trunc i) loop = loopPrimitive instance (C n, C m) => C (n,m) where type Index (n,m) = (Index n, Index m) intersectCode a b = case (MultiValue.unzip a, MultiValue.unzip b) of ((an,am), (bn,bm)) -> Monad.lift2 MultiValue.zip (intersectCode an bn) (intersectCode am bm) sizeCode nm = case MultiValue.unzip nm of (n,m) -> liftR2 A.mul (sizeCode n) (sizeCode m) size (n,m) = size n * size m flattenIndexRec nm ij = case (MultiValue.unzip nm, MultiValue.unzip ij) of ((n,m), (i,j)) -> do (ns, il) <- flattenIndexRec n i (ms, jl) <- flattenIndexRec m j Monad.lift2 (,) (A.mul ns ms) (A.add jl =<< A.mul ms il) loop code nm = case MultiValue.unzip nm of (n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n instance (C n, C m, C l) => C (n,m,l) where type Index (n,m,l) = (Index n, Index m, Index l) intersectCode a b = case (MultiValue.unzip3 a, MultiValue.unzip3 b) of ((ai,aj,ak), (bi,bj,bk)) -> Monad.lift3 MultiValue.zip3 (intersectCode ai bi) (intersectCode aj bj) (intersectCode ak bk) sizeCode nml = case MultiValue.unzip3 nml of (n,m,l) -> liftR2 A.mul (sizeCode n) $ liftR2 A.mul (sizeCode m) (sizeCode l) size (n,m,l) = size n * size m * size l flattenIndexRec nml ijk = case (MultiValue.unzip3 nml, MultiValue.unzip3 ijk) of ((n,m,l), (i,j,k)) -> do (ns, il) <- flattenIndexRec n i (ms, jl) <- flattenIndexRec m j x0 <- A.add jl =<< A.mul ms il (ls, kl) <- flattenIndexRec l k x1 <- A.add kl =<< A.mul ls x0 sz <- A.mul ns =<< A.mul ms ls return (sz, x1) loop code nml = case MultiValue.unzip3 nml of (n,m,l) -> loop (\i -> loop (\j -> loop (\k -> code (MultiValue.zip3 i j k)) l) m) n