{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE EmptyDataDecls #-} module KneadShape where import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import qualified LLVM.Extra.Multi.Value.Memory as MultiMem import qualified LLVM.Extra.Multi.Value as MultiValue import qualified LLVM.Extra.Arithmetic as A import LLVM.Extra.Multi.Value (atom) import qualified LLVM.Core as LLVM import qualified Type.Data.Num.Decimal as TypeNum import Foreign.Storable (Storable, sizeOf, alignment, poke, pokeElemOff, peek, peekElemOff) import Foreign.Ptr (Ptr, castPtr) import qualified Control.Monad.HT as Monad import Control.Monad (join) import Data.Int (Int64) {- | I choose a bit complicated Dim2 definition to make it distinct from size pairs with width and height swapped. Alternatives would be Index.Linear or intentionally complicated Shape types like: type Dim0 = () type Dim1 = ((), Size) type Dim2 = ((), Size, Size) Problems with Index.Linear is that it is fixed to Word32 dimensions which causes trouble with negative coordinates that we encounter on rotations. The custom shape type requires lots of new definitions but it is certainly the cleanest solution. -} type Size = Int64 type Dim0 = () type Dim1 = Size type Dim2 = Shape2 Size type Ix2 = Index2 Size data Vec2 tag i = Vec2 {vertical, horizontal :: i} data ShapeTag data IndexTag type Shape2 = Vec2 ShapeTag type Index2 = Vec2 IndexTag squareShape :: n -> Vec2 tag n squareShape n = Vec2 n n castToElemPtr :: Ptr (Vec2 tag a) -> Ptr a castToElemPtr = castPtr instance (Storable n) => Storable (Vec2 tag n) where -- cf. sample-frame:Frame.Stereo sizeOf ~(Vec2 n m) = sizeOf n + mod (- sizeOf n) (alignment m) + sizeOf m alignment ~(Vec2 n _) = alignment n poke p (Vec2 n m) = let q = castToElemPtr p in poke q n >> pokeElemOff q 1 m peek p = let q = castToElemPtr p in Monad.lift2 Vec2 (peek q) (peekElemOff q 1) instance (MultiValue.C n) => MultiValue.C (Vec2 tag n) where type Repr f (Vec2 tag n) = Vec2 tag (MultiValue.Repr f n) cons (Vec2 n m) = MultiValue.compose $ Vec2 (MultiValue.cons n) (MultiValue.cons m) undef = MultiValue.compose $ squareShape MultiValue.undef zero = MultiValue.compose $ squareShape MultiValue.zero phis bb a = case MultiValue.decompose (squareShape atom) a of Vec2 a0 a1 -> fmap MultiValue.compose $ Monad.lift2 Vec2 (MultiValue.phis bb a0) (MultiValue.phis bb a1) addPhis bb a b = case (MultiValue.decompose (squareShape atom) a, MultiValue.decompose (squareShape atom) b) of (Vec2 a0 a1, Vec2 b0 b1) -> MultiValue.addPhis bb a0 b0 >> MultiValue.addPhis bb a1 b1 type instance MultiValue.Decomposed f (Vec2 tag pat) = Vec2 tag (MultiValue.Decomposed f pat) type instance MultiValue.PatternTuple (Vec2 tag pat) = Vec2 tag (MultiValue.PatternTuple pat) instance (MultiValue.Compose n) => MultiValue.Compose (Vec2 tag n) where type Composed (Vec2 tag n) = Vec2 tag (MultiValue.Composed n) compose (Vec2 n m) = case (MultiValue.compose n, MultiValue.compose m) of (MultiValue.Cons rn, MultiValue.Cons rm) -> MultiValue.Cons (Vec2 rn rm) instance (MultiValue.Decompose pn) => MultiValue.Decompose (Vec2 tag pn) where decompose (Vec2 pn pm) (MultiValue.Cons (Vec2 n m)) = Vec2 (MultiValue.decompose pn (MultiValue.Cons n)) (MultiValue.decompose pm (MultiValue.Cons m)) instance (MultiMem.C i) => MultiMem.C (Vec2 tag i) where type Struct (Vec2 tag i) = LLVM.Struct (MultiMem.Struct i, (MultiMem.Struct i, ())) decompose nm = Monad.lift2 zipShape (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d0) (MultiMem.decompose =<< LLVM.extractvalue nm TypeNum.d1) compose nm = case unzipShape nm of Vec2 n m -> do sn <- MultiMem.compose n sm <- MultiMem.compose m rn <- LLVM.insertvalue (LLVM.value LLVM.undef) sn TypeNum.d0 LLVM.insertvalue rn sm TypeNum.d1 unzipShape :: MultiValue.T (Vec2 tag n) -> Vec2 tag (MultiValue.T n) unzipShape = MultiValue.decompose (squareShape atom) zipShape :: MultiValue.T n -> MultiValue.T n -> MultiValue.T (Vec2 tag n) zipShape y x = MultiValue.compose $ Vec2 y x instance (tag ~ ShapeTag, Shape.C i) => Shape.C (Vec2 tag i) where type Index (Vec2 tag i) = Index2 (Shape.Index i) intersectCode a b = case (unzipShape a, unzipShape b) of (Vec2 an am, Vec2 bn bm) -> Monad.lift2 zipShape (Shape.intersectCode an bn) (Shape.intersectCode am bm) sizeCode nm = case unzipShape nm of Vec2 n m -> join $ Monad.lift2 A.mul (Shape.sizeCode n) (Shape.sizeCode m) size (Vec2 n m) = Shape.size n * Shape.size m flattenIndexRec nm ij = case (unzipShape nm, unzipShape ij) of (Vec2 n m, Vec2 i j) -> do (ns, il) <- Shape.flattenIndexRec n i (ms, jl) <- Shape.flattenIndexRec m j Monad.lift2 (,) (A.mul ns ms) (A.add jl =<< A.mul ms il) loop code nm = case unzipShape nm of Vec2 n m -> Shape.loop (\i -> Shape.loop (\j -> code (zipShape i j)) m) n instance (Expr.Compose n) => Expr.Compose (Vec2 tag n) where type Composed (Vec2 tag n) = Vec2 tag (Expr.Composed n) compose (Vec2 n m) = Expr.lift2 zipShape (Expr.compose n) (Expr.compose m) instance (Expr.Decompose p) => Expr.Decompose (Vec2 tag p) where decompose (Vec2 pn pm) vec = Vec2 (Expr.decompose pn (verticalVal vec)) (Expr.decompose pm (horizontalVal vec)) verticalVal, horizontalVal :: (Expr.Value val) => val (Vec2 tag n) -> val n verticalVal = Expr.lift1 (MultiValue.lift1 vertical) horizontalVal = Expr.lift1 (MultiValue.lift1 horizontal)