{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Singular ( values, decompose, decomposeNarrow, decomposeSquat, determinantAbsolute, leastSquaresMinimumNormRCond, pseudoInverseRCond, RealOf, ) where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Matrix.Square as Square import qualified Numeric.LAPACK.Matrix as Matrix import qualified Numeric.LAPACK.Vector as Vector import Numeric.LAPACK.Matrix.Square (Square) import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor,ColumnMajor)) import Numeric.LAPACK.Matrix.Private (General, ZeroInt, zeroInt) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Private (RealOf, withAutoWorkspace, fromReal, allocArray, allocHigherArray, copyToTemp, copyToColumnMajor, copyToSubColumnMajor) import qualified Numeric.LAPACK.FFI.Complex as LapackComplex import qualified Numeric.LAPACK.FFI.Real as LapackReal import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import System.IO.Unsafe (unsafePerformIO) import Foreign.Marshal.Array (allocaArray) import Foreign.Marshal.Alloc (alloca) import Foreign.C.Types (CInt, CChar) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr, nullPtr) import Foreign.Storable (Storable, peek) import Control.Monad.Trans.Cont (evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative ((<$>)) import Text.Printf (printf) import Data.Complex (Complex) values :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> Vector ZeroInt (RealOf a) values = getValues $ Class.switchFloating (Values valuesAux) (Values valuesAux) (Values valuesAux) (Values valuesAux) type Values_ height width a = General height width a -> Vector ZeroInt (RealOf a) newtype Values height width a = Values {getValues :: Values_ height width a} valuesAux :: (Shape.C height, Shape.C width, Class.Floating a, RealOf a ~ ar, Storable ar) => Values_ height width a valuesAux (Array shape@(MatrixShape.General _order height width) a) = Array.unsafeCreateWithSize (zeroInt $ min (Shape.size height) (Shape.size width)) $ \mn sPtr -> do let (m,n) = MatrixShape.dimensions shape let lda = m evalContT $ do jobuPtr <- Call.char 'N' jobvtPtr <- Call.char 'N' mPtr <- Call.cint m nPtr <- Call.cint n aPtr <- copyToTemp (m*n) a ldaPtr <- Call.cint lda let uPtr = nullPtr let vtPtr = nullPtr lduPtr <- Call.cint m ldvtPtr <- Call.cint n liftIO $ withInfo "gesvd" $ \infoPtr -> gesvd jobuPtr jobvtPtr mPtr nPtr aPtr ldaPtr sPtr uPtr lduPtr vtPtr ldvtPtr mn infoPtr determinantAbsolute :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> RealOf a determinantAbsolute = getDeterminantAbsolute $ Class.switchFloating (DeterminantAbsolute determinantAbsoluteAux) (DeterminantAbsolute determinantAbsoluteAux) (DeterminantAbsolute determinantAbsoluteAux) (DeterminantAbsolute determinantAbsoluteAux) newtype DeterminantAbsolute f a = DeterminantAbsolute { getDeterminantAbsolute :: f a -> RealOf a } determinantAbsoluteAux :: (Shape.C height, Shape.C width, Class.Floating a, RealOf a ~ ar, Class.Floating ar) => General height width a -> ar determinantAbsoluteAux = Vector.product . values decompose :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> (Square height a, Vector ZeroInt (RealOf a), Square width a) decompose = getDecompose $ Class.switchFloating (Decompose decomposeAux) (Decompose decomposeAux) (Decompose decomposeAux) (Decompose decomposeAux) newtype Decompose m f v g a = Decompose { getDecompose :: m a -> (f a, v (RealOf a), g a) } decomposeAux :: (Shape.C height, Shape.C width, Class.Floating a, RealOf a ~ ar, Storable ar) => General height width a -> (Square height a, Vector ZeroInt ar, Square width a) decomposeAux (Array (MatrixShape.General order height width) a) = unsafePerformIO $ evalContT $ do (u,uPtr0) <- allocArray (MatrixShape.Square order height) (vt,vtPtr0) <- allocArray (MatrixShape.Square order width) let ((m,n),(uPtr,vtPtr)) = case order of RowMajor -> ((Shape.size width, Shape.size height), (vtPtr0,uPtr0)) ColumnMajor -> ((Shape.size height, Shape.size width), (uPtr0,vtPtr0)) let mn = min m n let lda = m jobuPtr <- Call.char 'A' jobvtPtr <- Call.char 'A' mPtr <- Call.cint m nPtr <- Call.cint n aPtr <- copyToTemp (m*n) a ldaPtr <- Call.cint lda (s,sPtr) <- allocArray (zeroInt mn) lduPtr <- Call.cint m ldvtPtr <- Call.cint n liftIO $ withInfo "gesvd" $ \infoPtr -> gesvd jobuPtr jobvtPtr mPtr nPtr aPtr ldaPtr sPtr uPtr lduPtr vtPtr ldvtPtr mn infoPtr return (u, s, vt) decomposeSquat :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> (Square height a, Vector height (RealOf a), General height width a) decomposeSquat a = let (u,s,vt) = decomposeNarrow $ Matrix.transpose a in (Square.transpose vt, s, Matrix.transpose u) decomposeNarrow :: (Shape.C height, Shape.C width, Class.Floating a) => General height width a -> (General height width a, Vector width (RealOf a), Square width a) decomposeNarrow = getDecompose $ Class.switchFloating (Decompose decomposeThin) (Decompose decomposeThin) (Decompose decomposeThin) (Decompose decomposeThin) decomposeThin :: (Shape.C height, Shape.C width, Class.Floating a, RealOf a ~ ar, Storable ar) => General height width a -> (General height width a, Vector width ar, Square width a) decomposeThin (Array (MatrixShape.General order height width) a) = unsafePerformIO $ do Call.assert "Singular.decomposeThin: matrix is wider than high" (Shape.size height >= Shape.size width) evalContT $ do (u,uPtr0) <- allocArray (MatrixShape.General order height width) (vt,vtPtr0) <- allocArray (MatrixShape.Square order width) let ((m,n),(uPtr,vtPtr)) = case order of RowMajor -> ((Shape.size width, Shape.size height), (vtPtr0,uPtr0)) ColumnMajor -> ((Shape.size height, Shape.size width), (uPtr0,vtPtr0)) let mn = min m n let lda = m jobuPtr <- Call.char 'S' jobvtPtr <- Call.char 'S' mPtr <- Call.cint m nPtr <- Call.cint n aPtr <- copyToTemp (m*n) a ldaPtr <- Call.cint lda (s,sPtr) <- allocArray width lduPtr <- Call.cint m ldvtPtr <- Call.cint mn liftIO $ withInfo "gesvd" $ \infoPtr -> gesvd jobuPtr jobvtPtr mPtr nPtr aPtr ldaPtr sPtr uPtr lduPtr vtPtr ldvtPtr mn infoPtr return (u, s, vt) type GESVD_ ar a = Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr ar -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Int -> Ptr CInt -> IO () newtype GESVD a = GESVD {getGESVD :: GESVD_ (RealOf a) a} gesvd :: Class.Floating a => GESVD_ (RealOf a) a gesvd = getGESVD $ Class.switchFloating (GESVD gesvdReal) (GESVD gesvdReal) (GESVD gesvdComplex) (GESVD gesvdComplex) gesvdReal :: (Class.Real a) => GESVD_ a a gesvdReal jobuPtr jobvtPtr mPtr nPtr aPtr ldaPtr sPtr uPtr lduPtr vtPtr ldvtPtr _mn infoPtr = withAutoWorkspace $ \workPtr lworkPtr -> LapackReal.gesvd jobuPtr jobvtPtr mPtr nPtr aPtr ldaPtr sPtr uPtr lduPtr vtPtr ldvtPtr workPtr lworkPtr infoPtr gesvdComplex :: (Class.Real a) => GESVD_ a (Complex a) gesvdComplex jobuPtr jobvtPtr mPtr nPtr aPtr ldaPtr sPtr uPtr lduPtr vtPtr ldvtPtr mn infoPtr = allocaArray (5*mn) $ \rworkPtr -> withAutoWorkspace $ \workPtr lworkPtr -> LapackComplex.gesvd jobuPtr jobvtPtr mPtr nPtr aPtr ldaPtr sPtr uPtr lduPtr vtPtr ldvtPtr workPtr lworkPtr rworkPtr infoPtr leastSquaresMinimumNormRCond :: (Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) => RealOf a -> General height width a -> General height nrhs a -> (Int, General width nrhs a) leastSquaresMinimumNormRCond = getLeastSquaresMinimumNormRCond $ Class.switchFloating (LeastSquaresMinimumNormRCond leastSquaresMinimumNormRCondAux) (LeastSquaresMinimumNormRCond leastSquaresMinimumNormRCondAux) (LeastSquaresMinimumNormRCond leastSquaresMinimumNormRCondAux) (LeastSquaresMinimumNormRCond leastSquaresMinimumNormRCondAux) newtype LeastSquaresMinimumNormRCond f g h a = LeastSquaresMinimumNormRCond { getLeastSquaresMinimumNormRCond :: RealOf a -> f a -> g a -> (Int, h a) } leastSquaresMinimumNormRCondAux :: (Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a, RealOf a ~ ar, Class.Floating ar) => ar -> General height width a -> General height nrhs a -> (Int, General width nrhs a) leastSquaresMinimumNormRCondAux rcond (Array (MatrixShape.General orderA heightA widthA) a) (Array (MatrixShape.General orderB heightB widthB) b) = unsafePerformIO $ do Call.assert "leastSquaresMinimumNorm: height shapes mismatch" (heightA == heightB) let shapeX = MatrixShape.General ColumnMajor widthA widthB let m = Shape.size heightA let n = Shape.size widthA let nrhs = Shape.size widthB let mn = min m n let aSize = m*n let lda = m evalContT $ do mPtr <- Call.cint m nPtr <- Call.cint n nrhsPtr <- Call.cint nrhs aPtr <- Call.allocaArray aSize liftIO $ withForeignPtr a $ \asrcPtr -> copyToColumnMajor orderA m n asrcPtr aPtr ldaPtr <- Call.cint lda (x,(tmpPtr,ldtmp)) <- allocHigherArray shapeX m n nrhs ldtmpPtr <- Call.cint ldtmp liftIO $ withForeignPtr b $ \bPtr -> copyToSubColumnMajor orderB m nrhs bPtr ldtmp tmpPtr sPtr <- Call.allocaArray mn rcondPtr <- Call.number rcond rankPtr <- Call.alloca liftIO $ withInfo "gelss" $ \infoPtr -> gelss mPtr nPtr nrhsPtr aPtr ldaPtr tmpPtr ldtmpPtr sPtr rcondPtr rankPtr mn infoPtr rank <- liftIO $ fromIntegral <$> peek rankPtr return (rank, x) type GELSS_ ar a = Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr ar -> Ptr ar -> Ptr CInt -> Int -> Ptr CInt -> IO () newtype GELSS a = GELSS {getGELSS :: GELSS_ (RealOf a) a} gelss :: Class.Floating a => GELSS_ (RealOf a) a gelss = getGELSS $ Class.switchFloating (GELSS gelssReal) (GELSS gelssReal) (GELSS gelssComplex) (GELSS gelssComplex) gelssReal :: (Class.Real a) => GELSS_ a a gelssReal mPtr nPtr nrhsPtr aPtr ldaPtr bPtr ldbPtr sPtr rcondPtr rankPtr _mn infoPtr = withAutoWorkspace $ \workPtr lworkPtr -> LapackReal.gelss mPtr nPtr nrhsPtr aPtr ldaPtr bPtr ldbPtr sPtr rcondPtr rankPtr workPtr lworkPtr infoPtr gelssComplex :: (Class.Real a) => GELSS_ a (Complex a) gelssComplex mPtr nPtr nrhsPtr aPtr ldaPtr bPtr ldbPtr sPtr rcondPtr rankPtr mn infoPtr = allocaArray (5*mn) $ \rworkPtr -> withAutoWorkspace $ \workPtr lworkPtr -> LapackComplex.gelss mPtr nPtr nrhsPtr aPtr ldaPtr bPtr ldbPtr sPtr rcondPtr rankPtr workPtr lworkPtr rworkPtr infoPtr pseudoInverseRCond :: (Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) => RealOf a -> General height width a -> (Int, General width height a) pseudoInverseRCond = getPseudoInverseRCond $ Class.switchFloating (PseudoInverseRCond pseudoInverseRCondAux) (PseudoInverseRCond pseudoInverseRCondAux) (PseudoInverseRCond pseudoInverseRCondAux) (PseudoInverseRCond pseudoInverseRCondAux) newtype PseudoInverseRCond f g a = PseudoInverseRCond { getPseudoInverseRCond :: RealOf a -> f a -> (Int, g a) } pseudoInverseRCondAux :: (Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a, RealOf a ~ ar, Class.Real ar) => RealOf a -> General height width a -> (Int, General width height a) pseudoInverseRCondAux rcond a = let (MatrixShape.General _ height width) = Array.shape a in if Shape.size height < Shape.size width then let (u,s,vt) = decomposeSquat a (rank,recipS) = recipSigma rcond s in (rank, Matrix.multiply (Matrix.adjoint vt) $ Matrix.scaleRows recipS $ Square.toGeneral $ Square.adjoint u) else let (u,s,vt) = decomposeNarrow a (rank,recipS) = recipSigma rcond s in (rank, Matrix.multiply (Square.toGeneral $ Square.adjoint vt) $ Matrix.scaleRows recipS $ Matrix.adjoint u) recipSigma :: (Shape.C sh, Class.Floating a, RealOf a ~ ar, Class.Real ar) => ar -> Array sh ar -> (Int, Array sh a) recipSigma rcond sigmas = case Array.toList sigmas of [] -> (0, Array.map fromReal sigmas) xs@(x:_) -> let smin = x * rcond in (length (takeWhile (>=smin) xs), Array.map (\s -> if s>=smin then fromReal (recip s) else 0) sigmas) withInfo :: String -> (Ptr CInt -> IO ()) -> IO () withInfo name computation = alloca $ \infoPtr -> do computation infoPtr info <- fromIntegral <$> peek infoPtr case compare info (0::Int) of EQ -> return () LT -> error $ printf "%s: illegal value in %d-th argument" name (-info) GT -> error $ printf "%s: %d superdiagonals did not converge" name info