{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE FlexibleContexts #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module Data.Array.Knead.Shape ( C(..), Index, Size, value, paramWith, load, intersect, offset, ZeroBased(ZeroBased), zeroBased, zeroBasedSize, Range(Range), range, rangeFrom, rangeTo, Shifted(Shifted), shifted, shiftedOffset, shiftedSize, Scalar(..), Sequence(..), ) where import qualified Data.Array.Knead.Expression as Expr import qualified Data.Array.Knead.Parameter as Param import Data.Array.Knead.Expression (Exp, ) import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Shape (Index, ZeroBased(ZeroBased), Range(Range), Shifted(Shifted)) import Data.Ix (Ix) import qualified LLVM.Extra.Multi.Value.Memory as MultiMem import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Multi.Iterator as IterMV import qualified LLVM.Extra.Iterator as Iter import qualified LLVM.Extra.Arithmetic as A import LLVM.Extra.Multi.Value (atom) import qualified LLVM.Util.Loop as Loop import qualified LLVM.Core as LLVM import qualified Type.Data.Num.Decimal as TypeNum import Foreign.Storable (Storable) import Foreign.Ptr (Ptr) import Data.Word (Word8, Word16, Word32, Word64) import Data.Int (Int8, Int16, Int32, Int64) import qualified Control.Monad.HT as Monad import Control.Applicative ((<$>)) type Size = Word64 value :: (C sh, Expr.Value val) => sh -> val sh value = Expr.lift0 . MultiValue.cons paramWith :: (Storable b, MultiMem.C b, Expr.Value val) => Param.T p b -> (forall parameters. (Storable parameters, MultiMem.C parameters) => (p -> parameters) -> (MultiValue.T parameters -> val b) -> a) -> a paramWith p f = Param.withMulti p (\get val -> f get (Expr.lift0 . val)) load :: (MultiMem.C sh) => f sh -> LLVM.Value (Ptr (MultiMem.Struct sh)) -> LLVM.CodeGenFunction r (MultiValue.T sh) load _ = MultiMem.load intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh intersect = Expr.liftM2 intersectCode offset :: (C sh) => MultiValue.T sh -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Size) offset sh ix = ($ix) . snd =<< sizeOffset sh class (MultiValue.C sh, MultiValue.C (Index sh), Shape.Indexed sh) => C sh where {- It would be better to restrict zipWith to matching shapes and turn shape intersection into a bound check. -} intersectCode :: MultiValue.T sh -> MultiValue.T sh -> LLVM.CodeGenFunction r (MultiValue.T sh) size :: MultiValue.T sh -> LLVM.CodeGenFunction r (LLVM.Value Size) {- | Result is @(size, offset)@. @size@ must equal the result of 'size'. We use this for sharing intermediate results. -} sizeOffset :: MultiValue.T sh -> LLVM.CodeGenFunction r (LLVM.Value Size, MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Size)) iterator :: (Index sh ~ ix) => MultiValue.T sh -> Iter.T r (MultiValue.T ix) loop :: (Index sh ~ ix, MultiValue.C ix, Loop.Phi state) => (MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T sh -> state -> LLVM.CodeGenFunction r state loop f sh = Iter.mapState_ f (iterator sh) instance C () where intersectCode _ _ = return $ MultiValue.cons () size _ = return A.one sizeOffset _ = return (A.one, \_ -> return A.zero) iterator = Iter.singleton loop = id class C sh => Scalar sh where scalar :: (Expr.Value val) => val sh zeroIndex :: (Expr.Value val) => f sh -> val (Index sh) instance Scalar () where scalar = Expr.lift0 $ MultiValue.Cons () zeroIndex _ = Expr.lift0 $ MultiValue.Cons () class (C sh, MultiValue.IntegerConstant (Index sh), MultiValue.Additive (Index sh)) => Sequence sh where sequenceShapeFromIndex :: MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (MultiValue.T sh) class (MultiValue.Additive n, MultiValue.Real n, MultiValue.IntegerConstant n) => ToSize n where toSize :: MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value Size) instance ToSize Word8 where toSize (MultiValue.Cons n) = LLVM.ext n instance ToSize Word16 where toSize (MultiValue.Cons n) = LLVM.ext n instance ToSize Word32 where toSize (MultiValue.Cons n) = LLVM.ext n instance ToSize Word64 where toSize (MultiValue.Cons n) = return n instance ToSize Int8 where toSize (MultiValue.Cons n) = LLVM.zext n instance ToSize Int16 where toSize (MultiValue.Cons n) = LLVM.zext n instance ToSize Int32 where toSize (MultiValue.Cons n) = LLVM.zext n instance ToSize Int64 where toSize (MultiValue.Cons n) = LLVM.bitcast n unzipZeroBased :: MultiValue.T (ZeroBased n) -> ZeroBased (MultiValue.T n) unzipZeroBased (MultiValue.Cons (ZeroBased n)) = ZeroBased (MultiValue.Cons n) zeroBasedSize :: (Expr.Value val) => val (ZeroBased n) -> val n zeroBasedSize = Expr.lift1 $ Shape.zeroBasedSize . unzipZeroBased zeroBased :: (Expr.Value val) => val n -> val (ZeroBased n) zeroBased = Expr.lift1 $ \(MultiValue.Cons n) -> MultiValue.Cons (ZeroBased n) instance (MultiValue.C n) => MultiValue.C (ZeroBased n) where type Repr f (ZeroBased n) = ZeroBased (MultiValue.Repr f n) cons (ZeroBased n) = zeroBased (MultiValue.cons n) undef = zeroBased MultiValue.undef zero = zeroBased MultiValue.zero phis bb = Monad.lift zeroBased . MultiValue.phis bb . zeroBasedSize addPhis bb a b = MultiValue.addPhis bb (zeroBasedSize a) (zeroBasedSize b) type instance MultiValue.Decomposed f (ZeroBased pn) = ZeroBased (MultiValue.Decomposed f pn) type instance MultiValue.PatternTuple (ZeroBased pn) = ZeroBased (MultiValue.PatternTuple pn) instance (MultiValue.Compose n) => MultiValue.Compose (ZeroBased n) where type Composed (ZeroBased n) = ZeroBased (MultiValue.Composed n) compose (ZeroBased n) = zeroBased (MultiValue.compose n) instance (MultiValue.Decompose pn) => MultiValue.Decompose (ZeroBased pn) where decompose (ZeroBased p) sh = MultiValue.decompose p <$> unzipZeroBased sh instance (Expr.Compose n) => Expr.Compose (ZeroBased n) where type Composed (ZeroBased n) = ZeroBased (Expr.Composed n) compose (ZeroBased n) = Expr.lift1 zeroBased (Expr.compose n) instance (Expr.Decompose pn) => Expr.Decompose (ZeroBased pn) where decompose (ZeroBased p) = ZeroBased . Expr.decompose p . zeroBasedSize instance (MultiMem.C n) => MultiMem.C (ZeroBased n) where type Struct (ZeroBased n) = MultiMem.Struct n decompose = fmap zeroBased . MultiMem.decompose compose = MultiMem.compose . zeroBasedSize {- | Array dimensions and indexes cannot be negative, but computations in indices may temporarily yield negative values or we want to add negative values to indices. So maybe, we would better have type Index (ZeroBased Word64) = Int64. This is not possible. Maybe we need an additional ZeroBased type for unsigned array sizes. -} instance (Integral n, ToSize n, MultiValue.Comparison n) => C (ZeroBased n) where intersectCode sha shb = zeroBased <$> MultiValue.min (zeroBasedSize sha) (zeroBasedSize shb) size = toSize . zeroBasedSize sizeOffset sh = Monad.lift2 (,) (toSize $ zeroBasedSize sh) (return toSize) iterator sh = IterMV.take (zeroBasedSize sh) $ Iter.iterate MultiValue.inc MultiValue.zero instance (Integral n, ToSize n, MultiValue.Comparison n) => Sequence (ZeroBased n) where sequenceShapeFromIndex = return . zeroBased singletonRange :: n -> Range n singletonRange n = Range n n rangeSize :: (ToSize n) => Range (MultiValue.T n) -> LLVM.CodeGenFunction r (LLVM.Value Size) rangeSize (Range from to) = toSize =<< MultiValue.inc =<< MultiValue.sub to from unzipRange :: MultiValue.T (Range n) -> Range (MultiValue.T n) unzipRange (MultiValue.Cons (Range from to)) = Range (MultiValue.Cons from) (MultiValue.Cons to) zipRange :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Range n) zipRange (MultiValue.Cons from) (MultiValue.Cons to) = MultiValue.Cons (Range from to) rangeFrom :: (Expr.Value val) => val (Range n) -> val n rangeFrom = Expr.lift1 $ Shape.rangeFrom . unzipRange rangeTo :: (Expr.Value val) => val (Range n) -> val n rangeTo = Expr.lift1 $ Shape.rangeTo . unzipRange range :: (Expr.Value val) => val n -> val n -> val (Range n) range = Expr.lift2 $ \(MultiValue.Cons from) (MultiValue.Cons to) -> MultiValue.Cons (Range from to) instance (MultiValue.C n) => MultiValue.C (Range n) where type Repr f (Range n) = Range (MultiValue.Repr f n) cons (Range from to) = zipRange (MultiValue.cons from) (MultiValue.cons to) undef = MultiValue.compose $ singletonRange MultiValue.undef zero = MultiValue.compose $ singletonRange MultiValue.zero phis bb a = case unzipRange a of Range a0 a1 -> Monad.lift2 zipRange (MultiValue.phis bb a0) (MultiValue.phis bb a1) addPhis bb a b = case (unzipRange a, unzipRange b) of (Range a0 a1, Range b0 b1) -> MultiValue.addPhis bb a0 b0 >> MultiValue.addPhis bb a1 b1 type instance MultiValue.Decomposed f (Range pn) = Range (MultiValue.Decomposed f pn) type instance MultiValue.PatternTuple (Range pn) = Range (MultiValue.PatternTuple pn) instance (MultiValue.Compose n) => MultiValue.Compose (Range n) where type Composed (Range n) = Range (MultiValue.Composed n) compose (Range from to) = zipRange (MultiValue.compose from) (MultiValue.compose to) instance (MultiValue.Decompose pn) => MultiValue.Decompose (Range pn) where decompose (Range pfrom pto) rng = case unzipRange rng of Range from to -> Range (MultiValue.decompose pfrom from) (MultiValue.decompose pto to) instance (MultiMem.C n) => MultiMem.C (Range n) where type Struct (Range n) = PairStruct n decompose = fmap (uncurry zipRange) . decomposeGen compose x = case unzipRange x of Range n m -> composeGen n m instance (Ix n, ToSize n, MultiValue.Comparison n) => C (Range n) where intersectCode = MultiValue.modifyF2 (singletonRange atom) (singletonRange atom) $ \(Range fromN toN) (Range fromM toM) -> Monad.lift2 Range (MultiValue.max fromN fromM) (MultiValue.min toN toM) size = rangeSize . unzipRange sizeOffset rngValue = case unzipRange rngValue of rng@(Range from _to) -> Monad.lift2 (,) (rangeSize rng) (return $ \i -> toSize =<< MultiValue.sub i from) iterator rngValue = case MultiValue.decompose (singletonRange atom) rngValue of Range from to -> IterMV.takeWhile (MultiValue.cmp LLVM.CmpGE to) $ Iter.iterate MultiValue.inc from singletonShifted :: n -> Shifted n singletonShifted n = Shifted n n unzipShifted :: MultiValue.T (Shifted n) -> Shifted (MultiValue.T n) unzipShifted (MultiValue.Cons (Shifted from to)) = Shifted (MultiValue.Cons from) (MultiValue.Cons to) zipShifted :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Shifted n) zipShifted (MultiValue.Cons from) (MultiValue.Cons to) = MultiValue.Cons (Shifted from to) shiftedOffset :: (Expr.Value val) => val (Shifted n) -> val n shiftedOffset = Expr.lift1 $ Shape.shiftedOffset . unzipShifted shiftedSize :: (Expr.Value val) => val (Shifted n) -> val n shiftedSize = Expr.lift1 $ Shape.shiftedSize . unzipShifted shifted :: (Expr.Value val) => val n -> val n -> val (Shifted n) shifted = Expr.lift2 $ \(MultiValue.Cons from) (MultiValue.Cons to) -> MultiValue.Cons (Shifted from to) instance (MultiValue.C n) => MultiValue.C (Shifted n) where type Repr f (Shifted n) = Shifted (MultiValue.Repr f n) cons (Shifted start len) = zipShifted (MultiValue.cons start) (MultiValue.cons len) undef = MultiValue.compose $ singletonShifted MultiValue.undef zero = MultiValue.compose $ singletonShifted MultiValue.zero phis bb a = case unzipShifted a of Shifted a0 a1 -> Monad.lift2 zipShifted (MultiValue.phis bb a0) (MultiValue.phis bb a1) addPhis bb a b = case (unzipShifted a, unzipShifted b) of (Shifted a0 a1, Shifted b0 b1) -> MultiValue.addPhis bb a0 b0 >> MultiValue.addPhis bb a1 b1 type instance MultiValue.Decomposed f (Shifted pn) = Shifted (MultiValue.Decomposed f pn) type instance MultiValue.PatternTuple (Shifted pn) = Shifted (MultiValue.PatternTuple pn) instance (MultiValue.Compose n) => MultiValue.Compose (Shifted n) where type Composed (Shifted n) = Shifted (MultiValue.Composed n) compose (Shifted start len) = zipShifted (MultiValue.compose start) (MultiValue.compose len) instance (MultiValue.Decompose pn) => MultiValue.Decompose (Shifted pn) where decompose (Shifted pstart plen) rng = case unzipShifted rng of Shifted start len -> Shifted (MultiValue.decompose pstart start) (MultiValue.decompose plen len) instance (MultiMem.C n) => MultiMem.C (Shifted n) where type Struct (Shifted n) = PairStruct n decompose = fmap (uncurry zipShifted) . decomposeGen compose x = case unzipShifted x of Shifted n m -> composeGen n m type PairStruct n = LLVM.Struct (MultiMem.Struct n, (MultiMem.Struct n, ())) decomposeGen :: (MultiMem.C n) => LLVM.Value (PairStruct n) -> LLVM.CodeGenFunction r (MultiValue.T n, MultiValue.T n) decomposeGen nm = Monad.lift2 (,) (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d0) (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d1) composeGen :: (MultiMem.C n) => MultiValue.T n -> MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value (PairStruct n)) composeGen n m = do sn <- MultiMem.compose n sm <- MultiMem.compose m rn <- LLVM.insertvalue (LLVM.value LLVM.undef) sn TypeNum.d0 LLVM.insertvalue rn sm TypeNum.d1 instance (Integral n, ToSize n, MultiValue.Comparison n) => C (Shifted n) where intersectCode = MultiValue.modifyF2 (singletonShifted atom) (singletonShifted atom) $ \(Shifted startN lenN) (Shifted startM lenM) -> do start <- MultiValue.max startN startM endN <- MultiValue.add startN lenN endM <- MultiValue.add startM lenM end <- MultiValue.min endN endM Shifted start <$> MultiValue.sub end start size = toSize . shiftedSize sizeOffset shapeValue = case unzipShifted shapeValue of Shifted start len -> Monad.lift2 (,) (toSize len) (return $ \i -> toSize =<< MultiValue.sub i start) iterator rngValue = case MultiValue.decompose (singletonShifted atom) rngValue of Shifted from len -> IterMV.take len $ Iter.iterate MultiValue.inc from instance (C n, C m) => C (n,m) where intersectCode a b = case (MultiValue.unzip a, MultiValue.unzip b) of ((an,am), (bn,bm)) -> Monad.lift2 MultiValue.zip (intersectCode an bn) (intersectCode am bm) size nm = case MultiValue.unzip nm of (n,m) -> Monad.liftJoin2 A.mul (size n) (size m) sizeOffset nm = case MultiValue.unzip nm of (n,m) -> do (ns, iOffset) <- sizeOffset n (ms, jOffset) <- sizeOffset m sz <- A.mul ns ms return (sz, \ij -> case MultiValue.unzip ij of (i,j) -> do il <- iOffset i jl <- jOffset j A.add jl =<< A.mul ms il) iterator nm = case MultiValue.unzip nm of (n,m) -> uncurry MultiValue.zip <$> Iter.cartesian (iterator n) (iterator m) loop code nm = case MultiValue.unzip nm of (n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n instance (C n, C m, C l) => C (n,m,l) where intersectCode a b = case (MultiValue.unzip3 a, MultiValue.unzip3 b) of ((ai,aj,ak), (bi,bj,bk)) -> Monad.lift3 MultiValue.zip3 (intersectCode ai bi) (intersectCode aj bj) (intersectCode ak bk) size nml = case MultiValue.unzip3 nml of (n,m,l) -> Monad.liftJoin2 A.mul (size n) $ Monad.liftJoin2 A.mul (size m) (size l) sizeOffset nml = case MultiValue.unzip3 nml of (n,m,l) -> do (ns, iOffset) <- sizeOffset n (ms, jOffset) <- sizeOffset m (ls, kOffset) <- sizeOffset l sz <- A.mul ns =<< A.mul ms ls return (sz, \ijk -> case MultiValue.unzip3 ijk of (i,j,k) -> do il <- iOffset i jl <- jOffset j kl <- kOffset k A.add kl =<< A.mul ls =<< A.add jl =<< A.mul ms il) iterator nml = case MultiValue.unzip3 nml of (n,m,l) -> fmap (\(a,(b,c)) -> MultiValue.zip3 a b c) $ Iter.cartesian (iterator n) $ Iter.cartesian (iterator m) (iterator l) loop code nml = case MultiValue.unzip3 nml of (n,m,l) -> loop (\i -> loop (\j -> loop (\k -> code (MultiValue.zip3 i j k)) l) m) n