{-# LANGUAGE ForeignFunctionInterface #-}
module Numeric.LinearAlgebra.Banded (
   UpperMatrix(..),
   SymmetricMatrix(..),
   choleskyDecompose,
   CholeskySolve, choleskySolve,
   ) where

import qualified Data.Packed.Development as Dev
import qualified Data.Packed.Matrix as Matrix
import Data.Packed.Matrix (Matrix)
import Data.Packed.Vector (Vector)

import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.Marshal.Array as Array
import qualified Foreign.Marshal.Alloc as Alloc
import qualified Foreign.C.String as CStr
import qualified Foreign.C.Types as C
import Foreign.Storable (Storable, peek)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (pure, (<*>))

import Text.Printf (printf)

import System.IO.Unsafe (unsafePerformIO)


foreign import ccall "spbtrf_"
   spbtrf ::
      Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt ->
      Ptr Float -> Ptr C.CInt ->
      Ptr C.CInt -> IO ()

foreign import ccall "dpbtrf_"
   dpbtrf ::
      Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt ->
      Ptr Double -> Ptr C.CInt ->
      Ptr C.CInt -> IO ()


foreign import ccall "spbtrs_"
   spbtrs ::
      Ptr C.CChar -> Ptr C.CInt ->
      Ptr C.CInt -> Ptr C.CInt ->
      Ptr Float -> Ptr C.CInt ->
      Ptr Float -> Ptr C.CInt ->
      Ptr C.CInt -> IO ()

foreign import ccall "dpbtrs_"
   dpbtrs ::
      Ptr C.CChar -> Ptr C.CInt ->
      Ptr C.CInt -> Ptr C.CInt ->
      Ptr Double -> Ptr C.CInt ->
      Ptr Double -> Ptr C.CInt ->
      Ptr C.CInt -> IO ()


class (Matrix.Element a) => C a where
   pbtrf ::
      Ptr C.CChar -> Ptr C.CInt -> Ptr C.CInt ->
      Ptr a -> Ptr C.CInt ->
      Ptr C.CInt -> IO ()
   pbtrs ::
      Ptr C.CChar -> Ptr C.CInt ->
      Ptr C.CInt -> Ptr C.CInt ->
      Ptr a -> Ptr C.CInt ->
      Ptr a -> Ptr C.CInt ->
      Ptr C.CInt -> IO ()

instance C Float where
   pbtrf = spbtrf
   pbtrs = spbtrs

instance C Double where
   pbtrf = dpbtrf
   pbtrs = dpbtrs


{- |
Stores an upper triangular band matrix in the form:

> a11 a12 a13
> a22 a23 a24
> a33 a34 a35
> a44 a45 a46
> a55 a56 a57
-}
newtype UpperMatrix a = UpperMatrix (Matrix a)
   deriving (Show)

{- |
Stores the upper half of a symmetric band matrix
in the same layout as 'UpperMatrix'.
-}
newtype SymmetricMatrix a = SymmetricMatrix (Matrix a)
   deriving (Show)

{-
import Numeric.Container ((<>))
let m = Matrix.fromLists [[1,2,0,0],[0,3,4,0],[0,0,5,6],[0,0,0,7 :: Double]]
let mm = Matrix.trans m <> m
let a = SymmetricMatrix $ Matrix.fromLists [[1,2],[13,12],[41,30],[85,0::Double]]
let u = choleskyDecompose a
-}
choleskyDecompose :: (C a) => SymmetricMatrix a -> UpperMatrix a
choleskyDecompose (SymmetricMatrix a) = unsafePerformIO $ do
   u <- cloneVector $ Matrix.flatten a
   evalContT $
      runChecked "Banded.choleskyDecompose" $
         pure pbtrf
          <*> string "L"
          <*> cint (Matrix.rows a)
          <*> cint (Matrix.cols a - 1)
          <*> vector u
          <*> cint (Matrix.cols a)

   return $ UpperMatrix $
      Dev.matrixFromVector Dev.RowMajor (Matrix.cols a) u


class CholeskySolve c where
   choleskySolve :: (C a) => UpperMatrix a -> c a -> c a

instance CholeskySolve Vector where choleskySolve = choleskySolveSingle
instance CholeskySolve Matrix where choleskySolve = choleskySolveMany


choleskySolveSingle :: (C a) => UpperMatrix a -> Vector a -> Vector a
choleskySolveSingle u =
   Matrix.flatten . choleskySolveMany u . Matrix.asColumn

{-
let x = Matrix.fromLists [[2,3],[5,7],[11,13],[17,19::Double]]
choleskySolveMany u (mm <> x)
-}
choleskySolveMany :: (C a) => UpperMatrix a -> Matrix a -> Matrix a
choleskySolveMany (UpperMatrix u) rhs = unsafePerformIO $ do
   when (Matrix.rows u /= Matrix.rows rhs) $ error $
      printf "Banded.choleskySolve: number of rows mismatch (%i /= %i)"
         (Matrix.rows u) (Matrix.rows rhs)
   x <- cloneVector $ Matrix.flatten $ Matrix.trans rhs
   evalContT $
      runChecked "Banded.choleskySolveMany" $
         pure pbtrs
          <*> string "L"
          <*> cint (Matrix.rows u)
          <*> cint (Matrix.cols u - 1)
          <*> cint (Matrix.cols rhs)
          <*> vector (Matrix.flatten u)
          <*> cint (Matrix.cols u)
          <*> vector x
          <*> cint (Matrix.rows rhs)

   return $ Dev.matrixFromVector Dev.ColumnMajor (Matrix.cols rhs) x


cloneVector :: (Storable a) => Vector a -> IO (Vector a)
cloneVector v =
   evalContT $ do
      (vPtr,len) <- vectorLen v
      w <- liftIO $ Dev.createVector len
      wPtr <- vector w
      liftIO $ Array.copyArray wPtr vPtr len
      return w


type FortranIO r = ContT r IO

run :: FortranIO r (IO a) -> FortranIO r a
run act = act >>= liftIO

runChecked :: String -> FortranIO r (Ptr C.CInt -> IO a) -> FortranIO r a
runChecked name act = do
   info <- alloca
   a <- run $ fmap ($info) act
   liftIO $ Dev.check name (peek info)
   return a

cint :: Int -> FortranIO r (Ptr C.CInt)
cint = ContT . Marshal.with . fromIntegral

alloca :: (Storable a) => FortranIO r (Ptr a)
alloca = ContT Alloc.alloca

string :: String -> FortranIO r (Ptr C.CChar)
string = ContT . CStr.withCString

vectorLen :: (Storable a) => Vector a -> FortranIO r (Ptr a, Int)
vectorLen v =
   let (fptr, offset, len) = Dev.unsafeToForeignPtr v
   in  ContT $ \f ->
       withForeignPtr fptr $ \ptr ->
          f (Array.advancePtr ptr offset, len)

vector :: (Storable a) => Vector a -> FortranIO r (Ptr a)
vector = fmap fst . vectorLen