{-# LANGUAGE Rank2Types #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE EmptyDataDecls #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} module Data.Array.Knead.Index.Linear ( C(switch), switchInt, intersect, value, constant, paramWith, tunnel, flattenIndex, peek, poke, computeSize, Struct, T(..), Z(Z), (:.)((:.)), Shape, shape, Index, index, cons, (#:.), head, tail, switchR, loadMultiValue, storeMultiValue, ) where import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Index.Linear.Int as Index import qualified Data.Array.Knead.Parameter as Param import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp, ) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Arithmetic as A import qualified LLVM.Extra.Control as C import LLVM.Extra.Multi.Value (Atom, ) import qualified LLVM.Util.Loop as Loop import qualified LLVM.Core as LLVM import qualified Foreign.Storable as St import Foreign.Storable.FixedArray (sizeOfArray, ) import Foreign.Marshal.Array (advancePtr, ) import Foreign.Ptr (Ptr, castPtr, ) import Control.Monad (liftM2, ) import Data.Word (Word32, ) import Prelude hiding (min, head, tail, ) class C ix where switch :: f Z -> (forall ix0 i. (C ix0, Index.Single i) => f (ix0 :. i)) -> f ix instance C Z where switch x _ = x instance (C ix0, Index.Single i) => C (ix0 :. i) where switch _ x = x newtype SwitchInt f ix i = SwitchInt {runSwitchInt :: f (ix :. i)} switchInt :: (C ix) => f Z -> (forall ix0. (C ix0) => f (ix0 :. Index.Int)) -> f ix switchInt z0 cons0 = switch z0 (runSwitchInt $ Index.switchSingle (SwitchInt cons0)) newtype Op2 tag sh = Op2 {runOp2 :: Exp (T tag sh) -> Exp (T tag sh) -> Exp (T tag sh)} intersect :: C sh => Exp (Shape sh) -> Exp (Shape sh) -> Exp (Shape sh) intersect = runOp2 $ switchInt (Op2 $ \z0 _ -> z0) (Op2 $ switchR $ \is i -> switchR $ \js j -> intersect is js #:. Expr.min i j) _value :: (C sh, MultiValue.C sh) => sh -> Exp sh _value = Expr.lift0 . MultiValue.cons newtype MakeValue val tag sh = MakeValue {runMakeValue :: T tag sh -> val (T tag sh)} value :: (C sh, Expr.Value val) => T tag sh -> val (T tag sh) value = runMakeValue $ switchInt (MakeValue $ \(Cons Z) -> z) (MakeValue $ \(Cons (t:.h)) -> value (Cons t) #:. Expr.lift0 (MultiValue.cons h)) paramWith :: (C sh, Expr.Value val) => Param.T p (T tag sh) -> (forall parameters. (St.Storable parameters, MultiValueMemory.C parameters) => (p -> parameters) -> (MultiValue.T parameters -> val (T tag sh)) -> a) -> a paramWith p f = case tunnel p of Param.Tunnel get val -> f get (Expr.lift0 . val) tunnel :: (C sh) => Param.T p (T tag sh) -> Param.Tunnel p (T tag sh) tunnel p = case structFieldsPropF p of StructFieldsProp -> Param.tunnel value p data StructFieldsProp sh = LLVM.StructFields (Struct sh) => StructFieldsProp _structFieldsProp :: (C sh) => f sh -> StructFieldsProp sh _structFieldsProp _p = structFieldsRec structFieldsPropF :: (C sh) => f (g sh) -> StructFieldsProp sh structFieldsPropF _p = structFieldsRec withStructFieldsPropFF :: (C sh) => (StructFieldsProp sh -> f (g (h sh))) -> f (g (h sh)) withStructFieldsPropFF f = f structFieldsRec structFieldsRec :: (C sh) => StructFieldsProp sh structFieldsRec = switchInt StructFieldsProp (succStructFieldsProp structFieldsRec) succStructFieldsProp :: StructFieldsProp sh -> StructFieldsProp (sh:.Index.Int) succStructFieldsProp StructFieldsProp = StructFieldsProp data Z = Z deriving (Eq, Ord, Read, Show) infixl 3 :., #:. data tail :. head = !tail :. !head deriving (Eq, Ord, Read, Show) newtype T tag sh = Cons {decons :: sh} data ShapeTag data IndexTag type Shape = T ShapeTag type Index = T IndexTag shape :: sh -> Shape sh shape = Cons index :: ix -> Index ix index = Cons (#:.) :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i)) (#:.) = cons cons :: (Expr.Value val) => val (T tag sh) -> val i -> val (T tag (sh:.i)) cons = Expr.lift2 $ \(MultiValue.Cons t) (MultiValue.Cons h) -> MultiValue.Cons (t,h) z :: (Expr.Value val) => val (T tag Z) z = Expr.lift0 $ MultiValue.Cons () head :: (Expr.Value val) => val (T tag (sh:.i)) -> val i head = Expr.lift1 $ \(MultiValue.Cons (_t,h)) -> MultiValue.Cons h tail :: (Expr.Value val) => val (T tag (sh:.i)) -> val (T tag sh) tail = Expr.lift1 $ \(MultiValue.Cons (t,_h)) -> MultiValue.Cons t switchR :: Expr.Value val => (val (T tag sh) -> val i -> a) -> val (T tag (sh :. i)) -> a switchR f ix = f (tail ix) (head ix) type family PatternTuple pattern type family Decomposed (f :: * -> *) tag pattern type instance PatternTuple (sh:.s) = PatternTuple sh :. MultiValue.PatternTuple s type instance Decomposed f tag (sh:.s) = Decomposed f tag sh :. MultiValue.Decomposed f s type instance PatternTuple (Atom sh) = sh type instance Decomposed f tag (Atom sh) = f (T tag sh) class (Expr.Composed (Decomposed Exp tag pattern) ~ T tag (PatternTuple pattern)) => Decompose tag pattern where decompose :: T tag pattern -> Exp (T tag (PatternTuple pattern)) -> Decomposed Exp tag pattern instance Decompose tag (Atom sh) where decompose (Cons _atom) x = x instance (Decompose tag sh, Expr.Decompose s) => Decompose tag (sh :. s) where decompose (Cons (psh:.ps)) x = decompose (Cons psh) (tail x) :. Expr.decompose ps (head x) type instance MultiValue.PatternTuple (T tag sh) = T tag (PatternTuple sh) type instance MultiValue.Decomposed f (T tag sh) = Decomposed f tag sh type family Unwrap sh type instance Unwrap (T tag sh) = sh type family Tag sh type instance Tag (T tag sh) = tag instance (Expr.Compose sh, Expr.Composed sh ~ T (Tag (Expr.Composed sh)) (Unwrap (Expr.Composed sh)), Expr.Compose s) => Expr.Compose (sh :. s) where type Composed (sh :. s) = T (Tag (Expr.Composed sh)) (Unwrap (Expr.Composed sh) :. Expr.Composed s) compose (sh :. s) = cons (Expr.compose sh) (Expr.compose s) instance (Decompose tag sh) => Expr.Decompose (T tag sh) where decompose = decompose instance (C sh) => St.Storable (T tag sh) where sizeOf (Cons sh) = sizeOfArray (rank sh) (0::Word32) alignment (Cons _sh) = St.alignment (0::Word32) poke ptr = poke (castPtr ptr) . decons peek = fmap Cons . peek . castPtr type family Repr (f :: * -> *) sh type instance Repr f Z = () type instance Repr f (tail :. head) = (Repr f tail, MultiValue.Repr f head) instance (C sh) => MultiValue.C (T tag sh) where type Repr f (T tag sh) = Repr f sh cons = value undef = constant $ MultiValue.undef zero = constant $ MultiValue.zero addPhis = addPhis phis = phis instance (tag ~ ShapeTag, C sh) => Shape.C (T tag sh) where type Index (T tag sh) = Index sh size = fromIntegral . size . decons sizeCode = computeSize intersectCode = Expr.unliftM2 intersect flattenIndexRec sh ix = -- a joint implementation would not be more efficient liftM2 (,) (computeSize sh) (flattenIndex sh ix) loop = loop type family Struct sh type instance Struct Z = () type instance Struct (sh :. Index.Int) = (Word32, Struct sh) instance (C sh, LLVM.StructFields (Struct sh)) => MultiValueMemory.C (T tag sh) where type Struct (T tag sh) = LLVM.Struct (Struct sh) load = loadMultiValue store = storeMultiValue loadMultiValue :: (C sh) => LLVM.Value (Ptr (LLVM.Struct (Struct sh))) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) loadMultiValue ptr = withStructFieldsPropFF $ \StructFieldsProp -> load =<< castPtrValue ptr storeMultiValue :: (C sh) => MultiValue.T (T tag sh) -> LLVM.Value (Ptr (LLVM.Struct (Struct sh))) -> LLVM.CodeGenFunction r () storeMultiValue x ptr = case structFieldsPropF x of StructFieldsProp -> store x =<< castPtrValue ptr newtype FlattenIndex r sh = FlattenIndex { runFlattenIndex :: MultiValue.T (Shape sh) -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Word32) } flattenIndex :: (C sh) => MultiValue.T (Shape sh) -> MultiValue.T (Index sh) -> LLVM.CodeGenFunction r (LLVM.Value Word32) flattenIndex = runFlattenIndex $ switchInt (FlattenIndex $ \_zerosh _zeroix -> return A.zero) (FlattenIndex $ switchR $ \sh (MultiValue.Cons s) -> switchR $ \ix (MultiValue.Cons i) -> A.add i =<< A.mul s =<< flattenIndex sh ix) newtype Rank sh = Rank {runRank :: sh -> Int} rank :: (C sh) => sh -> Int rank = runRank $ switch (Rank $ const 0) (Rank $ succ . rank . (\(sh :. _s) -> sh)) newtype Peek sh = Peek {runPeek :: Ptr Word32 -> IO sh} peek :: (C sh) => Ptr Word32 -> IO sh peek = runPeek $ switchInt (Peek $ const $ return Z) (Peek $ \ptr -> do h <- St.peek ptr t <- peek $ advancePtr ptr 1 return (t :. Index.Int h)) newtype Poke sh = Poke {runPoke :: Ptr Word32 -> sh -> IO ()} poke :: (C sh) => Ptr Word32 -> sh -> IO () poke = runPoke $ switchInt (Poke $ const $ const $ return ()) (Poke $ \ptr (sh :. Index.Int i) -> do St.poke ptr i poke (advancePtr ptr 1) sh) castPtrValue :: (LLVM.StructFields sh) => LLVM.Value (Ptr (LLVM.Struct sh)) -> LLVM.CodeGenFunction r (LLVM.Value (Ptr Word32)) castPtrValue = LLVM.bitcast newtype Load r tag sh = Load { runLoad :: LLVM.Value (Ptr Word32) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) } load :: (C sh) => LLVM.Value (Ptr Word32) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) load = runLoad $ switchInt (Load $ const $ return z) (Load $ \ptr -> do h <- LLVM.load ptr t <- load =<< A.advanceArrayElementPtr ptr return (t #:. MultiValue.Cons h)) newtype Store r tag sh = Store { runStore :: MultiValue.T (T tag sh) -> LLVM.Value (Ptr Word32) -> LLVM.CodeGenFunction r () } store :: (C sh) => MultiValue.T (T tag sh) -> LLVM.Value (Ptr Word32) -> LLVM.CodeGenFunction r () store = runStore $ switchInt (Store $ \_z _ptr -> return ()) (Store $ switchR $ \sh (MultiValue.Cons k) ptr -> do LLVM.store k ptr store sh =<< A.advanceArrayElementPtr ptr) newtype Size sh = Size { runSize :: sh -> Word32 } size :: (C sh) => sh -> Word32 size = runSize $ switchInt (Size $ \_z -> 1) (Size $ \(sh :. Index.Int k) -> k * size sh) newtype ComputeSize r sh = ComputeSize { runComputeSize :: MultiValue.T (Shape sh) -> LLVM.CodeGenFunction r (LLVM.Value Word32) } computeSize :: (C sh) => MultiValue.T (Shape sh) -> LLVM.CodeGenFunction r (LLVM.Value Word32) computeSize = runComputeSize $ switchInt (ComputeSize $ \_z -> return A.one) (ComputeSize $ switchR $ \sh (MultiValue.Cons k) -> A.mul k =<< computeSize sh) newtype Constant val tag sh = Constant {getConstant :: val Index.Int -> val (T tag sh)} constant :: (C sh, Expr.Value val) => val Index.Int -> val (T tag sh) constant = getConstant $ switchInt (Constant $ const z) (Constant $ \x -> constant x #:. x) newtype AddPhis r tag sh = AddPhis { runAddPhis :: LLVM.BasicBlock -> MultiValue.T (T tag sh) -> MultiValue.T (T tag sh) -> LLVM.CodeGenFunction r () } addPhis :: (C sh) => LLVM.BasicBlock -> MultiValue.T (T tag sh) -> MultiValue.T (T tag sh) -> LLVM.CodeGenFunction r () addPhis = runAddPhis $ switchInt (AddPhis $ \_ _ _ -> return ()) (AddPhis $ \bb -> switchR $ \hx tx -> switchR $ \hy ty -> MultiValue.addPhis bb tx ty >> addPhis bb hx hy) newtype Phis r tag sh = Phis { runPhis :: LLVM.BasicBlock -> MultiValue.T (T tag sh) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) } phis :: (C sh) => LLVM.BasicBlock -> MultiValue.T (T tag sh) -> LLVM.CodeGenFunction r (MultiValue.T (T tag sh)) phis = runPhis $ switchInt (Phis $ \_ -> return) (Phis $ \bb -> switchR $ \h t -> liftM2 (#:.) (phis bb h) (MultiValue.phis bb t)) newtype Loop r state sh = Loop { runLoop :: (MultiValue.T (Index sh) -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T (Shape sh) -> state -> LLVM.CodeGenFunction r state } loop :: (C sh, Loop.Phi state) => (MultiValue.T (Index sh) -> state -> LLVM.CodeGenFunction r state) -> MultiValue.T (Shape sh) -> state -> LLVM.CodeGenFunction r state loop = runLoop $ switchInt (Loop $ \code _z -> code z) (Loop $ \code -> switchR $ \sh (MultiValue.Cons n) -> loop (\ix ptrStart -> fmap fst $ C.fixedLengthLoop n (ptrStart, A.zero) $ \(ptr, k) -> liftM2 (,) (code (ix #:. MultiValue.Cons k) ptr) (A.inc k)) sh)