{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Linear.General ( solve, inverse, ) where import Numeric.LAPACK.Matrix.Square (Square) import Numeric.LAPACK.Matrix (General) import qualified Numeric.LAPACK.Matrix.Shape.Private as MatrixShape import Numeric.LAPACK.Matrix.Shape.Private (Order(ColumnMajor)) import Numeric.LAPACK.Private (withAutoWorkspace, copyBlock, copyToColumnMajor) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import qualified Data.Array.Comfort.Storable.Internal as Array import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import Foreign.Marshal.Alloc (alloca) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (withForeignPtr) import Foreign.Ptr (Ptr) import Foreign.Storable (peek) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Applicative ((<$>)) import Text.Printf (printf) solve :: (Shape.C sh, Eq sh, Shape.C nrhs, Class.Floating a) => Square sh a -> General sh nrhs a -> General sh nrhs a solve (Array (MatrixShape.Square orderA shA) a) (Array (MatrixShape.General orderB heightB widthB) b) = Array.unsafeCreate (MatrixShape.General ColumnMajor heightB widthB) $ \xPtr -> do Call.assert "Square.solve: height shapes mismatch" (shA == heightB) let n = Shape.size heightB let nrhs = Shape.size widthB let ldb = n evalContT $ do nPtr <- Call.cint n nrhsPtr <- Call.cint nrhs aPtr <- ContT $ withForeignPtr a atmpPtr <- Call.allocaArray (n*n) ldaPtr <- Call.cint ldb ipivPtr <- Call.allocaArray n bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint ldb liftIO $ do copyToColumnMajor orderA n n aPtr atmpPtr copyToColumnMajor orderB n nrhs bPtr xPtr withInfo "gesv" $ LapackGen.gesv nPtr nrhsPtr atmpPtr ldaPtr ipivPtr xPtr ldbPtr inverse :: (Shape.C sh, Class.Floating a) => Square sh a -> Square sh a inverse (Array shape@(MatrixShape.Square _order sh) a) = Array.unsafeCreateWithSize shape $ \blockSize bPtr -> do let n = Shape.size sh evalContT $ do nPtr <- Call.cint n aPtr <- ContT $ withForeignPtr a ldbPtr <- Call.cint n ipivPtr <- Call.allocaArray n liftIO $ do copyBlock blockSize aPtr bPtr withInfo "getrf" $ LapackGen.getrf nPtr nPtr bPtr ldbPtr ipivPtr withInfo "getri" $ \infoPtr -> withAutoWorkspace $ \workPtr lworkPtr -> LapackGen.getri nPtr bPtr ldbPtr ipivPtr workPtr lworkPtr infoPtr withInfo :: String -> (Ptr CInt -> IO ()) -> IO () withInfo name computation = alloca $ \infoPtr -> do computation infoPtr info <- fromIntegral <$> peek infoPtr case compare info (0::Int) of EQ -> return () LT -> error $ printf "%s: illegal value in %d-th argument" name (-info) GT -> error $ printf "%s: %d-th diagonal value is zero" name info