{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE Rank2Types #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} module Data.Array.Knead.Shape ( C(..), Index, Size, value, paramWith, load, intersect, offset, ZeroBased(ZeroBased), zeroBased, zeroBasedSize, Range(Range), range, rangeFrom, rangeTo, Shifted(Shifted), shifted, shiftedOffset, shiftedSize, Enumeration(Enumeration), EnumBounded(..), Scalar(..), Sequence(..), ) where import qualified Data.Array.Knead.Expression as Expr import qualified Data.Array.Knead.Parameter as Param import Data.Array.Knead.Shape.Orphan (zeroBased, zeroBasedSize, singletonRange, unzipRange, singletonShifted, unzipShifted) import Data.Array.Knead.Expression (Exp, ) import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Shape (Index, ZeroBased, Range(Range), Shifted(Shifted), Enumeration(Enumeration)) import Data.Ix (Ix) import qualified LLVM.Extra.Multi.Value.Memory as MultiMem import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Multi.Iterator as IterMV import qualified LLVM.Extra.Iterator as Iter import qualified LLVM.Extra.ScalarOrVector as SoV import qualified LLVM.Extra.Arithmetic as A import LLVM.Extra.Multi.Value (atom) import qualified LLVM.Util.Loop as Loop import qualified LLVM.Core as LLVM import Foreign.Storable (Storable) import Foreign.Ptr (Ptr) import qualified Data.Enum.Storable as Enum import Data.Tagged (Tagged) import Data.Tuple.HT (mapSnd) import Data.Word (Word8, Word16, Word32, Word64) import Data.Int (Int8, Int16, Int32, Int64) import qualified Control.Monad.HT as Monad import Control.Applicative ((<$>)) import Prelude2010 import Prelude () type Size = Word64 value :: (C sh, Expr.Value val) => sh -> val sh value = Expr.lift0 . MultiValue.cons paramWith :: (Storable b, MultiMem.C b, Expr.Value val) => Param.T p b -> (forall parameters. (Storable parameters, MultiMem.C parameters) => (p -> parameters) -> (MultiValue.T parameters -> val b) -> a) -> a paramWith p f = Param.withMulti p (\get val -> f get (Expr.lift0 . val)) load :: (MultiMem.C sh) => f sh -> LLVM.Value (Ptr (MultiMem.Struct sh)) -> LLVM.CodeGenFunction r (MultiValue.T sh) load _ = MultiMem.load intersect :: (C sh) => Exp sh -> Exp sh -> Exp sh intersect = Expr.liftM2 intersectCode offset :: (C sh) => MultiValue.T sh -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Size) offset sh ix = ($ix) . snd =<< sizeOffset sh class (MultiValue.C sh, MultiValue.C (Index sh), Shape.Indexed sh) => C sh where {- It would be better to restrict zipWith to matching shapes and turn shape intersection into a bound check. -} intersectCode :: MultiValue.T sh -> MultiValue.T sh -> LLVM.CodeGenFunction r (MultiValue.T sh) size :: MultiValue.T sh -> LLVM.CodeGenFunction r (LLVM.Value Size) {- | Result is @(size, offset)@. @size@ must equal the result of 'size'. We use this for sharing intermediate results. -} sizeOffset :: MultiValue.T sh -> LLVM.CodeGenFunction r (LLVM.Value Size, MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Size)) iterator :: (Index sh ~ ix) => MultiValue.T sh -> Iter.T r (MultiValue.T ix) loop :: (Index sh ~ ix, MultiValue.C ix, Loop.Phi state) => (MultiValue.T ix -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T sh -> state -> LLVM.CodeGenFunction r state loop f sh = Iter.mapState_ f (iterator sh) instance C () where intersectCode _ _ = return $ MultiValue.cons () size _ = return A.one sizeOffset _ = return (A.one, \_ -> return A.zero) iterator = Iter.singleton loop = id class C sh => Scalar sh where scalar :: (Expr.Value val) => val sh zeroIndex :: (Expr.Value val) => f sh -> val (Index sh) instance Scalar () where scalar = Expr.lift0 $ MultiValue.Cons () zeroIndex _ = Expr.lift0 $ MultiValue.Cons () class (C sh, MultiValue.IntegerConstant (Index sh), MultiValue.Additive (Index sh)) => Sequence sh where sequenceShapeFromIndex :: MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (MultiValue.T sh) class (MultiValue.Additive n, MultiValue.Real n, MultiValue.IntegerConstant n) => ToSize n where toSize :: MultiValue.T n -> LLVM.CodeGenFunction r (LLVM.Value Size) instance ToSize Word8 where toSize (MultiValue.Cons n) = LLVM.ext n instance ToSize Word16 where toSize (MultiValue.Cons n) = LLVM.ext n instance ToSize Word32 where toSize (MultiValue.Cons n) = LLVM.ext n instance ToSize Word64 where toSize (MultiValue.Cons n) = return n instance ToSize Int8 where toSize (MultiValue.Cons n) = LLVM.zext n instance ToSize Int16 where toSize (MultiValue.Cons n) = LLVM.zext n instance ToSize Int32 where toSize (MultiValue.Cons n) = LLVM.zext n instance ToSize Int64 where toSize (MultiValue.Cons n) = LLVM.bitcast n {- | Array dimensions and indexes cannot be negative, but computations in indices may temporarily yield negative values or we want to add negative values to indices. So maybe, we would better have type Index (ZeroBased Word64) = Int64. This is not possible. Maybe we need an additional ZeroBased type for unsigned array sizes. -} instance (Integral n, ToSize n, MultiValue.Comparison n) => C (ZeroBased n) where intersectCode sha shb = zeroBased <$> MultiValue.min (zeroBasedSize sha) (zeroBasedSize shb) size = toSize . zeroBasedSize sizeOffset sh = Monad.lift2 (,) (toSize $ zeroBasedSize sh) (return toSize) iterator sh = IterMV.take (zeroBasedSize sh) $ Iter.iterate MultiValue.inc MultiValue.zero instance (Integral n, ToSize n, MultiValue.Comparison n) => Sequence (ZeroBased n) where sequenceShapeFromIndex = return . zeroBased rangeSize :: (ToSize n) => Range (MultiValue.T n) -> LLVM.CodeGenFunction r (LLVM.Value Size) rangeSize (Range from to) = toSize =<< MultiValue.inc =<< MultiValue.sub to from rangeFrom :: (Expr.Value val) => val (Range n) -> val n rangeFrom = Expr.lift1 $ Shape.rangeFrom . unzipRange rangeTo :: (Expr.Value val) => val (Range n) -> val n rangeTo = Expr.lift1 $ Shape.rangeTo . unzipRange range :: (Expr.Value val) => val n -> val n -> val (Range n) range = Expr.lift2 $ \(MultiValue.Cons from) (MultiValue.Cons to) -> MultiValue.Cons (Range from to) instance (Ix n, ToSize n, MultiValue.Comparison n) => C (Range n) where intersectCode = MultiValue.modifyF2 (singletonRange atom) (singletonRange atom) $ \(Range fromN toN) (Range fromM toM) -> Monad.lift2 Range (MultiValue.max fromN fromM) (MultiValue.min toN toM) size = rangeSize . unzipRange sizeOffset rngValue = case unzipRange rngValue of rng@(Range from _to) -> Monad.lift2 (,) (rangeSize rng) (return $ \i -> toSize =<< MultiValue.sub i from) iterator rngValue = case MultiValue.decompose (singletonRange atom) rngValue of Range from to -> IterMV.takeWhile (MultiValue.cmp LLVM.CmpGE to) $ Iter.iterate MultiValue.inc from shiftedOffset :: (Expr.Value val) => val (Shifted n) -> val n shiftedOffset = Expr.lift1 $ Shape.shiftedOffset . unzipShifted shiftedSize :: (Expr.Value val) => val (Shifted n) -> val n shiftedSize = Expr.lift1 $ Shape.shiftedSize . unzipShifted shifted :: (Expr.Value val) => val n -> val n -> val (Shifted n) shifted = Expr.lift2 $ \(MultiValue.Cons from) (MultiValue.Cons to) -> MultiValue.Cons (Shifted from to) instance (Integral n, ToSize n, MultiValue.Comparison n) => C (Shifted n) where intersectCode = MultiValue.modifyF2 (singletonShifted atom) (singletonShifted atom) $ \(Shifted startN lenN) (Shifted startM lenM) -> do start <- MultiValue.max startN startM endN <- MultiValue.add startN lenN endM <- MultiValue.add startM lenM end <- MultiValue.min endN endM Shifted start <$> MultiValue.sub end start size = toSize . shiftedSize sizeOffset shapeValue = case unzipShifted shapeValue of Shifted start len -> Monad.lift2 (,) (toSize len) (return $ \i -> toSize =<< MultiValue.sub i start) iterator rngValue = case MultiValue.decompose (singletonShifted atom) rngValue of Shifted from len -> IterMV.take len $ Iter.iterate MultiValue.inc from class (IterMV.Enum enum, MultiValue.Bounded enum) => EnumBounded enum where enumOffset :: MultiValue.T enum -> LLVM.CodeGenFunction r (LLVM.Value Size) instance (ToSize w, MultiValue.Additive w, LLVM.IsInteger w, SoV.IntegerConstant w, Num w, MultiValue.Repr LLVM.Value w ~ LLVM.Value w, LLVM.CmpRet w, LLVM.CmpResult w ~ Bool, Enum e, Bounded e) => EnumBounded (Enum.T w e) where enumOffset ix = toSize =<< MultiValue.sub (MultiValue.fromEnum ix) (MultiValue.fromEnum $ MultiValue.minBound `asTypeOf` ix) instance (Enum enum, Bounded enum, EnumBounded enum) => C (Enumeration enum) where intersectCode _sha shb = return shb size = return . A.fromInteger' . toInteger . Shape.size . plainEnumeration sizeOffset sh = do sz <- size sh return (sz, enumOffset) iterator _ = IterMV.enumFromTo MultiValue.minBound MultiValue.maxBound plainEnumeration :: val (Enumeration enum) -> Enumeration enum plainEnumeration _ = Enumeration instance (C sh) => C (Tagged tag sh) where intersectCode = MultiValue.liftTaggedM2 intersectCode size = size . MultiValue.untag sizeOffset = fmap (mapSnd (. MultiValue.untag)) . sizeOffset . MultiValue.untag iterator = fmap MultiValue.tag . iterator . MultiValue.untag instance (C n, C m) => C (n,m) where intersectCode a b = case (MultiValue.unzip a, MultiValue.unzip b) of ((an,am), (bn,bm)) -> Monad.lift2 MultiValue.zip (intersectCode an bn) (intersectCode am bm) size nm = case MultiValue.unzip nm of (n,m) -> Monad.liftJoin2 A.mul (size n) (size m) sizeOffset nm = case MultiValue.unzip nm of (n,m) -> do (ns, iOffset) <- sizeOffset n (ms, jOffset) <- sizeOffset m sz <- A.mul ns ms return (sz, \ij -> case MultiValue.unzip ij of (i,j) -> do il <- iOffset i jl <- jOffset j A.add jl =<< A.mul ms il) iterator nm = case MultiValue.unzip nm of (n,m) -> uncurry MultiValue.zip <$> Iter.cartesian (iterator n) (iterator m) loop code nm = case MultiValue.unzip nm of (n,m) -> loop (\i -> loop (\j -> code (MultiValue.zip i j)) m) n instance (C n, C m, C l) => C (n,m,l) where intersectCode a b = case (MultiValue.unzip3 a, MultiValue.unzip3 b) of ((ai,aj,ak), (bi,bj,bk)) -> Monad.lift3 MultiValue.zip3 (intersectCode ai bi) (intersectCode aj bj) (intersectCode ak bk) size nml = case MultiValue.unzip3 nml of (n,m,l) -> Monad.liftJoin2 A.mul (size n) $ Monad.liftJoin2 A.mul (size m) (size l) sizeOffset nml = case MultiValue.unzip3 nml of (n,m,l) -> do (ns, iOffset) <- sizeOffset n (ms, jOffset) <- sizeOffset m (ls, kOffset) <- sizeOffset l sz <- A.mul ns =<< A.mul ms ls return (sz, \ijk -> case MultiValue.unzip3 ijk of (i,j,k) -> do il <- iOffset i jl <- jOffset j kl <- kOffset k A.add kl =<< A.mul ls =<< A.add jl =<< A.mul ms il) iterator nml = case MultiValue.unzip3 nml of (n,m,l) -> fmap (\(a,(b,c)) -> MultiValue.zip3 a b c) $ Iter.cartesian (iterator n) $ Iter.cartesian (iterator m) (iterator l) loop code nml = case MultiValue.unzip3 nml of (n,m,l) -> loop (\i -> loop (\j -> loop (\k -> code (MultiValue.zip3 i j k)) l) m) n