{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Split where import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import qualified Numeric.LAPACK.Matrix.Shape.Box as Box import qualified Numeric.LAPACK.Matrix.Triangular.Private as TriPriv import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Tri import qualified Numeric.LAPACK.Matrix.Private as Matrix import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent import qualified Numeric.LAPACK.Private as Private import Numeric.LAPACK.Matrix.Triangular.Private (diagonalPointers, unpack) import Numeric.LAPACK.Matrix.Triangular.Basic (UnitLower, Upper) import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor), transposeFromOrder, swapOnRowMajor, sideSwapFromOrder, Triangle, uploFromOrder, flipOrder) import Numeric.LAPACK.Matrix.Modifier (Transposition, transposeOrder, Conjugation(NonConjugated, Conjugated)) import Numeric.LAPACK.Matrix.Private (Full) import Numeric.LAPACK.Linear.Private (solver, withInfo) import Numeric.LAPACK.Scalar (zero, one) import Numeric.LAPACK.Private (copyBlock, conjugateToTemp) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import System.IO.Unsafe (unsafePerformIO) import Foreign.C.Types (CInt, CChar) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (poke) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) type Split lower vert horiz height width = Array (MatrixShape.Split lower vert horiz height width) type Square lower sh = Split lower Extent.Small Extent.Small sh sh determinantR :: (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) => Split lower vert Extent.Small height width a -> a determinantR (Array (MatrixShape.Split _ order extent) a) = let (height,width) = Extent.dimensions extent m = Shape.size height n = Shape.size width k = case order of RowMajor -> n; ColumnMajor -> m in unsafePerformIO $ withForeignPtr a $ \aPtr -> Private.product (min m n) aPtr (k+1) extractTriangle :: (Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => Either lower Triangle -> Split lower vert horiz height width a -> Full vert horiz height width a extractTriangle part (Array (MatrixShape.Split _ order extent) qr) = Array.unsafeCreate (MatrixShape.Full order extent) $ \rPtr -> do let (height,width) = Extent.dimensions extent let ((loup,m), (uplo,n)) = swapOnRowMajor order (('L', Shape.size height), ('U', Shape.size width)) evalContT $ do loupPtr <- Call.char loup uploPtr <- Call.char uplo mPtr <- Call.cint m nPtr <- Call.cint n qrPtr <- ContT $ withForeignPtr qr ldqrPtr <- Call.leadingDim m ldrPtr <- Call.leadingDim m zeroPtr <- Call.number zero onePtr <- Call.number one liftIO $ case part of Left _ -> do LapackGen.lacpy loupPtr mPtr nPtr qrPtr ldqrPtr rPtr ldrPtr LapackGen.laset uploPtr mPtr nPtr zeroPtr onePtr rPtr ldrPtr Right _ -> do LapackGen.laset loupPtr mPtr nPtr zeroPtr zeroPtr rPtr ldrPtr LapackGen.lacpy uploPtr mPtr nPtr qrPtr ldqrPtr rPtr ldrPtr wideExtractL :: (Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => Split lower Extent.Small horiz height width a -> UnitLower height a wideExtractL = TriPriv.takeLower (MatrixShape.Unit, \order m lPtr -> mapM_ (flip poke one) $ diagonalPointers order m lPtr) . toFull tallExtractR :: (Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) => Split lower vert Extent.Small height width a -> Upper width a tallExtractR = Tri.takeUpper . toFull toFull :: Split lower vert horiz height width a -> Full vert horiz height width a toFull = Array.mapShape (\(MatrixShape.Split _ order extent) -> MatrixShape.Full order extent) wideMultiplyL :: (Extent.C horizA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C widthA, Shape.C widthB, Class.Floating a) => Transposition -> Split Triangle Extent.Small horizA height widthA a -> Full vert horiz height widthB a -> Full vert horiz height widthB a wideMultiplyL transposed a b = if MatrixShape.splitHeight (Array.shape a) == Matrix.height b then multiplyTriangular ('L','U') 'U' transposed a b else error "wideMultiplyL: height shapes mismatch" tallMultiplyR :: (Extent.C vertA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C heightA, Shape.C widthB, Class.Floating a) => Transposition -> Split lower vertA Extent.Small heightA height a -> Full vert horiz height widthB a -> Full vert horiz height widthB a tallMultiplyR transposed a b = if MatrixShape.splitWidth (Array.shape a) == Matrix.height b then multiplyTriangular ('U','L') 'N' transposed a b else error "wideMultiplyR: height shapes mismatch" multiplyTriangular :: (Extent.C vertA, Extent.C horizA, Extent.C vertB, Extent.C horizB, Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB, Class.Floating a) => (Char,Char) -> Char -> Transposition -> Split lower vertA horizA heightA widthA a -> Full vertB horizB heightB widthB a -> Full vertB horizB heightB widthB a multiplyTriangular (normalPart,transposedPart) diag transposed (Array (MatrixShape.Split _ orderA extentA) a) (Array (MatrixShape.Full orderB extentB) b) = Array.unsafeCreate (MatrixShape.Full orderB extentB) $ \cPtr -> do let (heightA,widthA) = Extent.dimensions extentA let (heightB,widthB) = Extent.dimensions extentB let transOrderB = transposeOrder transposed orderB let ((uplo, transa), lda) = case orderA of RowMajor -> ((transposedPart, flipOrder transOrderB), Shape.size widthA) ColumnMajor -> ((normalPart, transOrderB), Shape.size heightA) let (side,(m,n)) = sideSwapFromOrder orderB (Shape.size heightB, Shape.size widthB) evalContT $ do sidePtr <- Call.char side uploPtr <- Call.char uplo transaPtr <- Call.char $ transposeFromOrder transa diagPtr <- Call.char diag mPtr <- Call.cint m nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim lda bPtr <- ContT $ withForeignPtr b ldcPtr <- Call.leadingDim m alphaPtr <- Call.number one liftIO $ do copyBlock (m*n) bPtr cPtr BlasGen.trmm sidePtr uploPtr transaPtr diagPtr mPtr nPtr alphaPtr aPtr ldaPtr cPtr ldcPtr wideSolveL :: (Extent.C horizA, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) => Transposition -> Conjugation -> Split Triangle Extent.Small horizA height width a -> Full vert horiz height nrhs a -> Full vert horiz height nrhs a wideSolveL transposed conjugated (Array (MatrixShape.Split _ orderA extentA) a) = let heightA = Extent.height extentA in solver "Split.wideSolveL" heightA $ \n nPtr nrhsPtr xPtr ldxPtr -> do uploPtr <- Call.char $ uploFromOrder $ flipOrder orderA diagPtr <- Call.char 'U' let m = Shape.size heightA solveTriangular transposed conjugated orderA m n a uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr tallSolveR :: (Extent.C vertA, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) => Transposition -> Conjugation -> Split lower vertA Extent.Small height width a -> Full vert horiz width nrhs a -> Full vert horiz width nrhs a tallSolveR transposed conjugated (Array (MatrixShape.Split _ orderA extentA) a) = let (heightA,widthA) = Extent.dimensions extentA in solver "Split.tallSolveR" widthA $ \n nPtr nrhsPtr xPtr ldxPtr -> do uploPtr <- Call.char $ uploFromOrder orderA diagPtr <- Call.char 'N' let m = Shape.size heightA solveTriangular transposed conjugated orderA m n a uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr solveTriangular :: Class.Floating a => Transposition -> Conjugation -> Order -> Int -> Int -> ForeignPtr a -> Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT r IO () solveTriangular transposed conjugated orderA m n a uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr = do let (trans, getA) = case (transposeOrder transposed orderA, conjugated) of (RowMajor, NonConjugated) -> ('T', ContT $ withForeignPtr a) (RowMajor, Conjugated) -> ('C', ContT $ withForeignPtr a) (ColumnMajor, NonConjugated) -> ('N', ContT $ withForeignPtr a) (ColumnMajor, Conjugated) -> ('N', conjugateToTemp (m*n) a) transPtr <- Call.char trans aPtr <- getA ldaPtr <- Call.leadingDim $ case orderA of ColumnMajor -> m; RowMajor -> n liftIO $ withInfo "trtrs" $ LapackGen.trtrs uploPtr transPtr diagPtr nPtr nrhsPtr aPtr ldaPtr xPtr ldxPtr data Corrupt = Corrupt deriving (Eq) {- We could use Plain.Class.ShapeOrder but this would currently cause an import cycle. -} {- | > let b = takeHalf a > ==> > isTriangular b && a == addTransposed b -} takeHalf :: (Box.Box symShape, Box.HeightOf symShape ~ sh, Shape.C sh, Class.Floating a) => (symShape -> Order) -> Array symShape a -> Square Corrupt sh a takeHalf shapeOrder (Array symShape a) = let sh = Box.height symShape order = shapeOrder symShape in Array.unsafeCreate (MatrixShape.Split Corrupt order (Extent.square sh)) $ \bPtr -> evalContT $ do let n = Shape.size sh aPtr <- ContT $ withForeignPtr a nPtr <- Call.cint n alphaPtr <- Call.number 0.5 incxPtr <- Call.cint (n+1) liftIO $ do unpack order n aPtr bPtr BlasGen.scal nPtr alphaPtr bPtr incxPtr