{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE FlexibleContexts #-} module Data.Array.Knead.Shape.Nested ( C(..), Size, value, paramWith, load, intersect, flattenIndex, Range(..), Shifted(..), Scalar(..), Sequence(..), ) where import qualified Data.Array.Knead.Expression as Expr import qualified Data.Array.Knead.Parameter as Param import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiMem import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Multi.Iterator as IterMV import qualified LLVM.Extra.Iterator as Iter import qualified LLVM.Extra.Arithmetic as A import LLVM.Extra.Multi.Value (atom) import qualified LLVM.Util.Loop as Loop import qualified LLVM.Core as LLVM import qualified Type.Data.Num.Decimal as TypeNum import Foreign.Storable (Storable, sizeOf, alignment, poke, peek, pokeElemOff, peekElemOff) import Foreign.Ptr (Ptr, castPtr) import Data.Word (Word8, Word16, Word32, Word64) import Data.Int (Int8, Int16, Int32, Int64) import qualified Control.Monad.HT as Monad import Control.Applicative ((<$>)) type Size = Word64 value :: (C sh, Expr.Value val) => sh -> val sh value = Expr.lift0 . MultiValue.cons paramWith :: (Storable b, MultiMem.C b, Expr.Value val) => Param.T p b -> (forall parameters. (Storable parameters, MultiMem.C parameters) => (p -> parameters) -> (MultiValue.T parameters -> val b) -> a) -> a paramWith p f = Param.withMulti p (\get val -> f get (Expr.lift0 . val)) load :: (MultiMem.C sh) => f sh -> LLVM.Value (Ptr (MultiMem.Struct sh)) -> LLVM.CodeGenFunction r (MultiValue.T sh) load _ = MultiMem.load intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh intersect = Expr.liftM2 intersectCode flattenIndex :: (C sh) => MultiValue.T sh -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Size) flattenIndex sh ix = fmap snd $ flattenIndexRec sh ix class (MultiValue.C sh, MultiValue.C (Index sh)) => C sh where type Index sh :: * {- It would be better to restrict zipWith to matching shapes and turn shape intersection into a bound check. -} intersectCode :: MultiValue.T sh -> MultiValue.T sh -> LLVM.CodeGenFunction r (MultiValue.T sh) sizeCode :: MultiValue.T sh -> LLVM.CodeGenFunction r (LLVM.Value Size) size :: sh -> Int {- | Result is @(size, flattenedIndex)@. @size@ must equal the result of 'sizeCode'. We use this for sharing intermediate results. -} flattenIndexRec :: MultiValue.T sh -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Size, LLVM.Value Size) iterator :: (Index sh ~ ix) => MultiValue.T sh -> Iter.T r (MultiValue.T ix) loop :: (Index sh ~ ix, MultiValue.C ix, Loop.Phi state) => (MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T sh -> state -> LLVM.CodeGenFunction r state loop f sh = Iter.mapState_ f (iterator sh) instance C () where type Index () = () intersectCode _ _ = return $ MultiValue.cons () sizeCode _ = return A.one size _ = 1 flattenIndexRec _ _ = return (A.one, A.zero) iterator = Iter.singleton loop = id class C sh => Scalar sh where scalar :: (Expr.Value val) => val sh zeroIndex :: (Expr.Value val) => f sh -> val (Index sh) instance Scalar () where scalar = Expr.lift0 $ MultiValue.Cons () zeroIndex _ = Expr.lift0 $ MultiValue.Cons () class (C sh, MultiValue.IntegerConstant (Index sh), MultiValue.Additive (Index sh)) => Sequence sh where sequenceShapeFromIndex :: MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (MultiValue.T sh) iteratorPrimitive :: (MultiValue.Repr LLVM.Value j ~ LLVM.Value j, Num j, LLVM.IsConst j, LLVM.IsInteger j, LLVM.CmpRet j, LLVM.CmpResult j ~ Bool, MultiValue.Additive j, MultiValue.IntegerConstant j) => MultiValue.T j -> Iter.T r (MultiValue.T j) iteratorPrimitive (MultiValue.Cons n) = iteratorStart n MultiValue.zero iteratorStart :: (Num j, LLVM.IsConst j, LLVM.IsInteger j, LLVM.CmpRet j, LLVM.CmpResult j ~ Bool, MultiValue.Additive i, MultiValue.IntegerConstant i) => LLVM.Value j -> MultiValue.T i -> Iter.T r (MultiValue.T i) iteratorStart n start = Iter.take n $ Iter.iterate MultiValue.inc start instance C Word32 where type Index Word32 = Word32 intersectCode = MultiValue.min sizeCode (MultiValue.Cons n) = LLVM.ext n size = fromIntegral flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = Monad.lift2 (,) (LLVM.ext n) (LLVM.ext i) iterator = iteratorPrimitive instance Sequence Word32 where sequenceShapeFromIndex = return instance C Word64 where type Index Word64 = Word64 intersectCode = MultiValue.min sizeCode (MultiValue.Cons n) = return n size = fromIntegral flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = return (n, i) iterator = iteratorPrimitive instance Sequence Word64 where sequenceShapeFromIndex = return {- | Array dimensions and indexes cannot be negative, but computations in indices may temporarily yield negative values or we want to add negative values to indices. Maybe we should better have type Index Word64 = Int64? -} unsigned8 :: LLVM.Value Int8 -> LLVM.CodeGenFunction r (LLVM.Value Word8) unsigned8 = LLVM.bitcast instance C Int8 where type Index Int8 = Int8 intersectCode = MultiValue.min sizeCode (MultiValue.Cons n) = LLVM.ext =<< unsigned8 n size = fromIntegral flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = Monad.lift2 (,) (LLVM.ext =<< unsigned8 n) (LLVM.ext =<< unsigned8 i) iterator = iteratorPrimitive instance Sequence Int8 where sequenceShapeFromIndex = return unsigned16 :: LLVM.Value Int16 -> LLVM.CodeGenFunction r (LLVM.Value Word16) unsigned16 = LLVM.bitcast instance C Int16 where type Index Int16 = Int16 intersectCode = MultiValue.min sizeCode (MultiValue.Cons n) = LLVM.ext =<< unsigned16 n size = fromIntegral flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = Monad.lift2 (,) (LLVM.ext =<< unsigned16 n) (LLVM.ext =<< unsigned16 i) iterator = iteratorPrimitive instance Sequence Int16 where sequenceShapeFromIndex = return instance C Int32 where type Index Int32 = Int32 intersectCode = MultiValue.min sizeCode (MultiValue.Cons n) = LLVM.zext n size = fromIntegral flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = Monad.lift2 (,) (LLVM.zext n) (LLVM.zext i) iterator = iteratorPrimitive instance Sequence Int32 where sequenceShapeFromIndex = return instance C Int64 where type Index Int64 = Int64 intersectCode = MultiValue.min sizeCode (MultiValue.Cons n) = LLVM.bitcast n size = fromIntegral flattenIndexRec (MultiValue.Cons n) (MultiValue.Cons i) = Monad.lift2 (,) (LLVM.bitcast n) (LLVM.bitcast i) iterator = iteratorPrimitive instance Sequence Int64 where sequenceShapeFromIndex = return {- | 'Range' denotes an inclusive range like those of the Haskell 98 standard @Array@ type from the @array@ package. E.g. the shape type @(Range Int32, Range Int64)@ is equivalent to the ix type @(Int32, Int64)@ for @Array@s. -} data Range n = Range n n singletonRange :: n -> Range n singletonRange n = Range n n {-# INLINE castToElemPtr #-} castToElemPtr :: Ptr (f a) -> Ptr a castToElemPtr = castPtr -- cf. sample-frame:Stereo instance Storable n => Storable (Range n) where {-# INLINE sizeOf #-} {-# INLINE alignment #-} {-# INLINE peek #-} {-# INLINE poke #-} sizeOf ~(Range l r) = sizeOf l + mod (- sizeOf l) (alignment r) + sizeOf r alignment ~(Range l _) = alignment l poke p (Range l r) = let q = castToElemPtr p in poke q l >> pokeElemOff q 1 r peek p = let q = castToElemPtr p in Monad.lift2 Range (peek q) (peekElemOff q 1) class (MultiValue.Additive n, MultiValue.Real n, MultiValue.IntegerConstant n) => ToSize n where toSize :: MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value Size) instance ToSize Word32 where toSize (MultiValue.Cons n) = LLVM.ext n instance ToSize Word64 where toSize (MultiValue.Cons n) = return n instance ToSize Int32 where toSize (MultiValue.Cons n) = LLVM.zext n instance ToSize Int64 where toSize (MultiValue.Cons n) = LLVM.bitcast n rangeSize :: (ToSize n) => Range (MultiValue.T n) -> LLVM.CodeGenFunction r (LLVM.Value Size) rangeSize (Range from to) = toSize =<< MultiValue.inc =<< MultiValue.sub to from unzipRange :: MultiValue.T (Range n) -> Range (MultiValue.T n) unzipRange (MultiValue.Cons (Range from to)) = Range (MultiValue.Cons from) (MultiValue.Cons to) zipRange :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Range n) zipRange (MultiValue.Cons from) (MultiValue.Cons to) = MultiValue.Cons (Range from to) instance (MultiValue.C n) => MultiValue.C (Range n) where type Repr f (Range n) = Range (MultiValue.Repr f n) cons (Range from to) = zipRange (MultiValue.cons from) (MultiValue.cons to) undef = MultiValue.compose $ singletonRange MultiValue.undef zero = MultiValue.compose $ singletonRange MultiValue.zero phis bb a = case unzipRange a of Range a0 a1 -> Monad.lift2 zipRange (MultiValue.phis bb a0) (MultiValue.phis bb a1) addPhis bb a b = case (unzipRange a, unzipRange b) of (Range a0 a1, Range b0 b1) -> MultiValue.addPhis bb a0 b0 >> MultiValue.addPhis bb a1 b1 type instance MultiValue.Decomposed f (Range pn) = Range (MultiValue.Decomposed f pn) type instance MultiValue.PatternTuple (Range pn) = Range (MultiValue.PatternTuple pn) instance (MultiValue.Compose n) => MultiValue.Compose (Range n) where type Composed (Range n) = Range (MultiValue.Composed n) compose (Range from to) = zipRange (MultiValue.compose from) (MultiValue.compose to) instance (MultiValue.Decompose pn) => MultiValue.Decompose (Range pn) where decompose (Range pfrom pto) rng = case unzipRange rng of Range from to -> Range (MultiValue.decompose pfrom from) (MultiValue.decompose pto to) instance (MultiMem.C n) => MultiMem.C (Range n) where type Struct (Range n) = PairStruct n decompose = fmap (uncurry zipRange) . decomposeGen compose x = case unzipRange x of Range n m -> composeGen n m instance (Integral n, ToSize n, MultiValue.Comparison n) => C (Range n) where type Index (Range n) = n intersectCode = MultiValue.modifyF2 (singletonRange atom) (singletonRange atom) $ \(Range fromN toN) (Range fromM toM) -> Monad.lift2 Range (MultiValue.max fromN fromM) (MultiValue.min toN toM) sizeCode = rangeSize . unzipRange size (Range from to) = fromIntegral $ to-from+1 flattenIndexRec rngValue i = case unzipRange rngValue of rng@(Range from _to) -> Monad.lift2 (,) (rangeSize rng) (toSize =<< MultiValue.sub i from) iterator rngValue = case MultiValue.decompose (singletonRange atom) rngValue of Range from to -> IterMV.takeWhile (MultiValue.cmp LLVM.CmpGE to) $ Iter.iterate MultiValue.inc from {- | 'Shifted' denotes a range defined by the start index and the length. -} data Shifted n = Shifted {shiftedOffset, shiftedSize :: n} singletonShifted :: n -> Shifted n singletonShifted n = Shifted n n -- cf. sample-frame:Stereo instance Storable n => Storable (Shifted n) where {-# INLINE sizeOf #-} {-# INLINE alignment #-} {-# INLINE peek #-} {-# INLINE poke #-} sizeOf ~(Shifted l n) = sizeOf l + mod (- sizeOf l) (alignment n) + sizeOf n alignment ~(Shifted l _) = alignment l poke p (Shifted l n) = let q = castToElemPtr p in poke q l >> pokeElemOff q 1 n peek p = let q = castToElemPtr p in Monad.lift2 Shifted (peek q) (peekElemOff q 1) unzipShifted :: MultiValue.T (Shifted n) -> Shifted (MultiValue.T n) unzipShifted (MultiValue.Cons (Shifted from to)) = Shifted (MultiValue.Cons from) (MultiValue.Cons to) zipShifted :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Shifted n) zipShifted (MultiValue.Cons from) (MultiValue.Cons to) = MultiValue.Cons (Shifted from to) instance (MultiValue.C n) => MultiValue.C (Shifted n) where type Repr f (Shifted n) = Shifted (MultiValue.Repr f n) cons (Shifted offset len) = zipShifted (MultiValue.cons offset) (MultiValue.cons len) undef = MultiValue.compose $ singletonShifted MultiValue.undef zero = MultiValue.compose $ singletonShifted MultiValue.zero phis bb a = case unzipShifted a of Shifted a0 a1 -> Monad.lift2 zipShifted (MultiValue.phis bb a0) (MultiValue.phis bb a1) addPhis bb a b = case (unzipShifted a, unzipShifted b) of (Shifted a0 a1, Shifted b0 b1) -> MultiValue.addPhis bb a0 b0 >> MultiValue.addPhis bb a1 b1 type instance MultiValue.Decomposed f (Shifted pn) = Shifted (MultiValue.Decomposed f pn) type instance MultiValue.PatternTuple (Shifted pn) = Shifted (MultiValue.PatternTuple pn) instance (MultiValue.Compose n) => MultiValue.Compose (Shifted n) where type Composed (Shifted n) = Shifted (MultiValue.Composed n) compose (Shifted offset len) = zipShifted (MultiValue.compose offset) (MultiValue.compose len) instance (MultiValue.Decompose pn) => MultiValue.Decompose (Shifted pn) where decompose (Shifted poffset plen) rng = case unzipShifted rng of Shifted offset len -> Shifted (MultiValue.decompose poffset offset) (MultiValue.decompose plen len) instance (MultiMem.C n) => MultiMem.C (Shifted n) where type Struct (Shifted n) = PairStruct n decompose = fmap (uncurry zipShifted) . decomposeGen compose x = case unzipShifted x of Shifted n m -> composeGen n m type PairStruct n = LLVM.Struct (MultiMem.Struct n, (MultiMem.Struct n, ())) decomposeGen :: (MultiMem.C n) => LLVM.Value (PairStruct n) -> LLVM.CodeGenFunction r (MultiValue.T n, MultiValue.T n) decomposeGen nm = Monad.lift2 (,) (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d0) (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d1) composeGen :: (MultiMem.C n) => MultiValue.T n -> MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value (PairStruct n)) composeGen n m = do sn <- MultiMem.compose n sm <- MultiMem.compose m rn <- LLVM.insertvalue (LLVM.value LLVM.undef) sn TypeNum.d0 LLVM.insertvalue rn sm TypeNum.d1 instance (Integral n, ToSize n, MultiValue.Comparison n) => C (Shifted n) where type Index (Shifted n) = n intersectCode = MultiValue.modifyF2 (singletonShifted atom) (singletonShifted atom) $ \(Shifted offsetN lenN) (Shifted offsetM lenM) -> do offset <- MultiValue.max offsetN offsetM endN <- MultiValue.add offsetN lenN endM <- MultiValue.add offsetM lenM end <- MultiValue.min endN endM Shifted offset <$> MultiValue.sub end offset sizeCode = toSize . shiftedSize . unzipShifted size (Shifted _offset len) = fromIntegral len flattenIndexRec shapeValue i = case unzipShifted shapeValue of Shifted offset len -> Monad.lift2 (,) (toSize len) (toSize =<< MultiValue.sub i offset) iterator rngValue = case MultiValue.decompose (singletonShifted atom) rngValue of Shifted from len -> IterMV.take len $ Iter.iterate MultiValue.inc from instance (C n, C m) => C (n,m) where type Index (n,m) = (Index n, Index m) intersectCode a b = case (MultiValue.unzip a, MultiValue.unzip b) of ((an,am), (bn,bm)) -> Monad.lift2 MultiValue.zip (intersectCode an bn) (intersectCode am bm) sizeCode nm = case MultiValue.unzip nm of (n,m) -> Monad.liftJoin2 A.mul (sizeCode n) (sizeCode m) size (n,m) = size n * size m flattenIndexRec nm ij = case (MultiValue.unzip nm, MultiValue.unzip ij) of ((n,m), (i,j)) -> do (ns, il) <- flattenIndexRec n i (ms, jl) <- flattenIndexRec m j Monad.lift2 (,) (A.mul ns ms) (A.add jl =<< A.mul ms il) iterator nm = case MultiValue.unzip nm of (n,m) -> uncurry MultiValue.zip <$> Iter.cartesian (iterator n) (iterator m) loop code nm = case MultiValue.unzip nm of (n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n instance (C n, C m, C l) => C (n,m,l) where type Index (n,m,l) = (Index n, Index m, Index l) intersectCode a b = case (MultiValue.unzip3 a, MultiValue.unzip3 b) of ((ai,aj,ak), (bi,bj,bk)) -> Monad.lift3 MultiValue.zip3 (intersectCode ai bi) (intersectCode aj bj) (intersectCode ak bk) sizeCode nml = case MultiValue.unzip3 nml of (n,m,l) -> Monad.liftJoin2 A.mul (sizeCode n) $ Monad.liftJoin2 A.mul (sizeCode m) (sizeCode l) size (n,m,l) = size n * size m * size l flattenIndexRec nml ijk = case (MultiValue.unzip3 nml, MultiValue.unzip3 ijk) of ((n,m,l), (i,j,k)) -> do (ns, il) <- flattenIndexRec n i (ms, jl) <- flattenIndexRec m j x0 <- A.add jl =<< A.mul ms il (ls, kl) <- flattenIndexRec l k x1 <- A.add kl =<< A.mul ls x0 sz <- A.mul ns =<< A.mul ms ls return (sz, x1) iterator nml = case MultiValue.unzip3 nml of (n,m,l) -> fmap (\(a,(b,c)) -> MultiValue.zip3 a b c) $ Iter.cartesian (iterator n) $ Iter.cartesian (iterator m) (iterator l) loop code nml = case MultiValue.unzip3 nml of (n,m,l) -> loop (\i -> loop (\j -> loop (\k -> code (MultiValue.zip3 i j k)) l) m) n