{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Split where import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Matrix.Mosaic.Private as Mos import qualified Numeric.LAPACK.Matrix.Triangular.Basic as Tri import qualified Numeric.LAPACK.Matrix.Private as Matrix import qualified Numeric.LAPACK.Matrix.Extent.Private as Extent import qualified Numeric.LAPACK.Private as Private import Numeric.LAPACK.Matrix.Mosaic.Private (diagonalPointers) import Numeric.LAPACK.Matrix.Triangular.Basic (Lower, Upper) import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor, ColumnMajor), transposeFromOrder, swapOnRowMajor, sideSwapFromOrder, Triangle, uploFromOrder, flipOrder) import Numeric.LAPACK.Matrix.Extent.Private (Extent) import Numeric.LAPACK.Matrix.Modifier (Transposition, transposeOrder, Conjugation(NonConjugated, Conjugated)) import Numeric.LAPACK.Matrix.Private (Full) import Numeric.LAPACK.Linear.Private (solver, withInfo) import Numeric.LAPACK.Scalar (zero, one) import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked)) import Numeric.LAPACK.Private (copyBlock, conjugateToTemp) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array)) import System.IO.Unsafe (unsafePerformIO) import Foreign.C.Types (CInt, CChar) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (poke) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) type Split lower meas vert horiz height width = Array (Layout.Split lower meas vert horiz height width) type Square lower sh = Split lower Extent.Shape Extent.Small Extent.Small sh sh mapExtent :: (Extent.Measure measA, Extent.C vertA, Extent.C horizA) => (Extent.Measure measB, Extent.C vertB, Extent.C horizB) => Extent.Map measA vertA horizA measB vertB horizB height width -> Split lower measA vertA horizA height width a -> Split lower measB vertB horizB height width a mapExtent = Array.mapShape . Layout.splitMapExtent mapExtentSizes :: (Extent measA vertA horizA heightA widthA -> Extent measB vertB horizB heightB widthB) -> Split lower measA vertA horizA heightA widthA a -> Split lower measB vertB horizB heightB widthB a mapExtentSizes f = Array.mapShape (\(Layout.Split lowerPart order extent) -> Layout.Split lowerPart order $ f extent) mapHeight :: (Extent.C vert, Extent.C horiz) => (heightA -> heightB) -> Split lower Extent.Size vert horiz heightA width a -> Split lower Extent.Size vert horiz heightB width a mapHeight = mapExtentSizes . Extent.mapHeight mapWidth :: (Extent.C vert, Extent.C horiz) => (widthA -> widthB) -> Split lower Extent.Size vert horiz height widthA a -> Split lower Extent.Size vert horiz height widthB a mapWidth = mapExtentSizes . Extent.mapWidth uncheck :: (Extent.Measure meas, Extent.C vert, Extent.C horiz) => Split lower meas vert horiz height width a -> Split lower meas vert horiz (Unchecked height) (Unchecked width) a uncheck = mapExtentSizes $ Extent.mapWrap Unchecked Unchecked recheck :: (Extent.Measure meas, Extent.C vert, Extent.C horiz) => Split lower meas vert horiz (Unchecked height) (Unchecked width) a -> Split lower meas vert horiz height width a recheck = mapExtentSizes Extent.recheck heightToQuadratic :: (Extent.Measure meas) => Split lower meas Extent.Small Extent.Small height width a -> Square lower height a heightToQuadratic = Array.mapShape $ \(Layout.Split part order_ extent_) -> Layout.Split part order_ $ Extent.square $ Extent.height extent_ widthToQuadratic :: (Extent.Measure meas) => Split lower meas Extent.Small Extent.Small height width a -> Square lower width a widthToQuadratic = Array.mapShape $ \(Layout.Split part order_ extent_) -> Layout.Split part order_ $ Extent.square $ Extent.width extent_ determinantR :: (Extent.Measure meas, Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) => Split lower meas vert Extent.Small height width a -> a determinantR (Array (Layout.Split _ order extent) a) = let (height,width) = Extent.dimensions extent m = Shape.size height n = Shape.size width k = case order of RowMajor -> n; ColumnMajor -> m in unsafePerformIO $ withForeignPtr a $ \aPtr -> Private.product (min m n) aPtr (k+1) extractTriangle :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => Either lower Triangle -> Split lower meas vert horiz height width a -> Full meas vert horiz height width a extractTriangle part (Array (Layout.Split _ order extent) qr) = Array.unsafeCreate (Layout.Full order extent) $ \rPtr -> do let (height,width) = Extent.dimensions extent let ((loup,m), (uplo,n)) = swapOnRowMajor order (('L', Shape.size height), ('U', Shape.size width)) evalContT $ do loupPtr <- Call.char loup uploPtr <- Call.char uplo mPtr <- Call.cint m nPtr <- Call.cint n qrPtr <- ContT $ withForeignPtr qr ldqrPtr <- Call.leadingDim m ldrPtr <- Call.leadingDim m zeroPtr <- Call.number zero onePtr <- Call.number one liftIO $ case part of Left _ -> do LapackGen.lacpy loupPtr mPtr nPtr qrPtr ldqrPtr rPtr ldrPtr LapackGen.laset uploPtr mPtr nPtr zeroPtr onePtr rPtr ldrPtr Right _ -> do LapackGen.laset loupPtr mPtr nPtr zeroPtr zeroPtr rPtr ldrPtr LapackGen.lacpy uploPtr mPtr nPtr qrPtr ldqrPtr rPtr ldrPtr wideExtractL :: (Extent.Measure meas, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => Split lower meas Extent.Small horiz height width a -> Lower height a wideExtractL = Mos.fromLowerPart (\order m lPtr -> mapM_ (flip poke one) $ diagonalPointers order m lPtr) Layout.NoMirror . toFull tallExtractR :: (Extent.Measure meas, Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) => Split lower meas vert Extent.Small height width a -> Upper width a tallExtractR = Tri.takeUpper . toFull toFull :: Split lower meas vert horiz height width a -> Full meas vert horiz height width a toFull = Array.mapShape (\(Layout.Split _ order extent) -> Layout.Full order extent) wideMultiplyL :: (Extent.Measure measA, Extent.C horizA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C widthA, Shape.C widthB, Class.Floating a) => Transposition -> Split Triangle measA Extent.Small horizA height widthA a -> Full meas vert horiz height widthB a -> Full meas vert horiz height widthB a wideMultiplyL transposed a b = if Layout.splitHeight (Array.shape a) == Matrix.height b then multiplyTriangular ('L','U') 'U' transposed a b else error "wideMultiplyL: height shapes mismatch" tallMultiplyR :: (Extent.Measure measA, Extent.C vertA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C heightA, Shape.C widthB, Class.Floating a) => Transposition -> Split lower measA vertA Extent.Small heightA height a -> Full meas vert horiz height widthB a -> Full meas vert horiz height widthB a tallMultiplyR transposed a b = if Layout.splitWidth (Array.shape a) == Matrix.height b then multiplyTriangular ('U','L') 'N' transposed a b else error "wideMultiplyR: height shapes mismatch" multiplyTriangular :: (Extent.Measure measA, Extent.C vertA, Extent.C horizA, Extent.Measure measB, Extent.C vertB, Extent.C horizB, Shape.C heightA, Shape.C widthA, Shape.C heightB, Shape.C widthB, Class.Floating a) => (Char,Char) -> Char -> Transposition -> Split lower measA vertA horizA heightA widthA a -> Full measB vertB horizB heightB widthB a -> Full measB vertB horizB heightB widthB a multiplyTriangular (normalPart,transposedPart) diag transposed (Array (Layout.Split _ orderA extentA) a) (Array (Layout.Full orderB extentB) b) = Array.unsafeCreate (Layout.Full orderB extentB) $ \cPtr -> do let (heightA,widthA) = Extent.dimensions extentA let (heightB,widthB) = Extent.dimensions extentB let transOrderB = transposeOrder transposed orderB let ((uplo, transa), lda) = case orderA of RowMajor -> ((transposedPart, flipOrder transOrderB), Shape.size widthA) ColumnMajor -> ((normalPart, transOrderB), Shape.size heightA) let (side,(m,n)) = sideSwapFromOrder orderB (Shape.size heightB, Shape.size widthB) evalContT $ do sidePtr <- Call.char side uploPtr <- Call.char uplo transaPtr <- Call.char $ transposeFromOrder transa diagPtr <- Call.char diag mPtr <- Call.cint m nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim lda bPtr <- ContT $ withForeignPtr b ldcPtr <- Call.leadingDim m alphaPtr <- Call.number one liftIO $ do copyBlock (m*n) bPtr cPtr BlasGen.trmm sidePtr uploPtr transaPtr diagPtr mPtr nPtr alphaPtr aPtr ldaPtr cPtr ldcPtr wideSolveL :: (Extent.Measure measA, Extent.C horizA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) => Transposition -> Conjugation -> Split Triangle measA Extent.Small horizA height width a -> Full meas vert horiz height nrhs a -> Full meas vert horiz height nrhs a wideSolveL transposed conjugated (Array (Layout.Split _ orderA extentA) a) = let heightA = Extent.height extentA in solver "Split.wideSolveL" heightA $ \n nPtr nrhsPtr xPtr ldxPtr -> do uploPtr <- Call.char $ uploFromOrder $ flipOrder orderA diagPtr <- Call.char 'U' let m = Shape.size heightA solveTriangular transposed conjugated orderA m n a uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr tallSolveR :: (Extent.Measure measA, Extent.C vertA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) => Transposition -> Conjugation -> Split lower measA vertA Extent.Small height width a -> Full meas vert horiz width nrhs a -> Full meas vert horiz width nrhs a tallSolveR transposed conjugated (Array (Layout.Split _ orderA extentA) a) = let (heightA,widthA) = Extent.dimensions extentA in solver "Split.tallSolveR" widthA $ \n nPtr nrhsPtr xPtr ldxPtr -> do uploPtr <- Call.char $ uploFromOrder orderA diagPtr <- Call.char 'N' let m = Shape.size heightA solveTriangular transposed conjugated orderA m n a uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr solveTriangular :: Class.Floating a => Transposition -> Conjugation -> Order -> Int -> Int -> ForeignPtr a -> Ptr CChar -> Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr CInt -> ContT r IO () solveTriangular transposed conjugated orderA m n a uploPtr diagPtr nPtr nrhsPtr xPtr ldxPtr = do let (trans, getA) = case (transposeOrder transposed orderA, conjugated) of (RowMajor, NonConjugated) -> ('T', ContT $ withForeignPtr a) (RowMajor, Conjugated) -> ('C', ContT $ withForeignPtr a) (ColumnMajor, NonConjugated) -> ('N', ContT $ withForeignPtr a) (ColumnMajor, Conjugated) -> ('N', conjugateToTemp (m*n) a) transPtr <- Call.char trans aPtr <- getA ldaPtr <- Call.leadingDim $ case orderA of ColumnMajor -> m; RowMajor -> n liftIO $ withInfo "trtrs" $ LapackGen.trtrs uploPtr transPtr diagPtr nPtr nrhsPtr aPtr ldaPtr xPtr ldxPtr