{-# LANGUAGE TypeFamilies #-} module Data.Array.Knead.Arithmetic.Interpolation ( bisect, lookupInterval, Interpolator13, sampleBasisFunctions13, ) where import qualified Data.Array.Knead.Arithmetic.LinearAlgebra as LinAlg import qualified Data.Array.Knead.Arithmetic.Sparse as Sparse import Data.Array.Knead.Arithmetic.LinearAlgebra (Scalar, Vector, Matrix, IOScalar) import qualified Data.Array.Knead.Parameterized.Physical as Phys import qualified Data.Array.Knead.Parameterized.Symbolic as SymP import qualified Data.Array.Knead.Simple.Physical as SimPhys import qualified Data.Array.Knead.Simple.ShapeDependent as ShapeDep import qualified Data.Array.Knead.Simple.Symbolic as Sym import qualified Data.Array.Knead.Index.Nested.Shape as Shape import qualified Data.Array.Knead.Expression as Expr import Data.Array.Knead.Expression (Exp) import qualified LLVM.Extra.Multi.Value.Memory as MultiValueMemory import qualified LLVM.Extra.Multi.Value as MultiValue import LLVM.Extra.Multi.Value (atom) import qualified LLVM.Core as LLVM import Foreign.Storable (Storable) import Control.Arrow (arr) import Control.Monad.HT (chain) import Control.Applicative (pure) import qualified Data.List.Match as Match bisect :: (Shape.C coll, Shape.C nodes, Shape.Index nodes ~ i, MultiValue.IntegerConstant i, MultiValue.Integral i, MultiValue.Select i, MultiValue.Comparison a) => Vector p coll nodes a -> Scalar p coll a -> Scalar p coll (i, i) -> Scalar p coll (i, i) bisect nodes xs bounds = let centers = Sym.map (Expr.modify (atom, atom) $ \(lower, upper) -> Expr.idiv (Expr.add lower upper) $ Expr.fromInteger' 2) bounds in Sym.zipWith3 (Expr.liftM3 $ \center interval leftBranch -> MultiValue.select leftBranch (MultiValue.mapSnd (const center) interval) (MultiValue.mapFst (const center) interval)) centers bounds $ Sym.zipWith (Expr.liftM2 $ MultiValue.cmp LLVM.CmpLT) xs $ Sym.gather (Sym.mapWithIndex Expr.zip centers) nodes nestLog2 :: (Integral i, Monad m) => i -> (a -> m a) -> a -> m a nestLog2 i f = chain $ Match.replicate (takeWhile (>1) $ iterate (flip div 2) i) f lookupInterval :: (Shape.C coll, Shape.C nodes, Shape.Index nodes ~ i, nodes ~ i, MultiValue.IntegerConstant i, MultiValue.Integral i, MultiValue.Select i, Num i, MultiValue.Comparison a, MultiValueMemory.C nodes, Storable nodes, MultiValueMemory.Struct nodes ~ nodesStruct, LLVM.IsSized nodesStruct, MultiValueMemory.C i, Storable i, MultiValueMemory.Struct i ~ iStruct, LLVM.IsSized iStruct, MultiValueMemory.C coll, Storable coll, MultiValueMemory.C a, Storable a) => Vector p coll nodes a -> Scalar p coll a -> IOScalar p coll i lookupInterval nodes x = do fill <- Phys.render $ SymP.fill (arr fst) (fmap ((,) 0) $ arr snd) bis <- Phys.render $ bisect (SymP.extendParameter fst nodes) (SymP.extendParameter fst x) (Phys.feed $ arr snd) getFst <- Phys.render $ Sym.map Expr.fst $ Phys.feed $ arr id getNodesShape <- Phys.renderShape nodes getXShape <- Phys.renderShape x return $ \p -> do (_,numElems) <- getNodesShape p (xShape,_) <- getXShape p getFst =<< nestLog2 numElems (curry bis p) =<< fill (xShape, fromIntegral numElems) outerVector :: (Shape.C coll, Shape.C dim) => (Exp a -> Exp b -> Exp c) -> Scalar p coll a -> SymP.Array p dim b -> Vector p coll dim c outerVector = ShapeDep.backpermute2 Expr.zip Expr.fst Expr.snd zipWithScalar :: (Shape.C shape) => (Exp a -> Exp b -> Exp c) -> SymP.Array p () a -> SymP.Array p shape b -> SymP.Array p shape c zipWithScalar = ShapeDep.backpermute2 (flip const) (const Expr.unit) id {- | One node before index 0 and three nodes starting from index 0. -} type Interpolator13 a = (a,a) -> (a,a) -> (a,a) -> (a,a) -> a -> a sampleBasisFunctions13Aux :: (Shape.C coll, Shape.C rows, Shape.C nodes, Shape.C set, MultiValueMemory.C set, Storable set, Num set, Shape.Index nodes ~ i, MultiValue.Comparison i, MultiValue.PseudoRing i, MultiValue.IntegerConstant i, MultiValueMemory.C i, Storable i, Num i, MultiValue.Select a, MultiValue.Real a, MultiValue.Field a, MultiValue.RationalConstant a, Num a, Storable a, MultiValueMemory.C a, MultiValueMemory.Struct a ~ astruct, LLVM.IsSized astruct, MultiValueMemory.Struct i ~ istruct, LLVM.IsSized istruct) => Interpolator13 (Exp a) -> SymP.Array p () (i,i) -> Vector p coll rows i -> SymP.Array p nodes a -> Vector p coll rows a -> IO (Matrix p coll rows set (i, a)) sampleBasisFunctions13Aux interpolate minMaxIx indices nodes zs = do let limitIndices = zipWithScalar (\mm -> case Expr.unzip mm of (minIx,maxIx) -> Expr.max minIx . Expr.min maxIx) minMaxIx indices gatherFromNodes d = Sym.gather (Sym.map (d+) limitIndices) nodes units <- SimPhys.vectorFromList [(-1, (1,0,0,0)), (0, (0,1,0,0)), (1, (0,0,1,0)), (2, (0,0,0,1))] return $ ShapeDep.backpermute LinAlg.balanceRight LinAlg.balanceLeft $ outerVector (Expr.liftM2 $ MultiValue.modifyF2 (atom, atom, atom, (atom, atom, atom, atom)) ((atom, atom), (atom, (atom, atom, atom, atom))) $ \(n, ln, z, (xm1,x0,x1,x2)) ((minIx, maxIx), (k, (ym1,y0,y1,y2))) -> do lnk <- MultiValue.add ln k tooSmall <- MultiValue.cmp LLVM.CmpLT n minIx tooLarge <- MultiValue.cmp LLVM.CmpGT n maxIx y <- MultiValue.select tooSmall y0 =<< MultiValue.select tooLarge y1 =<< Expr.unExp (interpolate (Expr.lift0 xm1, Expr.lift0 ym1) (Expr.lift0 x0, Expr.lift0 y0) (Expr.lift0 x1, Expr.lift0 y1) (Expr.lift0 x2, Expr.lift0 y2) (Expr.lift0 z)) return (lnk, y)) (Sym.zip4 indices limitIndices zs (Sym.zip4 (gatherFromNodes (-1)) (gatherFromNodes 0) (gatherFromNodes 1) (gatherFromNodes 2))) (zipWithScalar Expr.zip minMaxIx $ Phys.feed $ pure units) sampleBasisFunctions13 :: (Shape.Index nodes ~ nodes, Shape.C coll, Shape.C rows, Shape.C nodes, Shape.C set, MultiValueMemory.C set, Storable set, Num set, MultiValue.Comparison nodes, MultiValue.PseudoRing nodes, MultiValue.IntegerConstant nodes, MultiValue.Integral nodes, MultiValue.Select nodes, MultiValueMemory.C nodes, Storable nodes, Num nodes, MultiValueMemory.C rows, Storable rows, MultiValueMemory.C coll, Storable coll, MultiValue.Select a, MultiValue.Comparison a, MultiValue.Field a, MultiValue.RationalConstant a, Num a, Storable a, MultiValueMemory.C a, MultiValueMemory.Struct a ~ astruct, LLVM.IsSized astruct, MultiValueMemory.Struct nodes ~ nodesstruct, LLVM.IsSized nodesstruct, MultiValueMemory.Struct coll ~ collstruct, LLVM.IsSized collstruct, MultiValueMemory.Struct rows ~ rowsstruct, LLVM.IsSized rowsstruct) => Interpolator13 (Exp a) -> SymP.Array p nodes a -> Vector p coll rows a -> IO (p -> IO (Sparse.RowMatrix p coll rows set nodes a)) sampleBasisFunctions13 interpolate nodes zs = do indices <- lookupInterval (outerVector (flip const) zs nodes) zs return $ \p -> do indexArr <- indices p let minMaxIx = Sym.map (\numElems -> Expr.zip 1 (numElems - 3)) $ ShapeDep.shape nodes basis <- sampleBasisFunctions13Aux interpolate minMaxIx (Phys.feed $ pure indexArr) nodes zs return $ Sparse.RowMatrix $ ShapeDep.backpermuteExtra (Expr.modify2 (atom, (atom,atom)) atom $ \(coll, (dim,set)) numElems -> (coll, (dim, Sparse.Dim set numElems))) id basis nodes