module Data.Array.Accelerate.Arithmetic.Interpolation (
   bisect,
   lookupInterval,
   Interpolator13, sampleBasisFunctions13,
   ) where

import qualified Data.Array.Accelerate.LinearAlgebra.Matrix.Sparse as Sparse
import qualified Data.Array.Accelerate.LinearAlgebra as LinAlg
import qualified Data.Array.Accelerate.Utility.Arrange as Arrange
import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp
import qualified Data.Array.Accelerate.Utility.Loop as Loop
import Data.Array.Accelerate.LinearAlgebra
          (Scalar, Vector, numElems, extrudeVector, )

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Exp, Any(Any), Z(Z), (:.)((:.)), )

import Data.Ord.HT (limit, )


bisect ::
   (A.Slice ix, A.Shape ix, A.Ord a, A.Elt a) =>
   Vector ix a ->
   Scalar ix a ->
   Scalar ix (Int, Int) ->
   Scalar ix (Int, Int)
bisect nodes xs bounds =
   let centers =
          A.map
             (A.uncurry $ \lower upper -> div (lower+upper) 2)
             bounds
   in  A.zipWith3
          (\center interval leftBranch ->
              A.cond leftBranch
                 (Exp.mapSnd (const center) interval)
                 (Exp.mapFst (const center) interval))
          centers bounds $
       A.zipWith (A.<) xs $
       Arrange.gather (Arrange.mapWithIndex Exp.indexCons centers) nodes

lookupInterval ::
   (A.Slice ix, A.Shape ix, A.Ord a, A.Elt a) =>
   Vector ix a ->
   Scalar ix a ->
   Scalar ix Int
lookupInterval nodes x =
   A.map A.fst $
   Loop.nestLog2 (numElems nodes) (bisect nodes x) $
   A.fill (A.shape x) $
   A.lift (0 :: Exp Int, numElems nodes)


outerVector ::
   (A.Shape ix, A.Slice ix, A.Elt a, A.Elt b, A.Elt c) =>
   (Exp a -> Exp b -> Exp c) ->
   Scalar ix a -> Vector Z b -> Vector ix c
outerVector f x y =
   A.zipWith f
      (A.replicate (A.lift $ Any :. numElems y) x)
      (extrudeVector (A.shape x) y)


{- |
One node before index 0 and three nodes starting from index 0.
-}
type Interpolator13 a = (a,a) -> (a,a) -> (a,a) -> (a,a) -> a -> a

sampleBasisFunctions13 ::
   (A.Slice ix, A.Shape ix, A.Ord a, Num a) =>
   Interpolator13 (Exp a) ->
   Vector Z a -> Vector ix a -> Sparse.Rows ix a
sampleBasisFunctions13 interpolate nodes zs =
   Sparse.Rows (numElems nodes) $
   let indices = lookupInterval (extrudeVector (A.shape zs) nodes) zs
       minIx = 1
       maxIx = numElems nodes - 3
       limitIndices = A.map (limit (minIx, maxIx)) indices
       gatherFromNodes d =
          LinAlg.gatherFromVector (A.map (d+) limitIndices) nodes
   in  outerVector
          (A.lift2 $
           \(n, ln, z, x) (k, y) ->
              case (Exp.unliftQuadruple x, Exp.unliftQuadruple y) of
                 ((xm1,x0,x1,x2), (ym1,y0,y1,y2)) ->
                    (ln+k :: Exp Int,
                     A.cond (n A.< minIx) y0 $
                     A.cond (n A.> maxIx) y1 $
                     interpolate (xm1,ym1) (x0,y0) (x1,y1) (x2,y2) z))
          (A.zip4 indices limitIndices zs
             (A.zip4
                (gatherFromNodes (-1))
                (gatherFromNodes 0)
                (gatherFromNodes 1)
                (gatherFromNodes 2)))
          (A.use $
           A.fromList (Z:.4)
              [(-1, (1,0,0,0)), (0, (0,1,0,0)), (1, (0,0,1,0)), (2, (0,0,0,1))])