{-# LANGUAGE TypeFamilies #-} module Numeric.LAPACK.Private where import Numeric.LAPACK.Matrix.Shape.Private (Order(RowMajor, ColumnMajor), transposeFromOrder) import qualified Numeric.LAPACK.FFI.Generic as LapackGen import qualified Numeric.LAPACK.FFI.Complex as LapackComplex import qualified Numeric.BLAS.FFI.Real as BlasReal import qualified Numeric.BLAS.FFI.Generic as BlasGen import qualified Numeric.Netlib.Utility as Call import qualified Numeric.Netlib.Class as Class import Foreign.Marshal.Array (advancePtr) import Foreign.Marshal.Alloc (alloca) import Foreign.C.Types (CInt) import Foreign.ForeignPtr (ForeignPtr, withForeignPtr, mallocForeignPtrArray) import Foreign.Ptr (Ptr) import Foreign.Storable (Storable, poke, peek) import Text.Printf (printf) import Control.Monad.Trans.Cont (ContT(ContT), evalContT) import Control.Monad.IO.Class (liftIO) import Control.Monad (foldM) import Control.Applicative ((<$>)) import Data.Functor.Identity (Identity(Identity, runIdentity)) import qualified Data.Array.Comfort.Shape as Shape import Data.Array.Comfort.Storable.Internal (Array(Array)) import qualified Data.Complex as Complex import Data.Complex (Complex((:+))) import Prelude hiding (sum) type family RealOf x type instance RealOf Float = Float type instance RealOf Double = Double type instance RealOf (Complex a) = a type ComplexOf x = Complex (RealOf x) zero, one, minusOne :: Class.Floating a => a zero = runIdentity $ Class.switchFloating (Identity 0) (Identity 0) (Identity 0) (Identity 0) one = runIdentity $ Class.switchFloating (Identity 1) (Identity 1) (Identity 1) (Identity 1) minusOne = runIdentity $ Class.switchFloating (Identity (-1)) (Identity (-1)) (Identity (-1)) (Identity (-1)) fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO () fill a n dstPtr = evalContT $ do nPtr <- Call.cint n srcPtr <- Call.number a incxPtr <- Call.cint 0 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr copyBlock :: (Class.Floating a) => Int -> Ptr a -> Ptr a -> IO () copyBlock n srcPtr dstPtr = evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint 1 incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr srcPtr incxPtr dstPtr incyPtr copyToTemp :: (Class.Floating a) => Int -> ForeignPtr a -> ContT r IO (Ptr a) copyToTemp n fptr = do ptr <- ContT $ withForeignPtr fptr tmpPtr <- Call.allocaArray n liftIO $ copyBlock n ptr tmpPtr return tmpPtr {- | In ColumnMajor: Copy a m-by-n-matrix with lda>=m and ldb>=m. -} copySubMatrix :: (Class.Floating a) => Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copySubMatrix m n lda aPtr ldb bPtr = evalContT $ do uploPtr <- Call.char 'A' mPtr <- Call.cint m nPtr <- Call.cint n ldaPtr <- Call.cint lda ldbPtr <- Call.cint ldb liftIO $ LapackGen.lacpy uploPtr mPtr nPtr aPtr ldaPtr bPtr ldbPtr copyTransposed :: (Class.Floating a) => Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copyTransposed n m aPtr ldb bPtr = evalContT $ do nPtr <- Call.cint n incaPtr <- Call.cint m incbPtr <- Call.cint 1 liftIO $ sequence_ $ take m $ zipWith (\akPtr bkPtr -> BlasGen.copy nPtr akPtr incaPtr bkPtr incbPtr) (pointerSeq 1 aPtr) (pointerSeq ldb bPtr) {- | Copy a m-by-n-matrix to ColumnMajor order. -} copyToColumnMajor :: (Class.Floating a) => Order -> Int -> Int -> Ptr a -> Ptr a -> IO () copyToColumnMajor order m n aPtr bPtr = case order of RowMajor -> copyTransposed m n aPtr m bPtr ColumnMajor -> copyBlock (m*n) aPtr bPtr copyToSubColumnMajor :: (Class.Floating a) => Order -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO () copyToSubColumnMajor order m n aPtr ldb bPtr = case order of RowMajor -> copyTransposed m n aPtr ldb bPtr ColumnMajor -> if m==ldb then copyBlock (m*n) aPtr bPtr else copySubMatrix m n m aPtr ldb bPtr pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a] pointerSeq k ptr = iterate (flip advancePtr k) ptr allocArray :: (Shape.C sh, Storable a) => sh -> ContT r IO (Array sh a, Ptr a) allocArray sh = do fptr <- liftIO $ mallocForeignPtrArray $ Shape.size sh ptr <- ContT $ withForeignPtr fptr return (Array sh fptr, ptr) allocHigherArray :: (Shape.C sh, Class.Floating a) => sh -> Int -> Int -> Int -> ContT r IO (Array sh a, (Ptr a, Int)) allocHigherArray shapeX m n nrhs = do (x,xPtr) <- allocArray shapeX if m>n then do tmpPtr <- Call.allocaArray (m*nrhs) ContT $ \act -> do r <- act (x,(tmpPtr,m)) copySubMatrix n nrhs m tmpPtr n xPtr return r else return (x,(xPtr,n)) newtype Sum a = Sum {runSum :: Int -> Ptr a -> Int -> IO a} sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a sum = runSum $ Class.switchFloating (Sum sumReal) (Sum sumReal) (Sum sumComplex) (Sum sumComplex) sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a sumReal n xPtr incx = evalContT $ do nPtr <- Call.cint n incxPtr <- Call.cint incx yPtr <- Call.real one incyPtr <- Call.cint 0 liftIO $ BlasReal.dot nPtr xPtr incxPtr yPtr incyPtr sumComplex :: Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a) sumComplex n xPtr incx = evalContT $ do transPtr <- Call.char 'N' mPtr <- Call.cint 1 nPtr <- Call.cint n alphaPtr <- Call.number one onePtr <- Call.number one zeroincPtr <- Call.cint 0 aPtr <- Call.allocaArray n ldaPtr <- Call.cint 1 incxPtr <- Call.cint incx betaPtr <- Call.number zero yPtr <- Call.alloca incyPtr <- Call.cint 1 liftIO $ BlasGen.copy nPtr onePtr zeroincPtr aPtr incyPtr liftIO $ BlasGen.gemv transPtr mPtr nPtr alphaPtr aPtr ldaPtr xPtr incxPtr betaPtr yPtr incyPtr liftIO $ peek yPtr product :: Class.Floating a => Int -> Ptr a -> Int -> IO a product n xPtr incx = foldM (\x ptr -> do y <- peek ptr; return $! x*y) one $ take n $ pointerSeq incx xPtr newtype LACGV a = LACGV {getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()} lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO () lacgv = getLACGV $ Class.switchFloating (LACGV $ const $ const $ const $ return ()) (LACGV $ const $ const $ const $ return ()) (LACGV LapackComplex.lacgv) (LACGV LapackComplex.lacgv) multiplyMatrix :: (Class.Floating a) => Order -> Order -> Int -> Int -> Int -> ForeignPtr a -> ForeignPtr a -> Ptr a -> IO () multiplyMatrix orderA orderB m k n a b cPtr = do let lda = case orderA of RowMajor -> k; ColumnMajor -> m let ldb = case orderB of RowMajor -> n; ColumnMajor -> k let ldc = m evalContT $ do transaPtr <- Call.char $ transposeFromOrder orderA transbPtr <- Call.char $ transposeFromOrder orderB mPtr <- Call.cint m nPtr <- Call.cint n kPtr <- Call.cint k alphaPtr <- Call.number one aPtr <- ContT $ withForeignPtr a ldaPtr <- Call.cint lda bPtr <- ContT $ withForeignPtr b ldbPtr <- Call.cint ldb betaPtr <- Call.number zero ldcPtr <- Call.cint ldc liftIO $ BlasGen.gemm transaPtr transbPtr mPtr nPtr kPtr alphaPtr aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr withAutoWorkspaceInfo :: (Class.Floating a) => String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO () withAutoWorkspaceInfo name computation = evalContT $ do infoPtr <- Call.alloca liftIO $ withAutoWorkspace $ \workPtr lworkPtr -> computation workPtr lworkPtr infoPtr info <- liftIO $ fromIntegral <$> peek infoPtr case compare info (0::Int) of EQ -> return () LT -> error $ printf "%s: illegal value in %d-th argument" name (-info) GT -> error $ printf "%s: deficient rank %d" name info withAutoWorkspace :: (Class.Floating a) => (Ptr a -> Ptr CInt -> IO ()) -> IO () withAutoWorkspace computation = evalContT $ do lworkPtr <- Call.cint (-1) lwork <- liftIO $ alloca $ \workPtr -> do computation workPtr lworkPtr ceilingSize <$> peek workPtr workPtr <- Call.allocaArray lwork liftIO $ poke lworkPtr $ fromIntegral lwork liftIO $ computation workPtr lworkPtr newtype FromReal a = FromReal {getFromReal :: RealOf a -> a} fromReal :: (Class.Floating a) => RealOf a -> a fromReal = getFromReal $ Class.switchFloating (FromReal id) (FromReal id) (FromReal (:+0)) (FromReal (:+0)) newtype RealPart a = RealPart {getRealPart :: a -> RealOf a} realPart :: (Class.Floating a) => a -> RealOf a realPart = getRealPart $ Class.switchFloating (RealPart id) (RealPart id) (RealPart Complex.realPart) (RealPart Complex.realPart) newtype FuncArg b a = FuncArg {runFuncArg :: a -> b} ceilingSize :: (Class.Floating a) => a -> Int ceilingSize = runFuncArg $ Class.switchFloating (FuncArg ceiling) (FuncArg ceiling) (FuncArg $ ceiling . Complex.realPart) (FuncArg $ ceiling . Complex.realPart)