{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Matrix.Square.Linear (
   solve,
   inverse,
   determinant,
   ) where

import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout
import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent
import qualified Numeric.LAPACK.Permutation.Private as Perm
import qualified Numeric.LAPACK.Private as Private
import Numeric.LAPACK.Linear.Private
         (solver, withDeterminantInfo, withInfo, diagonalMsg)
import Numeric.LAPACK.Matrix.Layout.Private (transposeFromOrder)
import Numeric.LAPACK.Matrix.Private (Full, Square, SquareMeas, argSquare)
import Numeric.LAPACK.Private
         (withAutoWorkspaceInfo, copyBlock, copyToTemp, copyToColumnMajorTemp)

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class

import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array))

import System.IO.Unsafe (unsafePerformIO)

import Foreign.Marshal.Array (peekArray)
import Foreign.ForeignPtr (withForeignPtr)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)


solve, _solve ::
   (Extent.Measure meas, Extent.C vert, Extent.C horiz,
    Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) =>
   Square sh a ->
   Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve :: Square sh a
-> Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
solve =
   (Order
 -> sh
 -> ForeignPtr a
 -> Full meas vert horiz sh nrhs a
 -> Full meas vert horiz sh nrhs a)
-> Square sh a
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall sh a b.
(Order -> sh -> ForeignPtr a -> b) -> Square sh a -> b
argSquare ((Order
  -> sh
  -> ForeignPtr a
  -> Full meas vert horiz sh nrhs a
  -> Full meas vert horiz sh nrhs a)
 -> Square sh a
 -> Full meas vert horiz sh nrhs a
 -> Full meas vert horiz sh nrhs a)
-> (Order
    -> sh
    -> ForeignPtr a
    -> Full meas vert horiz sh nrhs a
    -> Full meas vert horiz sh nrhs a)
-> Square sh a
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Order
orderA sh
shA ForeignPtr a
a ->
   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall meas vert horiz height width a.
