{-# LANGUAGE Rank2Types #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE EmptyDataDecls #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} module Data.Array.Knead.Shape.Cubic ( C(switch), switchInt, intersect, value, constant, paramWith, tunnel, offsetCode, peek, poke, computeSize, Struct, T(..), Z(Z), z, (:.)((:.)), Shape, shape, Index, index, cons, (#:.), head, tail, switchR, loadMultiValue, storeMultiValue, ) where import qualified Data.Array.Knead.Shape as Shape import qualified Data.Array.Knead.Shape.Cubic.Int as Index import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, ) import qualified Data.Array.Comfort.Shape as ComfortShape import Data.Array.Comfort.Shape (ZeroBased(ZeroBased)) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Multi.Iterator as IterMV import qualified LLVM.Extra.Iterator as Iter import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Control as C import LLVM.Extra.Multi.Value (Atom, ) import qualified LLVM.Util.Loop as Loop import qualified LLVM.Core as LLVM import qualified Foreign.Storable as St import Foreign.Storable.FixedArray (sizeOfArray, ) import Foreign.Marshal.Array (advancePtr, ) import Foreign.Ptr (Ptr, castPtr, ) import Control.Monad (liftM2, ) import Prelude hiding (min, head, tail, ) class C ix where switch :: f Z -> (forall ix0 i. (C ix0, Index.Single i) => f (ix0 :. i)) -> f ix instance C Z where switch x _ = x instance (C ix0, Index.Single i) => C (ix0 :. i) where switch _ x = x newtype SwitchInt f ix i = SwitchInt {runSwitchInt :: f (ix :. i)} switchInt :: (C ix) => f Z -> (forall ix0. (C ix0) => f (ix0 :. Index.Int)) -> f ix switchInt z0 cons0 = switch z0 (runSwitchInt $ Index.switchSingle (SwitchInt cons0)) newtype Op2 tag sh = Op2 {runOp2 :: Exp (T tag sh) -> Exp (T tag sh) -> Exp (T tag sh)} intersect :: C sh => Exp (Shape sh) -> Exp (Shape sh) -> Exp (Shape sh) intersect = runOp2 $ switchInt (Op2 $ \z0 _ -> z0) (Op2 $ switchR $ \is i -> switchR $ \js j -> intersect is js #:. Expr.min i j) _value :: (C sh, MultiValue.C sh) => sh -> Exp sh _value = Expr.lift0 . MultiValue.cons newtype MakeValue val tag sh = MakeValue {runMakeValue :: T tag sh -> val (T tag sh)} value :: (C sh, Expr.Value val) => T tag sh -> val (T tag sh) value = runMakeValue $ switchInt (MakeValue $ \(Cons Z) -> z) (MakeValue $ \(Cons (t:.h)) -> value (Cons t) #:. Expr.lift0 (MultiValue.cons h)) paramWith :: (C sh, Expr.Value val) => Param.T p (T tag sh) -> (forall parameters. (St.Storable parameters, MultiValueMemory.C parameters) => (p -> parameters) -> (MultiValue.T parameters -> val (T tag sh)) -> a) -> a paramWith p f = case tunnel p of Param.Tunnel get val -> f get (Expr.lift0 . val) tunnel :: (C sh) => Param.T p (T tag sh) -> Param.Tunnel p (T tag sh) tunnel p = case structFieldsPropF p of StructFieldsProp -> Param.tunnel value p data StructFieldsProp sh = LLVM.StructFields (Struct sh) => StructFieldsProp _structFieldsProp :: (C sh) => f sh -> StructFieldsProp sh _structFieldsProp _p = structFieldsRec structFieldsPropF :: (C sh) => f (g sh) -> StructFieldsProp sh structFieldsPropF _p = structFieldsRec withStructFieldsPropFF :: (C sh) => (StructFieldsProp sh -> f (g (h sh))) -> f (g (h sh)) withStructFieldsPropFF f = f structFieldsRec structFieldsRec :: (C sh) => StructFieldsProp sh structFieldsRec = switchInt StructFieldsProp (succStructFieldsProp structFieldsRec) succStructFieldsProp :: StructFieldsProp sh -> StructFieldsProp (sh:.Index.Int) succStructFieldsProp StructFieldsProp = StructFieldsProp data Z = Z deriving (Eq, Ord, Read, Show) infixl 3 :., #:. data tail :. head = !tail :. !head deriving (Eq, Ord, Read, Show) newtype T tag sh = Cons {decons :: sh} data ShapeTag data IndexTag type Shape = T ShapeTag type Index = T IndexTag shape :: sh -> Shape sh shape = Cons index :: ix -> Index ix index = Cons (#:.) :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i)) (#:.) = cons cons :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i)) cons = Expr.lift2 $ \(MultiValue.Cons t) (MultiValue.Cons h) -> MultiValue.Cons (t,h) z :: (Expr.Value val) => val (T tag Z) z = Expr.lift0 $ MultiValue.Cons () head :: (Expr.Value val) => val (T tag (sh:.i)) -> val i head = Expr.lift1 $ \(MultiValue.Cons (_t,h)) -> MultiValue.Cons h tail :: (Expr.Value val) => val (T tag (sh:.i)) -> val (T tag sh) tail = Expr.lift1 $ \(MultiValue.Cons (t,_h)) -> MultiValue.Cons t switchR :: Expr.Value val => (val (T tag sh) -> val i -> a) -> val (T tag (sh :. i)) -> a switchR f ix = f (tail ix) (head ix) instance (tag ~ ShapeTag, sh ~ Z) => Shape.Scalar (T tag sh) where scalar = Expr.lift0 $ MultiValue.Cons () zeroIndex _ = Expr.lift0 $ MultiValue.Cons () type family PatternTuple pattern type family Decomposed (f :: * -> *) tag pattern type instance PatternTuple (sh:.s) = PatternTuple sh :. MultiValue.PatternTuple s type instance Decomposed f tag (sh:.s) = Decomposed f tag sh :. MultiValue.Decomposed f s type instance PatternTuple (Atom sh) = sh type instance Decomposed f tag (Atom sh) = f (T tag sh) class (Expr.Composed (Decomposed Exp tag pattern) ~ T tag (PatternTuple pattern)) => Decompose tag pattern where decompose :: T tag pattern -> Exp (T tag (PatternTuple pattern)) -> Decomposed Exp tag pattern instance Decompose tag (Atom sh) where decompose (Cons _atom) x = x instance (Decompose tag sh, Expr.Decompose s) => Decompose tag (sh :. s) where decompose (Cons (psh:.ps)) x = decompose (Cons psh) (tail x) :. Expr.decompose ps (head x) type instance MultiValue.PatternTuple (T tag sh) = T tag (PatternTuple sh) type instance MultiValue.Decomposed f (T tag sh) = Decomposed f tag sh type family Unwrap sh type instance Unwrap (T tag sh) = sh type family Tag sh type instance Tag (T tag sh) = tag instance (Expr.Compose sh, Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Unwrap (Expr.Composed sh)), Expr.Compose s) => Expr.Compose (sh :. s) where type Composed (sh :. s) = T (Tag (Expr.Composed sh)) (Unwrap (Expr.Composed sh) :. Expr.Composed s) compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s) instance (Decompose tag sh) => Expr.Decompose (T tag sh) where decompose = decompose instance (C sh) => St.Storable (T tag sh) where sizeOf (Cons sh) = sizeOfArray (rank sh) (0::Shape.Size) alignment (Cons _sh) = St.alignment (0::Shape.Size) poke ptr = poke (castPtr ptr) . decons peek = fmap Cons . peek . castPtr type family Repr (f :: * -> *) sh type instance Repr f Z = () type instance Repr f (tail :. head) = (Repr f tail, MultiValue.Repr f head) instance (C sh) => MultiValue.C (T tag sh) where type Repr f (T tag sh) = Repr f sh cons = value undef = constant $ MultiValue.undef zero = constant $ MultiValue.zero addPhis = addPhis phis = phis instance (tag ~ ShapeTag, C sh) => ComfortShape.C (T tag sh) where size = fromIntegral . size . decons instance (tag ~ ShapeTag, C sh) => ComfortShape.Indexed (T tag sh) where type Index (T tag sh) = Index sh indices (Cons ix) = map index $ indices ix inBounds (Cons sh) (Cons ix) = inBounds sh ix offset (Cons sh) (Cons ix) = offset sh ix newtype Indices sh = Indices {runIndices :: sh -> [sh]} indices :: (C sh) => sh -> [sh] indices = runIndices $ switchInt (Indices $ \Z -> [Z]) (Indices $ \(t :. Index.Int h) -> liftM2 (:.) (indices t) (map Index.Int $ ComfortShape.indices $ ZeroBased h)) newtype InBounds sh = InBounds {runInBounds :: sh -> sh -> Bool} inBounds :: (C sh) => sh -> sh -> Bool inBounds = runInBounds $ switchInt (InBounds $ \Z Z -> True) (InBounds $ \(sh :. Index.Int s) (ix :. Index.Int i) -> inBounds sh ix && ComfortShape.inBounds (ZeroBased s) i) newtype Offset sh = Offset {runOffset :: sh -> sh -> Int} offset :: (C sh) => sh -> sh -> Int offset = runOffset $ switchInt (Offset $ \Z Z -> 0) (Offset $ \(sh :. Index.Int s) (ix :. Index.Int i) -> offset sh ix * fromIntegral s + fromIntegral i) instance (tag ~ ShapeTag, C sh) => Shape.C (T tag sh) where size = computeSize intersectCode = Expr.unliftM2 intersect sizeOffset sh = -- would a joint implementation be more efficient? liftM2 (,) (computeSize sh) (return $ offsetCode sh) iterator = iterator loop = loop type family Struct sh type instance Struct Z = () type instance Struct (sh :. Index.Int) = (Shape.Size, Struct sh) instance (C sh, LLVM.StructFields (Struct sh)) => MultiValueMemory.C (T tag sh) where type Struct (T tag sh) = LLVM.Struct (Struct sh) load = loadMultiValue store = storeMultiValue loadMultiValue :: (C sh) => LLVM.Value (Ptr (LLVM.Struct (Struct sh))) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) loadMultiValue ptr = withStructFieldsPropFF $ \StructFieldsProp -> load =<< castPtrValue ptr storeMultiValue :: (C sh) => MultiValue.T (T tag sh) -> LLVM.Value (Ptr (LLVM.Struct (Struct sh))) -> LLVM.CodeGenFunction r () storeMultiValue x ptr = case structFieldsPropF x of StructFieldsProp -> store x =<< castPtrValue ptr newtype OffsetCode r sh = OffsetCode { runOffsetCode :: MultiValue.T (Shape sh) -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Shape.Size) } offsetCode :: (C sh) => MultiValue.T (Shape sh) -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Shape.Size) offsetCode = runOffsetCode $ switchInt (OffsetCode $ \_zerosh _zeroix -> return A.zero) (OffsetCode $ switchR $ \sh (MultiValue.Cons s) -> switchR $ \ix (MultiValue.Cons i) -> A.add i =<< A.mul s =<< offsetCode sh ix) newtype Rank sh = Rank {runRank :: sh -> Int} rank :: (C sh) => sh -> Int rank = runRank $ switch (Rank $ const 0) (Rank $ succ . rank . (\(sh :. _s) -> sh)) newtype Peek sh = Peek {runPeek :: Ptr Shape.Size -> IO sh} peek :: (C sh) => Ptr Shape.Size -> IO sh peek = runPeek $ switchInt (Peek $ const $ return Z) (Peek $ \ptr -> do h <- St.peek ptr t <- peek $ advancePtr ptr 1 return (t :. Index.Int h)) newtype Poke sh = Poke {runPoke :: Ptr Shape.Size -> sh -> IO ()} poke :: (C sh) => Ptr Shape.Size -> sh -> IO () poke = runPoke $ switchInt (Poke $ const $ const $ return ()) (Poke $ \ptr (sh :. Index.Int i) -> do St.poke ptr i poke (advancePtr ptr 1) sh) castPtrValue :: (LLVM.StructFields sh) => LLVM.Value (Ptr (LLVM.Struct sh)) -> LLVM.CodeGenFunction r (LLVM.Value (Ptr Shape.Size)) castPtrValue = LLVM.bitcast newtype Load r tag sh = Load { runLoad :: LLVM.Value (Ptr Shape.Size) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) } load :: (C sh) => LLVM.Value (Ptr Shape.Size) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) load = runLoad $ switchInt (Load $ const $ return z) (Load $ \ptr -> do h <- LLVM.load ptr t <- load =<< A.advanceArrayElementPtr ptr return (t #:. MultiValue.Cons h)) newtype Store r tag sh = Store { runStore :: MultiValue.T (T tag sh) -> LLVM.Value (Ptr Shape.Size) -> LLVM.CodeGenFunction r () } store :: (C sh) => MultiValue.T (T tag sh) -> LLVM.Value (Ptr Shape.Size) -> LLVM.CodeGenFunction r () store = runStore $ switchInt (Store $ \_z _ptr -> return ()) (Store $ switchR $ \sh (MultiValue.Cons k) ptr -> do LLVM.store k ptr store sh =<< A.advanceArrayElementPtr ptr) newtype Size sh = Size {runSize :: sh -> Shape.Size} size :: (C sh) => sh -> Shape.Size size = runSize $ switchInt (Size $ \_z -> 1) (Size $ \(sh :. Index.Int k) -> k * size sh) newtype ComputeSize r sh = ComputeSize { runComputeSize :: MultiValue.T (Shape sh) -> LLVM.CodeGenFunction r (LLVM.Value Shape.Size) } computeSize :: (C sh) => MultiValue.T (Shape sh) -> LLVM.CodeGenFunction r (LLVM.Value Shape.Size) computeSize = runComputeSize $ switchInt (ComputeSize $ \_z -> return A.one) (ComputeSize $ switchR $ \sh (MultiValue.Cons k) -> A.mul k =<< computeSize sh) newtype Constant val tag sh = Constant {getConstant :: val Index.Int -> val (T tag sh)} constant :: (C sh, Expr.Value val) => val Index.Int -> val (T tag sh) constant = getConstant $ switchInt (Constant $ const z) (Constant $ \x -> constant x #:. x) newtype AddPhis r tag sh = AddPhis { runAddPhis :: LLVM.BasicBlock -> MultiValue.T (T tag sh) -> MultiValue.T (T tag sh) -> LLVM.CodeGenFunction r () } addPhis :: (C sh) => LLVM.BasicBlock -> MultiValue.T (T tag sh) -> MultiValue.T (T tag sh) -> LLVM.CodeGenFunction r () addPhis = runAddPhis $ switchInt (AddPhis $ \_ _ _ -> return ()) (AddPhis $ \bb -> switchR $ \hx tx -> switchR $ \hy ty -> MultiValue.addPhis bb tx ty >> addPhis bb hx hy) newtype Phis r tag sh = Phis { runPhis :: LLVM.BasicBlock -> MultiValue.T (T tag sh) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) } phis :: (C sh) => LLVM.BasicBlock -> MultiValue.T (T tag sh) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) phis = runPhis $ switchInt (Phis $ \_ -> return) (Phis $ \bb -> switchR $ \h t -> liftM2 (#:.) (phis bb h) (MultiValue.phis bb t)) newtype Iterator r sh = Iterator { runIterator :: MultiValue.T (Shape sh) -> Iter.T r (MultiValue.T (Index sh)) } iterator :: (C sh) => MultiValue.T (Shape sh) -> Iter.T r (MultiValue.T (Index sh)) iterator = runIterator $ switchInt (Iterator $ \ _z -> Iter.empty) (Iterator $ switchR $ \sh n -> fmap (\(ix,i) -> ix#:.i) $ Iter.cartesian (iterator sh) (IterMV.takeWhile (MultiValue.cmp LLVM.CmpGT n) $ Iter.iterate MultiValue.inc MultiValue.zero)) newtype Loop r state sh = Loop { runLoop :: (MultiValue.T (Index sh) -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T (Shape sh) -> state -> LLVM.CodeGenFunction r state } loop :: (C sh, Loop.Phi state) => (MultiValue.T (Index sh) -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T (Shape sh) -> state -> LLVM.CodeGenFunction r state loop = runLoop $ switchInt (Loop $ \code _z -> code z) (Loop $ \code -> switchR $ \sh (MultiValue.Cons n) -> loop (\ix ptrStart -> fmap fst $ C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) -> liftM2 (,) (code (ix #:. MultiValue.Cons k) ptr) (A.inc k)) sh)