{- |
This module demonstrates triangular matrices.

It verifies that the divided difference scheme
nicely fits into a triangular matrix,
where function addition is mapped to matrix addition
and function multiplication is mapped to matrix multiplication.

<http://en.wikipedia.org/wiki/Divided_difference>
-}
module Numeric.LAPACK.Example.DividedDifference where

import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Square as Square
import qualified Numeric.LAPACK.Matrix.Layout as Layout
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix (ShapeInt, (#+#), (#-#))
import Numeric.LAPACK.Vector (Vector, (|+|), (|-|))
import Numeric.LAPACK.Format ((##))

import qualified Data.Array.Comfort.Shape as Shape
import qualified Data.Array.Comfort.Storable as Array
import Foreign.Storable (Storable)

import qualified Data.List as List
import Data.Semigroup ((<>))


{- $setup
>>> import qualified Test.Utility as Util
>>> import Test.Utility (approxArray)
>>>
>>> import qualified Numeric.LAPACK.Example.DividedDifference as DD
>>> import qualified Numeric.LAPACK.Vector as Vector
>>> import Numeric.LAPACK.Example.DividedDifference (dividedDifferencesMatrix)
>>> import Numeric.LAPACK.Matrix (ShapeInt, shapeInt, (#+#))
>>> import Numeric.LAPACK.Vector ((|+|))
>>>
>>> import qualified Data.Array.Comfort.Storable as Array
>>>
>>> import qualified Test.QuickCheck as QC
>>>
>>> import Control.Monad (liftM2)
>>> import Data.Tuple.HT (mapPair)
>>> import Data.Semigroup ((<>))
>>>
>>> type Vector = Vector.Vector ShapeInt Float
>>>
>>> genDD :: QC.Gen (Vector, (Vector, Vector))
>>> genDD = do
>>>    (ys0,ys1) <-
>>>       fmap (mapPair (Vector.autoFromList, Vector.autoFromList) .
>>>             unzip . take 10) $
>>>       QC.listOf $ liftM2 (,) (Util.genElement 10) (Util.genElement 10)
>>>    xs <- Util.genDistinct [-10..10] [-10..10] $ Array.shape ys0
>>>    return (xs,(ys0,ys1))
-}


size :: Vector ShapeInt a -> Int
size = Shape.zeroBasedSize . Array.shape

subSlices :: Int -> Vector ShapeInt Float -> Vector ShapeInt Float
subSlices k xs = Vector.drop k xs |-| Vector.take (size xs - k) xs

parameterDifferences :: Vector ShapeInt Float -> [Vector ShapeInt Float]
parameterDifferences xs = map (flip subSlices xs) [1 .. size xs - 1]

dividedDifferences ::
   Vector ShapeInt Float -> Vector ShapeInt Float -> [Vector ShapeInt Float]
dividedDifferences xs ys =
   scanl
      (\ddys dxs -> Vector.divide (subSlices 1 ddys) dxs)
      ys
      (parameterDifferences xs)

upperFromPyramid ::
   (Shape.C sh, Storable a) => sh -> [Vector sh a] -> Triangular.Upper sh a
upperFromPyramid sh =
   Triangular.upperFromList Layout.RowMajor sh .
   concat . List.transpose . map Vector.toList

{- |
prop> QC.forAll genDD $ \(xs, (ys0,ys1)) -> approxArray (dividedDifferencesMatrix xs (ys0|+|ys1)) (dividedDifferencesMatrix xs ys0 #+# dividedDifferencesMatrix xs ys1)
prop> QC.forAll genDD $ \(xs, (ys0,ys1)) -> approxArray (dividedDifferencesMatrix xs (Vector.mul ys0 ys1)) (dividedDifferencesMatrix xs ys0 <> dividedDifferencesMatrix xs ys1)
-}
dividedDifferencesMatrix ::
   Vector ShapeInt Float -> Vector ShapeInt Float ->
   Triangular.Upper ShapeInt Float
dividedDifferencesMatrix xs ys =
   upperFromPyramid (Array.shape xs) $ dividedDifferences xs ys


{- |
prop> QC.forAll (QC.choose (0,10)) $ \n -> let sh = shapeInt n in QC.forAll (Util.genDistinct [-10..10] [-10..10] sh) $ \xs -> approxArray (DD.parameterDifferencesMatrix xs) (DD.upperFromPyramid sh (Vector.zero sh : DD.parameterDifferences xs))
-}
parameterDifferencesMatrix ::
   Vector ShapeInt Float -> Triangular.Upper ShapeInt Float
parameterDifferencesMatrix xs =
   let ones = Vector.one $ Array.shape xs
       tp = Matrix.tensorProduct Layout.RowMajor
   in Triangular.takeUpper $ Square.fromFull $ tp ones xs #-# tp xs ones


main :: IO ()
main = do
   let xs  = Vector.autoFromList [0,1,4,9,16]
   let ys0 = Vector.autoFromList [3,1,4,1,5]
   let ys1 = Vector.autoFromList [2,7,1,8,1]

   mapM_ (## "%.4f") $ parameterDifferences xs
   parameterDifferencesMatrix xs ## "%.4f"

   let ddys0 = dividedDifferencesMatrix xs ys0
   let ddys1 = dividedDifferencesMatrix xs ys1
   ddys0 ## "%.4f"
   ddys1 ## "%.4f"
   putStrLn ""

   dividedDifferencesMatrix xs (ys0|+|ys1) ## "%.4f"
   ddys0 #+# ddys1 ## "%.4f"
   putStrLn ""

   dividedDifferencesMatrix xs (Vector.mul ys0 ys1) ## "%.4f"
   ddys0 <> ddys1 ## "%.4f"