{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE EmptyDataDecls #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE StandaloneDeriving #-} module Numeric.LAPACK.Linear.Plain ( LowerUpper, Tall, Wide, Square, LiberalSquare, Transposition(..), Conjugation(..), Inversion(..), mapExtent, fromMatrix, toMatrix, solve, multiplyFull, determinant, extractP, multiplyP, extractL, wideExtractL, wideMultiplyL, wideSolveL, extractU, tallExtractU, tallMultiplyU, tallSolveU, caseTallWide, ) where import qualified Numeric.LAPACK.Matrix.Divide as Divide import qualified Numeric.LAPACK.Matrix.Multiply as Multiply import qualified Numeric.LAPACK.Matrix.Type as Matrix import qualified Numeric.LAPACK.Matrix.Banded.Basic as Banded import qualified Numeric.LAPACK.Matrix.Basic as Basic import qualified Numeric.LAPACK.Matrix.Array.Private as ArrMatrix import qualified Numeric.LAPACK.Matrix.Layout as LayoutPub import qualified Numeric.LAPACK.Matrix.Layout.Private as Layout import qualified Numeric.LAPACK.Matrix.Extent.Strict as ExtentStrict import qualified Numeric.LAPACK.Matrix.Extent.Private as ExtentPriv import qualified Numeric.LAPACK.Matrix.Extent as Extent import qualified Numeric.LAPACK.Matrix.Private as MatrixPriv import qualified Numeric.LAPACK.Permutation.Private as Perm import qualified Numeric.LAPACK.Split as Split import Numeric.LAPACK.Output ((/+/)) import Numeric.LAPACK.Matrix.Plain.Format (formatArray) import Numeric.LAPACK.Matrix.Type (Matrix, FormatMatrix(formatMatrix)) import Numeric.LAPACK.Matrix.Triangular.Basic (Lower, Upper) import Numeric.LAPACK.Matrix.Layout.Private (Order(RowMajor, ColumnMajor), Triangle(Triangle)) import Numeric.LAPACK.Matrix.Modifier (Transposition(NonTransposed, Transposed), Conjugation(NonConjugated, Conjugated), Inversion(NonInverted, Inverted)) import Numeric.LAPACK.Matrix.Private (Full) import Numeric.LAPACK.Linear.Private (solver, withInfo) import Numeric.LAPACK.Vector (Vector) import Numeric.LAPACK.Shape.Private (Unchecked(Unchecked), deconsUnchecked) import Numeric.LAPACK.Private (copyBlock, copyTransposed, copyToColumnMajor, copyToColumnMajorTemp) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO import qualified Data.Array.Comfort.Storable.Unchecked as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Unchecked (Array(Array), (!)) import Foreign.Marshal.Array (advancePtr) import Foreign.ForeignPtr (withForeignPtr, castForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Monad (forM_) import Data.Monoid ((<>)) data LU data instance Matrix LU xl xu lower upper meas vert horiz height width a where LowerUpper :: Banded.RectangularDiagonal meas vert horiz height width (Perm.Element height) -> SplitArray meas vert horiz height width a -> LowerUpperFlex lower upper meas vert horiz height width a deriving instance (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Storable a, Show height, Show width, Show a) => Show (Matrix LU xl xu lower upper meas vert horiz height width a) type SplitArray meas vert horiz height width a = Split.Split Layout.Triangle meas vert horiz height width a split_ :: Matrix LU xl xu lower upper meas vert horiz height width a -> SplitArray meas vert horiz height width a split_ (LowerUpper _pivot split) = split type LowerUpperFlex = Matrix LU () () type LowerUpper = LowerUpperFlex Layout.Filled Layout.Filled type Tall height width = LowerUpper Extent.Size Extent.Big Extent.Small height width type Wide height width = LowerUpper Extent.Size Extent.Small Extent.Big height width type LiberalSquare height width = SquareMeas Extent.Size height width type Square sh = SquareMeas Extent.Shape sh sh type SquareMeas meas height width = LowerUpper meas Extent.Small Extent.Small height width instance FormatMatrix LU where formatMatrix fmt lu@(LowerUpper _ipiv m) = Perm.format (extractP NonInverted lu) /+/ formatArray fmt m mapExtent :: (Extent.C vertA, Extent.C horizA) => (Extent.C vertB, Extent.C horizB) => ExtentStrict.Map measA vertA horizA measB vertB horizB height width -> LowerUpperFlex lower upper measA vertA horizA height width a -> LowerUpperFlex lower upper measB vertB horizB height width a mapExtent f (LowerUpper pivot split) = let g = ExtentStrict.apply f in LowerUpper (Banded.mapExtent g pivot) $ Array.mapShape (Layout.splitMapExtent g) split mapPivotHeight :: (sh0 -> sh1) -> Vector shape (Perm.Element sh0) -> Vector shape (Perm.Element sh1) mapPivotHeight _f (Array shape xs) = Array shape (castForeignPtr xs) mapHeight :: (Extent.C vert, Extent.C horiz) => (heightA -> heightB) -> LowerUpperFlex lower upper Extent.Size vert horiz heightA width a -> LowerUpperFlex lower upper Extent.Size vert horiz heightB width a mapHeight f (LowerUpper pivot split) = LowerUpper (Banded.mapHeight f $ mapPivotHeight f pivot) (Split.mapHeight f split) mapWidth :: (Extent.C vert, Extent.C horiz) => (widthA -> widthB) -> LowerUpperFlex lower upper Extent.Size vert horiz height widthA a -> LowerUpperFlex lower upper Extent.Size vert horiz height widthB a mapWidth f (LowerUpper pivot split) = LowerUpper (Banded.mapWidth f pivot) (Split.mapWidth f split) fromMatrix :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => Full meas vert horiz height width a -> LowerUpper meas vert horiz height width a fromMatrix (Array (Layout.Full order extent) a) = let (height,width) = Extent.dimensions extent in uncurry LowerUpper $ Array.unsafeCreateWithSizeAndResult (snd $ Layout.rectangularDiagonal extent) $ \_ ipivPtr -> ArrayIO.unsafeCreate (Layout.Split Layout.Triangle ColumnMajor extent) $ \luPtr -> evalContT $ do let m = Shape.size height let n = Shape.size width mPtr <- Call.cint m nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim m liftIO $ do copyToColumnMajor order m n aPtr luPtr withInfo "getrf" $ LapackGen.getrf mPtr nPtr luPtr ldaPtr (Perm.deconsElementPtr ipivPtr) solve :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Eq height, Shape.C height, Shape.C width, Class.Floating a) => Square height a -> Full meas vert horiz height width a -> Full meas vert horiz height width a solve = solveTrans NonTransposed solveTrans :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Eq height, Shape.C height, Shape.C width, Class.Floating a) => Transposition -> LowerUpperFlex lower upper Extent.Shape Extent.Small Extent.Small height height a -> Full meas vert horiz height width a -> Full meas vert horiz height width a solveTrans trans (LowerUpper (Array _ ipiv) (Array (Layout.Split Layout.Triangle orderLU extentLU) lu)) = solver "LowerUpper.solve" (Extent.squareSize extentLU) $ \n nPtr nrhsPtr xPtr ldxPtr -> do let lda = n transPtr <- Call.char $ case trans of NonTransposed -> 'N' Transposed -> 'T' aPtr <- case orderLU of RowMajor -> copyToColumnMajorTemp orderLU n n lu ColumnMajor -> ContT $ withForeignPtr lu ldaPtr <- Call.leadingDim lda ipivPtr <- fmap Perm.deconsElementPtr $ ContT $ withForeignPtr ipiv liftIO $ withInfo "getrs" $ LapackGen.getrs transPtr nPtr nrhsPtr aPtr ldaPtr ipivPtr xPtr ldxPtr {- | Caution: @LU.determinant . LU.fromMatrix@ will fail for singular matrices. -} determinant :: (Extent.Measure meas, Shape.C height, Shape.C width, Class.Floating a) => LowerUpperFlex lower upper meas Extent.Small Extent.Small height width a -> a determinant (LowerUpper ipiv split) = Perm.condNegate (map Perm.deconsElement $ Array.toList ipiv) $ Split.determinantR split extractP :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) => Inversion -> LowerUpperFlex lower upper meas vert horiz height width a -> Perm.Permutation height extractP inverted (LowerUpper ipiv _) = Perm.fromTruncatedPivots (Inverted <> inverted) (Array.mapShape Perm.Shape ipiv) multiplyP :: (Extent.Measure measA, Extent.C vertA, Extent.C horizA, Extent.Measure measB, Extent.C vertB, Extent.C horizB, Eq height, Shape.C height, Shape.C widthA, Shape.C widthB, Class.Floating a) => Inversion -> LowerUpperFlex lower upper measA vertA horizA height widthA a -> Full measB vertB horizB height widthB a -> Full measB vertB horizB height widthB a multiplyP inverted (LowerUpper ipiv@(Array shapeIPiv ipivFPtr) (Array (Layout.Split _ _ extentLU) _lu)) (Array shape@(Layout.Full order extent) a) = Array.unsafeCreate shape $ \bPtr -> do Call.assert "multiplyP: heights mismatch" (Extent.height extentLU == Extent.height extent) let (height,width) = Extent.dimensions extent let m = Shape.size height let n = Shape.size width let k = Shape.size shapeIPiv evalContT $ do aPtr <- ContT $ withForeignPtr a ipivPtr <- ContT $ withForeignPtr ipivFPtr liftIO $ copyBlock (n*m) aPtr bPtr case order of ColumnMajor -> do nPtr <- Call.cint n ldaPtr <- Call.leadingDim m k1Ptr <- Call.cint 1 k2Ptr <- Call.cint k incxPtr <- Call.cint $ case inverted of Inverted -> 1 NonInverted -> -1 liftIO $ LapackGen.laswp nPtr bPtr ldaPtr k1Ptr k2Ptr (Perm.deconsElementPtr ipivPtr) incxPtr RowMajor -> liftIO $ swapColumns (Perm.Shape height) n bPtr (Array.mapShape Perm.Shape ipiv) $ Perm.indices (Inverted <> inverted) (Perm.Shape shapeIPiv) {-# INLINE swapColumns #-} swapColumns :: (Extent.Measure meas, Extent.C vert, Extent.C horiz) => (Shape.C height, Shape.C width, Class.Floating a) => (diagShape ~ Layout.RectangularDiagonal meas vert horiz height width) => Perm.Shape height -> Int -> Ptr a -> Array (Perm.Shape diagShape) (Perm.Element height) -> [Perm.Element diagShape] -> IO () swapColumns sh n xPtr ipiv is = evalContT $ do nPtr <- Call.cint n incPtr <- Call.cint 1 let mapIx (Perm.Element i) = Perm.Element i let columnPtr ix = advancePtr xPtr (n * Shape.uncheckedOffset sh ix) liftIO $ forM_ is $ \i -> BlasGen.swap nPtr (columnPtr $ mapIx i) incPtr (columnPtr (ipiv!i)) incPtr extractL :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => LowerUpperFlex lower upper meas vert horiz height width a -> Full meas vert horiz height width a extractL = Split.extractTriangle (Left Triangle) . split_ wideExtractL :: (Extent.Measure meas, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => LowerUpperFlex lower upper meas Extent.Small horiz height width a -> Lower height a wideExtractL = Split.wideExtractL . split_ wideMultiplyL :: (Extent.Measure measA, Extent.C horizA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C widthA, Shape.C widthB, Class.Floating a) => Transposition -> LowerUpperFlex lower upper measA Extent.Small horizA height widthA a -> Full meas vert horiz height widthB a -> Full meas vert horiz height widthB a wideMultiplyL transposed = Split.wideMultiplyL transposed . split_ wideSolveL :: (Extent.Measure measA, Extent.C horizA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Shape.C nrhs, Class.Floating a) => Transposition -> Conjugation -> LowerUpperFlex lower upper measA Extent.Small horizA height width a -> Full meas vert horiz height nrhs a -> Full meas vert horiz height nrhs a wideSolveL transposed conjugated = Split.wideSolveL transposed conjugated . split_ extractU :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Class.Floating a) => LowerUpperFlex lower upper meas vert horiz height width a -> Full meas vert horiz height width a extractU = Split.extractTriangle (Right Triangle) . split_ tallExtractU :: (Extent.Measure meas, Extent.C vert, Shape.C height, Shape.C width, Class.Floating a) => LowerUpperFlex lower upper meas vert Extent.Small height width a -> Upper width a tallExtractU = Split.tallExtractR . split_ tallMultiplyU :: (Extent.Measure measA, Extent.C vertA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C heightA, Shape.C widthB, Class.Floating a) => Transposition -> LowerUpperFlex lower upper measA vertA Extent.Small heightA height a -> Full meas vert horiz height widthB a -> Full meas vert horiz height widthB a tallMultiplyU transposed = Split.tallMultiplyR transposed . split_ tallSolveU :: (Extent.Measure measA, Extent.C vertA, Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Eq width, Shape.C nrhs, Class.Floating a) => Transposition -> Conjugation -> LowerUpperFlex lower upper measA vertA Extent.Small height width a -> Full meas vert horiz width nrhs a -> Full meas vert horiz width nrhs a tallSolveU transposed conjugated = Split.tallSolveR transposed conjugated . split_ toMatrix :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Eq width, Class.Floating a) => LowerUpperFlex lower upper meas vert horiz height width a -> Full meas vert horiz height width a toMatrix = getToMatrix $ ExtentPriv.switchTagTriple (ToMatrix wideToMatrix) (ToMatrix wideToMatrix) (ToMatrix wideToMatrix) (ToMatrix tallToMatrix) (ToMatrix $ either (MatrixPriv.fromFull . tallToMatrix) (MatrixPriv.fromFull . wideToMatrix) . caseTallWide) newtype ToMatrix lower upper height width a meas vert horiz = ToMatrix { getToMatrix :: LowerUpperFlex lower upper meas vert horiz height width a -> Full meas vert horiz height width a } tallToMatrix :: (Extent.Measure meas, Extent.C vert, Shape.C height, Shape.C width, Eq height, Eq width, Class.Floating a) => LowerUpperFlex lower upper meas vert Extent.Small height width a -> Full meas vert Extent.Small height width a tallToMatrix a = multiplyP NonInverted a $ Basic.transpose $ tallMultiplyU Transposed a $ Basic.transpose $ extractL a wideToMatrix :: (Extent.Measure meas, Extent.C horiz, Shape.C height, Shape.C width, Eq height, Eq width, Class.Floating a) => LowerUpperFlex lower upper meas Extent.Small horiz height width a -> Full meas Extent.Small horiz height width a wideToMatrix a = multiplyP NonInverted a $ wideMultiplyL NonTransposed a $ extractU a multiplyFull :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) => LowerUpperFlex lower upper meas vert horiz height fuse a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a multiplyFull a = case Matrix.extent a of ExtentPriv.Square _ -> multiplyFullAux a ExtentPriv.Separate _ _ -> Basic.mapHeight deconsUnchecked . multiplyFullAux (mapHeight Unchecked a) multiplyFullAux :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Eq height, Shape.C width, Shape.C fuse, Eq fuse, Class.Floating a) => LowerUpperFlex lower upper meas vert horiz height fuse a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a multiplyFullAux = getMultiplyFullRight $ ExtentPriv.switchTagTriple {- We cannot simply use squareFull here, because this requires height~width. -} (MultiplyFullRight wideMultiplyFullRight) (MultiplyFullRight wideMultiplyFullRight) (MultiplyFullRight wideMultiplyFullRight) (MultiplyFullRight tallMultiplyFullRight) (MultiplyFullRight $ either tallMultiplyFullRight wideMultiplyFullRight . caseTallWide) newtype MultiplyFullRight lower upper height fuse width a meas vert horiz = MultiplyFullRight { getMultiplyFullRight :: LowerUpperFlex lower upper meas vert horiz height fuse a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a } tallMultiplyFullRight :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse, Class.Floating a) => LowerUpperFlex lower upper meas vert Extent.Small height fuse a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a tallMultiplyFullRight a = multiplyP NonInverted a . Basic.multiply (MatrixPriv.weakenTall (extractL a)) . tallMultiplyU NonTransposed a wideMultiplyFullRight :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse, Class.Floating a) => LowerUpperFlex lower upper meas Extent.Small horiz height fuse a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a wideMultiplyFullRight a = multiplyP NonInverted a . wideMultiplyL NonTransposed a . Basic.multiply (MatrixPriv.weakenWide (extractU a)) transMultiplyVector :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Eq height, Eq width, Class.Floating a) => LowerUpperFlex lower upper meas vert horiz height width a -> Vector height a -> Vector width a transMultiplyVector = Basic.unliftColumn Layout.ColumnMajor . either tallTransMultiplyFullRight wideTransMultiplyFullRight . caseTallWide tallTransMultiplyFullRight :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse, Class.Floating a) => LowerUpperFlex lower upper meas horiz Extent.Small fuse height a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a tallTransMultiplyFullRight a = tallMultiplyU Transposed a . Basic.multiply (Basic.transpose $ MatrixPriv.weakenTall $ extractL a) . multiplyP Inverted a wideTransMultiplyFullRight :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width, Shape.C fuse, Eq height, Eq fuse, Class.Floating a) => LowerUpperFlex lower upper meas Extent.Small vert fuse height a -> Full meas vert horiz fuse width a -> Full meas vert horiz height width a wideTransMultiplyFullRight a = Basic.multiply (Basic.transpose $ MatrixPriv.weakenWide $ extractU a) . wideMultiplyL Transposed a . multiplyP Inverted a caseTallWide :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Shape.C height, Shape.C width) => LowerUpperFlex lower upper meas vert horiz height width a -> Either (Tall height width a) (Wide height width a) caseTallWide (LowerUpper ipiv (Array shape a)) = let consLU ipivb b newShape = LowerUpper (Array.mapShape (\bandShape -> bandShape{Layout.bandedExtent = Layout.splitExtent newShape}) ipivb) (Array newShape b) in either (Left . consLU ipiv a) (Right . consLU ipiv a) $ Layout.caseTallWideSplit shape _toRowMajor :: (Extent.Measure meas, Extent.C vert, Extent.C horiz, Eq height, Shape.C height, Shape.C width, Class.Floating a) => LowerUpperFlex lower upper meas vert horiz height width a -> LowerUpperFlex lower upper meas vert horiz height width a _toRowMajor (LowerUpper ipiv arr@(Array (Layout.Split Layout.Triangle order extent) a)) = LowerUpper ipiv $ case order of RowMajor -> arr ColumnMajor -> Array.unsafeCreate (Layout.Split Layout.Triangle RowMajor extent) $ \bPtr -> withForeignPtr a $ \aPtr -> do let (height, width) = Extent.dimensions extent let n = Shape.size width let m = Shape.size height copyTransposed n m aPtr n bPtr instance Matrix.Box LU where extent = Layout.splitExtent . Array.shape . split_ instance Matrix.ToQuadratic LU where heightToQuadratic (LowerUpper pivot split) = LowerUpper (Array.mapShape (layoutPivotSquare . Layout.bandedHeight) pivot) (Split.heightToQuadratic split) widthToQuadratic (LowerUpper pivot split) = LowerUpper (mapPivotHeight (const $ Layout.bandedWidth $ Array.shape pivot) $ Array.mapShape (layoutPivotSquare . Layout.bandedWidth) pivot) (Split.widthToQuadratic split) layoutPivotSquare :: sh -> Layout.Diagonal sh layoutPivotSquare = LayoutPub.diagonal Layout.ColumnMajor instance (xl ~ (), xu ~ ()) => Matrix.MapExtent LU xl xu lower upper where mapExtent = mapExtent instance (xl ~ (), xu ~ ()) => Multiply.MultiplyVector LU xl xu where matrixVector lu = Basic.unliftColumn Layout.ColumnMajor (multiplyFull (mapExtent Extent.toGeneral lu)) vectorMatrix = flip $ \lu -> case Matrix.extent lu of ExtentPriv.Square _ -> transMultiplyVector lu ExtentPriv.Separate _ _ -> Array.mapShape deconsUnchecked . transMultiplyVector (mapWidth Unchecked lu) instance (xl ~ (), xu ~ ()) => Multiply.MultiplySquare LU xl xu where squareFull lu = ArrMatrix.lift1 $ multiplyP NonInverted lu . wideMultiplyL NonTransposed lu . tallMultiplyU NonTransposed lu fullSquare = flip $ \lu -> ArrMatrix.lift1 $ Basic.transpose . tallMultiplyU Transposed lu . wideMultiplyL Transposed lu . multiplyP Inverted lu . Basic.transpose instance (xl ~ (), xu ~ ()) => Divide.Determinant LU xl xu where determinant = determinant instance (xl ~ (), xu ~ ()) => Divide.Solve LU xl xu where solve trans = ArrMatrix.lift1 . solveTrans trans