{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Private where import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor), transposeFromOrder) import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip)) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.LAPACK.FFI.Complex as LapackComplex import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated, Conjugated)) import Numeric.LAPACK.Scalar (RealOf, zero, one, isZero) import qualified Foreign.Marshal.Array.Guarded as ForeignArray import qualified Foreign.Marshal.Utils as Marshal import qualified Foreign.C.String as CStr import Foreign.Marshal.Array (copyArray, advancePtr) import Foreign.Marshal.Alloc (alloca) import Foreign.C.Types (CChar, CInt) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr) import Foreign.Ptr (Ptr, castPtr) import Foreign.Storable (Storable, poke, peek, pokeElemOff, peekElemOff) import Text.Printf (printf) import Control.Monad.Trans.Cont (ContT(ContT), evalContT, runContT) import Control.Monad.IO.Class (liftIO) import Control.Monad (when) import Control.Applicative (Const(Const,getConst), liftA2, (<$>)) import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable (Array) import qualified Data.Complex as Complex import Data.Complex (Complex) import Data.Tuple.HT (swap) import Prelude hiding (sum) realPtr :: Ptr a -> Ptr (RealOf a) realPtr = castPtr fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO () fill a n dstPtr = evalContT $ do nPtr <- Call.cint n srcPtr <- Call.number a incxPtr <- Call.cint 0 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr copyBlock :: (Class.Floating a) => Int -> Ptr a -> Ptr a -> IO () copyBlock n srcPtr dstPtr = evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr copyToTemp :: (Storable a) => Int -> ForeignPtr a -> ContT r IO (Ptr a) copyToTemp n fptr = do ptr <- ContT $ withForeignPtr fptr tmpPtr <- Call.allocaArray n liftIO $ copyArray tmpPtr ptr n return tmpPtr {- | Make a temporary copy only for complex matrices. -} conjugateToTemp :: (Class.Floating a) => Int -> ForeignPtr a -> ContT r IO (Ptr a) conjugateToTemp n = runCopyToTemp $ Class.switchFloating (CopyToTemp $ ContT . withForeignPtr) (CopyToTemp $ ContT . withForeignPtr) (CopyToTemp $ complexConjugateToTemp n) (CopyToTemp $ complexConjugateToTemp n) newtype CopyToTemp r a = CopyToTemp {runCopyToTemp :: ForeignPtr a -> ContT r IO (Ptr a)} complexConjugateToTemp :: Class.Real a => Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a)) complexConjugateToTemp n x = do nPtr <- Call.cint n xPtr <- copyToTemp n x incxPtr <- Call.cint 1 liftIO $ LapackComplex.lacgv nPtr xPtr incxPtr return xPtr condConjugate :: (Class.Floating a) => Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> IO () condConjugate conj nPtr yPtr incyPtr = when (conj==Conjugated) $ lacgv nPtr yPtr incyPtr copyConjugate :: (Class.Floating a) => Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO () copyConjugate nPtr xPtr incxPtr yPtr incyPtr = do BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr lacgv nPtr yPtr incyPtr copyCondConjugate :: (Class.Floating a) => Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO () copyCondConjugate conj nPtr xPtr incxPtr yPtr incyPtr = do BlasGen.copy nPtr xPtr incxPtr yPtr incyPtr condConjugate conj nPtr yPtr incyPtr condConjugateToTemp :: (Class.Floating a) => Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a) condConjugateToTemp conj n x = case conj of NonConjugated -> ContT $ withForeignPtr x Conjugated -> conjugateToTemp n x copyCondConjugateToTemp :: (Class.Floating a) => Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a) copyCondConjugateToTemp conj n a = do bPtr <- Call.allocaArray n liftIO $ evalContT $ do aPtr <- ContT $ withForeignPtr a sizePtr <- Call.cint n incPtr <- Call.cint 1 liftIO $ copyCondConjugate conj sizePtr aPtr incPtr bPtr incPtr return bPtr {- | In ColumnMajor: Copy a m-by-n-matrix with lda>=m and ldb>=m. -} copySubMatrix :: (Class.Floating a) => Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copySubMatrix = copySubTrapezoid 'A' copySubTrapezoid :: (Class.Floating a) => Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copySubTrapezoid side m n lda aPtr ldb bPtr = evalContT $ do uploPtr <- Call.char side mPtr <- Call.cint m nPtr <- Call.cint n ldaPtr <- Call.leadingDim lda ldbPtr <- Call.leadingDim ldb liftIO $ LapackGen.lacpy uploPtr mPtr nPtr aPtr ldaPtr bPtr ldbPtr copyTransposed :: (Class.Floating a) => Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copyTransposed n m aPtr ldb bPtr = evalContT $ do nPtr <- Call.cint n incaPtr <- Call.cint m incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\akPtr bkPtr -> BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr) (pointerSeq 1 aPtr) (pointerSeq ldb bPtr) {- | Copy a m-by-n-matrix to ColumnMajor order. -} copyToColumnMajor :: (Class.Floating a) => Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyToColumnMajor order m n aPtr bPtr = case order of RowMajor -> copyTransposed m n aPtr m bPtr ColumnMajor -> copyBlock (m*n) aPtr bPtr copyToSubColumnMajor :: (Class.Floating a) => Order -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copyToSubColumnMajor order m n aPtr ldb bPtr = case order of RowMajor -> copyTransposed m n aPtr ldb bPtr ColumnMajor -> if m==ldb then copyBlock (m*n) aPtr bPtr else copySubMatrix m n m aPtr ldb bPtr pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a] pointerSeq k ptr = iterate (flip advancePtr k) ptr createHigherArray :: (Shape.C sh, Class.Floating a) => sh -> Int -> Int -> Int -> ((Ptr a, Int) -> IO rank) -> IO (rank, Array sh a) createHigherArray shapeX m n nrhs act = fmap swap $ ArrayIO.unsafeCreateWithSizeAndResult shapeX $ \ _ xPtr -> if m>n then runContT (Call.allocaArray (m*nrhs)) $ \tmpPtr -> do r <- act (tmpPtr,m) copySubMatrix n nrhs m tmpPtr n xPtr return r else act (xPtr,n) newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a} sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a sum = runSum $ Class.switchFloating (Sum sumReal) (Sum sumReal) (Sum sumComplex) (Sum sumComplex) sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a sumReal n xPtr incx = evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint incx yPtr <- Call.real one incyPtr <- Call.cint 0 liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr sumComplex, sumComplexAlt :: Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a) sumComplex n xPtr incx = evalContT $ do nPtr <- Call.cint n let sxPtr = realPtr xPtr incxPtr <- Call.cint (2*incx) yPtr <- Call.real one incyPtr <- Call.cint 0 liftIO $ liftA2 (Complex.:+) (BlasReal.dot nPtr sxPtr incxPtr yPtr incyPtr) (BlasReal.dot nPtr (advancePtr sxPtr 1) incxPtr yPtr incyPtr) sumComplexAlt n aPtr inca = evalContT $ do transPtr <- Call.char 'N' mPtr <- Call.cint 2 nPtr <- Call.cint n onePtr <- Call.number one inc0Ptr <- Call.cint 0 let saPtr = realPtr aPtr ldaPtr <- Call.leadingDim (2*inca) sxPtr <- Call.allocaArray n incxPtr <- Call.cint 1 betaPtr <- Call.number zero yPtr <- Call.alloca let syPtr = realPtr yPtr incyPtr <- Call.cint 1 liftIO $ do BlasGen.copy nPtr onePtr inc0Ptr sxPtr incxPtr gemv transPtr mPtr nPtr onePtr saPtr ldaPtr sxPtr incxPtr betaPtr syPtr incyPtr peek yPtr mulReal :: (Class.Floating a) => Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO () mulReal n aPtr inca xPtr incx yPtr incy = evalContT $ do uploPtr <- Call.char 'U' nPtr <- Call.cint n kPtr <- Call.cint 0 alphaPtr <- Call.number one ldaPtr <- Call.leadingDim inca incxPtr <- Call.cint incx betaPtr <- Call.number zero incyPtr <- Call.cint incy liftIO $ BlasGen.hbmv uploPtr nPtr kPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr mul :: (Class.Floating a) => Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO () mul n aPtr inca xPtr incx yPtr incy = evalContT $ do transPtr <- Call.char 'N' nPtr <- Call.cint n klPtr <- Call.cint 0 kuPtr <- Call.cint 0 alphaPtr <- Call.number one ldaPtr <- Call.leadingDim inca incxPtr <- Call.cint incx betaPtr <- Call.number zero incyPtr <- Call.cint incy liftIO $ BlasGen.gbmv transPtr nPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr {- | Use the foldBalanced trick. -} product :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a product n aPtr inca = case compare n 1 of LT -> return one EQ -> peek aPtr GT -> let n2 = div n 2; new = n-n2 in ForeignArray.alloca (2*new-1) $ \xPtr -> do mulPairs n2 aPtr inca xPtr 1 when (odd n) $ pokeElemOff xPtr n2 =<< peekElemOff aPtr ((n-1)*inca) productLoop new xPtr {- | If 'mul' would be based on a scalar loop we would not need to cut the vector into chunks. The invariance is: When calling @productLoop n xPtr@, starting from xPtr there is storage allocated for 2*n-1 elements. -} productLoop :: (Class.Floating a) => Int -> Ptr a -> IO a productLoop n xPtr = if n==1 then peek xPtr else do let n2 = div n 2 mulPairs n2 xPtr 1 (advancePtr xPtr n) 1 productLoop (n-n2) (advancePtr xPtr (2*n2)) mulPairs :: (Class.Floating a) => Int -> Ptr a -> Int -> Ptr a -> Int -> IO () mulPairs n aPtr inca xPtr incx = let inca2 = 2*inca in mul n aPtr inca2 (advancePtr aPtr inca) inca2 xPtr incx newtype LACGV a = LACGV {getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()} lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO () lacgv = getLACGV $ Class.switchFloating (LACGV $ const $ const $ const $ return ()) (LACGV $ const $ const $ const $ return ()) (LACGV LapackComplex.lacgv) (LACGV LapackComplex.lacgv) {- Work around an inconsistency of BLAS. In case of a zero-column matrix BLAS's gemv and gbmv do not initialize the target vector. In contrast, these work-arounds do. -} {-# INLINE gemv #-} gemv :: (Class.Floating a) => Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO () gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr BlasGen.gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr {-# INLINE gbmv #-} gbmv :: (Class.Floating a) => Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO () gbmv transPtr mPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr = do initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr BlasGen.gbmv transPtr mPtr nPtr klPtr kuPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr initializeMV :: Class.Floating a => Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO () initializeMV transPtr mPtr nPtr betaPtr yPtr incyPtr = do trans <- peek transPtr let (mtPtr,ntPtr) = if trans == CStr.castCharToCChar 'N' then (mPtr,nPtr) else (nPtr,mPtr) n <- peek ntPtr beta <- peek betaPtr when (n == 0 && isZero beta) $ Marshal.with 0 $ \incbPtr -> BlasGen.copy mtPtr betaPtr incbPtr yPtr incyPtr multiplyMatrix :: (Class.Floating a) => Order -> Order -> Int -> Int -> Int -> ForeignPtr a -> ForeignPtr a -> Ptr a -> IO () multiplyMatrix orderA orderB m k n a b cPtr = do let lda = case orderA of RowMajor -> k; ColumnMajor -> m let ldb = case orderB of RowMajor -> n; ColumnMajor -> k let ldc = m evalContT $ do transaPtr <- Call.char $ transposeFromOrder orderA transbPtr <- Call.char $ transposeFromOrder orderB mPtr <- Call.cint m nPtr <- Call.cint n kPtr <- Call.cint k alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.leadingDim lda bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.leadingDim ldb betaPtr <- Call.number zero ldcPtr <- Call.leadingDim ldc liftIO $ BlasGen.gemm transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr withAutoWorkspaceInfo :: (Class.Floating a) => String -> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO () withAutoWorkspaceInfo msg name computation = withInfo msg name $ \infoPtr -> withAutoWorkspace $ \workPtr lworkPtr -> computation workPtr lworkPtr infoPtr withAutoWorkspace :: (Class.Floating a) => (Ptr a -> Ptr CInt -> IO ()) -> IO () withAutoWorkspace computation = evalContT $ do lworkPtr <- Call.cint (-1) lwork <- liftIO $ alloca $ \workPtr -> do computation workPtr lworkPtr max 1 . ceilingSize <$> peek workPtr workPtr <- Call.allocaArray lwork liftIO $ pokeCInt lworkPtr lwork liftIO $ computation workPtr lworkPtr withInfo :: String -> String -> (Ptr CInt -> IO ()) -> IO () withInfo msg name computation = alloca $ \infoPtr -> do computation infoPtr info <- peekCInt infoPtr case compare info (0::Int) of EQ -> return () LT -> error $ printf argMsg name (-info) GT -> error $ name ++ ": " ++ printf msg info argMsg :: String argMsg = "%s: illegal value in %d-th argument" errorCodeMsg :: String errorCodeMsg = "unknown error code %d" rankMsg :: String rankMsg = "deficient rank %d" definiteMsg :: String definiteMsg = "minor of order %d not positive definite" eigenMsg :: String eigenMsg = "%d off-diagonal elements not converging" pokeCInt :: Ptr CInt -> Int -> IO () pokeCInt ptr = poke ptr . fromIntegral peekCInt :: Ptr CInt -> IO Int peekCInt ptr = fromIntegral <$> peek ptr ceilingSize :: (Class.Floating a) => a -> Int ceilingSize = getFlip $ Class.switchFloating (Flip ceiling) (Flip ceiling) (Flip $ ceiling . Complex.realPart) (Flip $ ceiling . Complex.realPart) caseRealComplexFunc :: (Class.Floating a) => f a -> b -> b -> b caseRealComplexFunc f r c = getConstFunc f $ Class.switchFloating (Const r) (Const r) (Const c) (Const c) getConstFunc :: f c -> Const a c -> a getConstFunc _ = getConst data ComplexPart = RealPart | ImaginaryPart deriving (Eq, Ord, Show, Enum, Bounded)