(Measure meas, C vert, C horiz, C height, C width, Eq height,
 Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz height width a
-> Full meas vert horiz height width a
solver String
"Square.solve" sh
shA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full meas vert horiz sh nrhs a
 -> Full meas vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
transposeFromOrder Order
orderA
      Ptr a
aPtr <- Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r. Storable a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
n
      Ptr CInt
ipivPtr <- Int -> FortranIO () (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"getrf" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.getrf Ptr CInt
nPtr Ptr CInt
nPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr CInt
ipivPtr
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"getrs" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.getrs Ptr CChar
transPtr Ptr CInt
nPtr Ptr CInt
nrhsPtr
               Ptr a
aPtr Ptr CInt
ldaPtr Ptr CInt
ipivPtr Ptr a
xPtr Ptr CInt
ldxPtr

_solve :: Square sh a
-> Full meas vert horiz sh nrhs a -> Full meas vert horiz sh nrhs a
_solve =
   (Order
 -> sh
 -> ForeignPtr a
 -> Full meas vert horiz sh nrhs a
 -> Full meas vert horiz sh nrhs a)
-> Square sh a
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall sh a b.
(Order -> sh -> ForeignPtr a -> b) -> Square sh a -> b
argSquare ((Order
  -> sh
  -> ForeignPtr a
  -> Full meas vert horiz sh nrhs a
  -> Full meas vert horiz sh nrhs a)
 -> Square sh a
 -> Full meas vert horiz sh nrhs a
 -> Full meas vert horiz sh nrhs a)
-> (Order
    -> sh
    -> ForeignPtr a
    -> Full meas vert horiz sh nrhs a
    -> Full meas vert horiz sh nrhs a)
-> Square sh a
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Order
orderA sh
shA ForeignPtr a
a ->
   String
-> sh
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall meas vert horiz height width a.
(Measure meas, C vert, C horiz, C height, C width, Eq height,
 Floating a) =>
String
-> height
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz height width a
-> Full meas vert horiz height width a
solver String
"Square.solve" sh
shA ((Int
  -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
 -> Full meas vert horiz sh nrhs a
 -> Full meas vert horiz sh nrhs a)
-> (Int
    -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT () IO ())
-> Full meas vert horiz sh nrhs a
-> Full meas vert horiz sh nrhs a
forall a b. (a -> b) -> a -> b
$ \Int
n Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
xPtr Ptr CInt
ldxPtr -> do
      Ptr a
aPtr <- Order -> Int -> Int -> ForeignPtr a -> ContT () IO (Ptr a)
forall a r.
Floating a =>
Order -> Int -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToColumnMajorTemp Order
orderA Int
n Int
n ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
n
      Ptr CInt
ipivPtr <- Int -> FortranIO () (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ do
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"gesv" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.gesv Ptr CInt
nPtr Ptr CInt
nrhsPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr CInt
ipivPtr Ptr a
xPtr Ptr CInt
ldxPtr


inverse ::
   (Extent.Measure meas, Shape.C height, Shape.C width, Class.Floating a) =>
   SquareMeas meas height width a -> SquareMeas meas width height a
inverse :: SquareMeas meas height width a -> SquareMeas meas width height a
inverse (Array shape :: SquareMeas meas height width
shape@(Layout.Full Order
_order Extent meas Small Small height width
extent) ForeignPtr a
a) =
   Full meas Small Small width height
-> (Int -> Ptr a -> IO ()) -> SquareMeas meas width height a
forall sh a.
(C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO ()) -> Array sh a
Array.unsafeCreateWithSize (SquareMeas meas height width -> Full meas Small Small width height
forall meas vert horiz height width.
(Measure meas, C vert, C horiz) =>
Full meas vert horiz height width
-> Full meas horiz vert width height
Layout.inverse SquareMeas meas height width
shape) ((Int -> Ptr a -> IO ()) -> SquareMeas meas width height a)
-> (Int -> Ptr a -> IO ()) -> SquareMeas meas width height a
forall a b. (a -> b) -> a -> b
$
      \Int
blockSize Ptr a
bPtr -> do
   let n :: Int
n = height -> Int
forall sh. C sh => sh -> Int
Shape.size (height -> Int) -> height -> Int
forall a b. (a -> b) -> a -> b
$ Extent meas Small Small height width -> height
forall meas vert horiz height width.
(Measure meas, C vert, C horiz) =>
Extent meas vert horiz height width -> height
Extent.height Extent meas Small Small height width
extent
   ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> ContT () IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
n
      Ptr CInt
ipivPtr <- Int -> FortranIO () (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
nInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
         Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
blockSize Ptr a
aPtr Ptr a
bPtr
         String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
"getrf" ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.getrf Ptr CInt
nPtr Ptr CInt
nPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr CInt
ipivPtr
         String
-> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
forall a.
Floating a =>
String
-> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspaceInfo String
diagonalMsg String
"getri" ((Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ())
-> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr CInt
-> IO ()
LapackGen.getri Ptr CInt
nPtr Ptr a
bPtr Ptr CInt
ldbPtr Ptr CInt
ipivPtr


determinant :: (Shape.C sh, Class.Floating a) => Square sh a -> a
determinant :: Square sh a -> a
determinant = (Order -> sh -> ForeignPtr a -> a) -> Square sh a -> a
forall sh a b.
(Order -> sh -> ForeignPtr a -> b) -> Square sh a -> b
argSquare ((Order -> sh -> ForeignPtr a -> a) -> Square sh a -> a)
-> (Order -> sh -> ForeignPtr a -> a) -> Square sh a -> a
forall a b. (a -> b) -> a -> b
$ \Order
_order sh
sh ForeignPtr a
a -> IO a -> a
forall a. IO a -> a
unsafePerformIO (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ do
   let n :: Int
n = sh -> Int
forall sh. C sh => sh -> Int
Shape.size sh
sh
   ContT a IO a -> IO a
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT a IO a -> IO a) -> ContT a IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
aPtr <- Int -> ForeignPtr a -> ContT a IO (Ptr a)
forall a r. Storable a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
n
      Ptr CInt
ipivPtr <- Int -> FortranIO a (Ptr CInt)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      IO a -> ContT a IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> ContT a IO a) -> IO a -> ContT a IO a
forall a b. (a -> b) -> a -> b
$ String -> (Ptr CInt -> IO ()) -> IO a -> IO a
forall a.
Floating a =>
String -> (Ptr CInt -> IO ()) -> IO a -> IO a
withDeterminantInfo String
"getrf"
         (Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr CInt -> Ptr CInt -> IO ()
LapackGen.getrf Ptr CInt
nPtr Ptr CInt
nPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr CInt
ipivPtr)
         (do
            a
det <- Int -> Ptr a -> Int -> IO a
forall a. Floating a => Int -> Ptr a -> Int -> IO a
Private.product Int
n Ptr a
aPtr (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
            [CInt]
ipiv <- Int -> Ptr CInt -> IO [CInt]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
n Ptr CInt
ipivPtr
            a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$ [CInt] -> a -> a
forall a. Floating a => [CInt] -> a -> a
Perm.condNegate [CInt]
ipiv a
det)