{-# LANGUAGE TypeFamilies #-}
module Numeric.LAPACK.Private where

import Numeric.LAPACK.Matrix.Layout.Private
         (Order(RowMajor, ColumnMajor), transposeFromOrder)
import Numeric.LAPACK.Wrapper (Flip(Flip, getFlip))

import qualified Numeric.LAPACK.FFI.Generic as LapackGen
import qualified Numeric.LAPACK.FFI.Complex as LapackComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.BLAS.FFI.Generic as BlasGen
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import Numeric.LAPACK.Matrix.Modifier (Conjugation(NonConjugated, Conjugated))
import Numeric.LAPACK.Scalar (RealOf, zero, one, isZero)

import qualified Foreign.Marshal.Array.Guarded as ForeignArray
import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.C.String as CStr
import Foreign.Marshal.Array (copyArray, advancePtr)
import Foreign.Marshal.Alloc (alloca)
import Foreign.C.Types (CChar, CInt)
import Foreign.ForeignPtr (ForeignPtr, withForeignPtr)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, poke, peek, pokeElemOff, peekElemOff)

import Text.Printf (printf)

import Control.Monad.Trans.Cont (ContT(ContT), evalContT, runContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (Const(Const,getConst), liftA2, (<$>))

import qualified Data.Array.Comfort.Storable.Unchecked.Monadic as ArrayIO
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable (Array)

import qualified Data.Complex as Complex
import Data.Complex (Complex)
import Data.Tuple.HT (swap)

import Prelude hiding (sum)


realPtr :: Ptr a -> Ptr (RealOf a)
realPtr :: Ptr a -> Ptr (RealOf a)
realPtr = Ptr a -> Ptr (RealOf a)
forall a b. Ptr a -> Ptr b
castPtr


fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO ()
fill :: a -> Int -> Ptr a -> IO ()
fill a
a Int
n Ptr a
dstPtr = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr a
srcPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
a
   Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
   Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
nPtr Ptr a
srcPtr Ptr CInt
incxPtr Ptr a
dstPtr Ptr CInt
incyPtr


copyBlock :: (Class.Floating a) => Int -> Ptr a -> Ptr a -> IO ()
copyBlock :: Int -> Ptr a -> Ptr a -> IO ()
copyBlock Int
n Ptr a
srcPtr Ptr a
dstPtr = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
nPtr Ptr a
srcPtr Ptr CInt
incxPtr Ptr a
dstPtr Ptr CInt
incyPtr

copyToTemp :: (Storable a) => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp :: Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp Int
n ForeignPtr a
fptr = do
   Ptr a
ptr <- ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a))
-> ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr
   Ptr a
tmpPtr <- Int -> ContT r IO (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
   IO () -> ContT r IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT r IO ()) -> IO () -> ContT r IO ()
forall a b. (a -> b) -> a -> b
$ Ptr a -> Ptr a -> Int -> IO ()
forall a. Storable a => Ptr a -> Ptr a -> Int -> IO ()
copyArray Ptr a
tmpPtr Ptr a
ptr Int
n
   Ptr a -> ContT r IO (Ptr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr a
tmpPtr


{- |
Make a temporary copy only for complex matrices.
-}
conjugateToTemp ::
   (Class.Floating a) => Int -> ForeignPtr a -> ContT r IO (Ptr a)
conjugateToTemp :: Int -> ForeignPtr a -> ContT r IO (Ptr a)
conjugateToTemp Int
n =
   CopyToTemp r a -> ForeignPtr a -> ContT r IO (Ptr a)
forall r a. CopyToTemp r a -> ForeignPtr a -> ContT r IO (Ptr a)
runCopyToTemp (CopyToTemp r a -> ForeignPtr a -> ContT r IO (Ptr a))
-> CopyToTemp r a -> ForeignPtr a -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$
   CopyToTemp r Float
-> CopyToTemp r Double
-> CopyToTemp r (Complex Float)
-> CopyToTemp r (Complex Double)
-> CopyToTemp r a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating
      ((ForeignPtr Float -> ContT r IO (Ptr Float)) -> CopyToTemp r Float
forall r a. (ForeignPtr a -> ContT r IO (Ptr a)) -> CopyToTemp r a
CopyToTemp ((ForeignPtr Float -> ContT r IO (Ptr Float))
 -> CopyToTemp r Float)
-> (ForeignPtr Float -> ContT r IO (Ptr Float))
-> CopyToTemp r Float
forall a b. (a -> b) -> a -> b
$ ((Ptr Float -> IO r) -> IO r) -> ContT r IO (Ptr Float)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr Float -> IO r) -> IO r) -> ContT r IO (Ptr Float))
-> (ForeignPtr Float -> (Ptr Float -> IO r) -> IO r)
-> ForeignPtr Float
-> ContT r IO (Ptr Float)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Float -> (Ptr Float -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr)
      ((ForeignPtr Double -> ContT r IO (Ptr Double))
-> CopyToTemp r Double
forall r a. (ForeignPtr a -> ContT r IO (Ptr a)) -> CopyToTemp r a
CopyToTemp ((ForeignPtr Double -> ContT r IO (Ptr Double))
 -> CopyToTemp r Double)
-> (ForeignPtr Double -> ContT r IO (Ptr Double))
-> CopyToTemp r Double
forall a b. (a -> b) -> a -> b
$ ((Ptr Double -> IO r) -> IO r) -> ContT r IO (Ptr Double)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr Double -> IO r) -> IO r) -> ContT r IO (Ptr Double))
-> (ForeignPtr Double -> (Ptr Double -> IO r) -> IO r)
-> ForeignPtr Double
-> ContT r IO (Ptr Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Double -> (Ptr Double -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr)
      ((ForeignPtr (Complex Float) -> ContT r IO (Ptr (Complex Float)))
-> CopyToTemp r (Complex Float)
forall r a. (ForeignPtr a -> ContT r IO (Ptr a)) -> CopyToTemp r a
CopyToTemp ((ForeignPtr (Complex Float) -> ContT r IO (Ptr (Complex Float)))
 -> CopyToTemp r (Complex Float))
-> (ForeignPtr (Complex Float) -> ContT r IO (Ptr (Complex Float)))
-> CopyToTemp r (Complex Float)
forall a b. (a -> b) -> a -> b
$ Int
-> ForeignPtr (Complex Float) -> ContT r IO (Ptr (Complex Float))
forall a r.
Real a =>
Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a))
complexConjugateToTemp Int
n)
      ((ForeignPtr (Complex Double) -> ContT r IO (Ptr (Complex Double)))
-> CopyToTemp r (Complex Double)
forall r a. (ForeignPtr a -> ContT r IO (Ptr a)) -> CopyToTemp r a
CopyToTemp ((ForeignPtr (Complex Double) -> ContT r IO (Ptr (Complex Double)))
 -> CopyToTemp r (Complex Double))
-> (ForeignPtr (Complex Double)
    -> ContT r IO (Ptr (Complex Double)))
-> CopyToTemp r (Complex Double)
forall a b. (a -> b) -> a -> b
$ Int
-> ForeignPtr (Complex Double) -> ContT r IO (Ptr (Complex Double))
forall a r.
Real a =>
Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a))
complexConjugateToTemp Int
n)

newtype CopyToTemp r a =
   CopyToTemp {CopyToTemp r a -> ForeignPtr a -> ContT r IO (Ptr a)
runCopyToTemp :: ForeignPtr a -> ContT r IO (Ptr a)}

complexConjugateToTemp ::
   Class.Real a =>
   Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a))
complexConjugateToTemp :: Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a))
complexConjugateToTemp Int
n ForeignPtr (Complex a)
x = do
   Ptr CInt
nPtr <- Int -> FortranIO r (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr (Complex a)
xPtr <- Int -> ForeignPtr (Complex a) -> ContT r IO (Ptr (Complex a))
forall a r. Storable a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToTemp Int
n ForeignPtr (Complex a)
x
   Ptr CInt
incxPtr <- Int -> FortranIO r (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   IO () -> ContT r IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT r IO ()) -> IO () -> ContT r IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
forall a.
Real a =>
Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
LapackComplex.lacgv Ptr CInt
nPtr Ptr (Complex a)
xPtr Ptr CInt
incxPtr
   Ptr (Complex a) -> ContT r IO (Ptr (Complex a))
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr (Complex a)
xPtr


condConjugate ::
   (Class.Floating a) => Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
condConjugate :: Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
condConjugate Conjugation
conj Ptr CInt
nPtr Ptr a
yPtr Ptr CInt
incyPtr =
   Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conjugation
conjConjugation -> Conjugation -> Bool
forall a. Eq a => a -> a -> Bool
==Conjugation
Conjugated) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a. Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv Ptr CInt
nPtr Ptr a
yPtr Ptr CInt
incyPtr

copyConjugate ::
   (Class.Floating a) =>
   Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyConjugate :: Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyConjugate Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr
   Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a. Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv Ptr CInt
nPtr Ptr a
yPtr Ptr CInt
incyPtr

copyCondConjugate ::
   (Class.Floating a) =>
   Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyCondConjugate :: Conjugation
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyCondConjugate Conjugation
conj Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr
   Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Conjugation -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
condConjugate Conjugation
conj Ptr CInt
nPtr Ptr a
yPtr Ptr CInt
incyPtr

condConjugateToTemp ::
   (Class.Floating a) =>
   Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
condConjugateToTemp :: Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
condConjugateToTemp Conjugation
conj Int
n ForeignPtr a
x =
   case Conjugation
conj of
      Conjugation
NonConjugated -> ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a))
-> ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
x
      Conjugation
Conjugated -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
forall a r. Floating a => Int -> ForeignPtr a -> ContT r IO (Ptr a)
conjugateToTemp Int
n ForeignPtr a
x

copyCondConjugateToTemp ::
   (Class.Floating a) =>
   Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyCondConjugateToTemp :: Conjugation -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyCondConjugateToTemp Conjugation
conj Int
n ForeignPtr a
a = do
   Ptr a
bPtr <- Int -> ContT r IO (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
   IO (Ptr a) -> ContT r IO (Ptr a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Ptr a) -> ContT r IO (Ptr a))
-> IO (Ptr a) -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ContT (Ptr a) IO (Ptr a) -> IO (Ptr a)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (Ptr a) IO (Ptr a) -> IO (Ptr a))
-> ContT (Ptr a) IO (Ptr a) -> IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ do
      Ptr a
aPtr <- ((Ptr a -> IO (Ptr a)) -> IO (Ptr a)) -> ContT (Ptr a) IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO (Ptr a)) -> IO (Ptr a)) -> ContT (Ptr a) IO (Ptr a))
-> ((Ptr a -> IO (Ptr a)) -> IO (Ptr a))
-> ContT (Ptr a) IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO (Ptr a)) -> IO (Ptr a)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
sizePtr <- Int -> FortranIO (Ptr a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
incPtr <- Int -> FortranIO (Ptr a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO () -> ContT (Ptr a) IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT (Ptr a) IO ()) -> IO () -> ContT (Ptr a) IO ()
forall a b. (a -> b) -> a -> b
$ Conjugation
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Conjugation
-> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyCondConjugate Conjugation
conj Ptr CInt
sizePtr Ptr a
aPtr Ptr CInt
incPtr Ptr a
bPtr Ptr CInt
incPtr
      Ptr a -> ContT (Ptr a) IO (Ptr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr a
bPtr



{- |
In ColumnMajor:
Copy a m-by-n-matrix with lda>=m and ldb>=m.
-}
copySubMatrix ::
   (Class.Floating a) =>
   Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix :: Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix = Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
forall a.
Floating a =>
Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubTrapezoid Char
'A'

copySubTrapezoid ::
   (Class.Floating a) =>
   Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubTrapezoid :: Char -> Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubTrapezoid Char
side Int
m Int
n Int
lda Ptr a
aPtr Int
ldb Ptr a
bPtr = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
side
   Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
   Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> IO ()
LapackGen.lacpy Ptr CChar
uploPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
bPtr Ptr CInt
ldbPtr

copyTransposed ::
   (Class.Floating a) =>
   Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyTransposed :: Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyTransposed Int
n Int
m Ptr a
aPtr Int
ldb Ptr a
bPtr = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr CInt
incaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
   Ptr CInt
incbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([IO ()] -> IO ()) -> [IO ()] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [IO ()] -> [IO ()]
forall a. Int -> [a] -> [a]
take Int
m ([IO ()] -> [IO ()]) -> [IO ()] -> [IO ()]
forall a b. (a -> b) -> a -> b
$
      (Ptr a -> Ptr a -> IO ()) -> [Ptr a] -> [Ptr a] -> [IO ()]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
         (\Ptr a
akPtr Ptr a
bkPtr -> Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
nPtr Ptr a
akPtr Ptr CInt
incaPtr Ptr a
bkPtr Ptr CInt
incbPtr)
         (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
1 Ptr a
aPtr)
         (Int -> Ptr a -> [Ptr a]
forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
ldb Ptr a
bPtr)


{- |
Copy a m-by-n-matrix to ColumnMajor order.
-}
copyToColumnMajor ::
   (Class.Floating a) =>
   Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyToColumnMajor :: Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyToColumnMajor Order
order Int
m Int
n Ptr a
aPtr Ptr a
bPtr =
   case Order
order of
      Order
RowMajor -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
forall a.
Floating a =>
Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyTransposed Int
m Int
n Ptr a
aPtr Int
m Ptr a
bPtr
      Order
ColumnMajor -> Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) Ptr a
aPtr Ptr a
bPtr

copyToSubColumnMajor ::
   (Class.Floating a) =>
   Order -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyToSubColumnMajor :: Order -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyToSubColumnMajor Order
order Int
m Int
n Ptr a
aPtr Int
ldb Ptr a
bPtr =
   case Order
order of
      Order
RowMajor -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
forall a.
Floating a =>
Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copyTransposed Int
m Int
n Ptr a
aPtr Int
ldb Ptr a
bPtr
      Order
ColumnMajor ->
         if Int
mInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
ldb
           then Int -> Ptr a -> Ptr a -> IO ()
forall a. Floating a => Int -> Ptr a -> Ptr a -> IO ()
copyBlock (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) Ptr a
aPtr Ptr a
bPtr
           else Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
forall a.
Floating a =>
Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix Int
m Int
n Int
m Ptr a
aPtr Int
ldb Ptr a
bPtr


copyToColumnMajorTemp ::
   (Class.Floating a) =>
   Order -> Int -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToColumnMajorTemp :: Order -> Int -> Int -> ForeignPtr a -> ContT r IO (Ptr a)
copyToColumnMajorTemp Order
order Int
m Int
n ForeignPtr a
fptr = do
   Ptr a
ptr <- ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a))
-> ((Ptr a -> IO r) -> IO r) -> ContT r IO (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO r) -> IO r
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fptr
   Ptr a
tmpPtr <- Int -> ContT r IO (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n)
   IO () -> ContT r IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT r IO ()) -> IO () -> ContT r IO ()
forall a b. (a -> b) -> a -> b
$ Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
forall a.
Floating a =>
Order -> Int -> Int -> Ptr a -> Ptr a -> IO ()
copyToColumnMajor Order
order Int
m Int
n Ptr a
ptr Ptr a
tmpPtr
   Ptr a -> ContT r IO (Ptr a)
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr a
tmpPtr


pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a]
pointerSeq :: Int -> Ptr a -> [Ptr a]
pointerSeq Int
k Ptr a
ptr = (Ptr a -> Ptr a) -> Ptr a -> [Ptr a]
forall a. (a -> a) -> a -> [a]
iterate ((Ptr a -> Int -> Ptr a) -> Int -> Ptr a -> Ptr a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
k) Ptr a
ptr


createHigherArray ::
   (Shape.C sh, Class.Floating a) =>
   sh -> Int -> Int -> Int ->
   ((Ptr a, Int) -> IO rank) -> IO (rank, Array sh a)
createHigherArray :: sh
-> Int
-> Int
-> Int
-> ((Ptr a, Int) -> IO rank)
-> IO (rank, Array sh a)
createHigherArray sh
shapeX Int
m Int
n Int
nrhs (Ptr a, Int) -> IO rank
act =
   ((Array sh a, rank) -> (rank, Array sh a))
-> IO (Array sh a, rank) -> IO (rank, Array sh a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Array sh a, rank) -> (rank, Array sh a)
forall a b. (a, b) -> (b, a)
swap (IO (Array sh a, rank) -> IO (rank, Array sh a))
-> IO (Array sh a, rank) -> IO (rank, Array sh a)
forall a b. (a -> b) -> a -> b
$ sh -> (Int -> Ptr a -> IO rank) -> IO (Array sh a, rank)
forall (m :: * -> *) sh a b.
(PrimMonad m, C sh, Storable a) =>
sh -> (Int -> Ptr a -> IO b) -> m (Array sh a, b)
ArrayIO.unsafeCreateWithSizeAndResult sh
shapeX ((Int -> Ptr a -> IO rank) -> IO (Array sh a, rank))
-> (Int -> Ptr a -> IO rank) -> IO (Array sh a, rank)
forall a b. (a -> b) -> a -> b
$ \ Int
_ Ptr a
xPtr ->
   if Int
mInt -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>Int
n
      then
         ContT rank IO (Ptr a) -> (Ptr a -> IO rank) -> IO rank
forall k (r :: k) (m :: k -> *) a. ContT r m a -> (a -> m r) -> m r
runContT (Int -> ContT rank IO (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
nrhs)) ((Ptr a -> IO rank) -> IO rank) -> (Ptr a -> IO rank) -> IO rank
forall a b. (a -> b) -> a -> b
$ \Ptr a
tmpPtr -> do
            rank
r <- (Ptr a, Int) -> IO rank
act (Ptr a
tmpPtr,Int
m)
            Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
forall a.
Floating a =>
Int -> Int -> Int -> Ptr a -> Int -> Ptr a -> IO ()
copySubMatrix Int
n Int
nrhs Int
m Ptr a
tmpPtr Int
n Ptr a
xPtr
            rank -> IO rank
forall (m :: * -> *) a. Monad m => a -> m a
return rank
r
      else (Ptr a, Int) -> IO rank
act (Ptr a
xPtr,Int
n)



newtype Sum a = Sum {Sum a -> Int -> Ptr a -> Int -> IO a
runSum :: Int -> Ptr a -> Int -> IO a}

sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a
sum :: Int -> Ptr a -> Int -> IO a
sum =
   Sum a -> Int -> Ptr a -> Int -> IO a
forall a. Sum a -> Int -> Ptr a -> Int -> IO a
runSum (Sum a -> Int -> Ptr a -> Int -> IO a)
-> Sum a -> Int -> Ptr a -> Int -> IO a
forall a b. (a -> b) -> a -> b
$
   Sum Float
-> Sum Double
-> Sum (Complex Float)
-> Sum (Complex Double)
-> Sum a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating
      ((Int -> Ptr Float -> Int -> IO Float) -> Sum Float
forall a. (Int -> Ptr a -> Int -> IO a) -> Sum a
Sum Int -> Ptr Float -> Int -> IO Float
forall a. Real a => Int -> Ptr a -> Int -> IO a
sumReal)
      ((Int -> Ptr Double -> Int -> IO Double) -> Sum Double
forall a. (Int -> Ptr a -> Int -> IO a) -> Sum a
Sum Int -> Ptr Double -> Int -> IO Double
forall a. Real a => Int -> Ptr a -> Int -> IO a
sumReal)
      ((Int -> Ptr (Complex Float) -> Int -> IO (Complex Float))
-> Sum (Complex Float)
forall a. (Int -> Ptr a -> Int -> IO a) -> Sum a
Sum Int -> Ptr (Complex Float) -> Int -> IO (Complex Float)
forall a. Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex)
      ((Int -> Ptr (Complex Double) -> Int -> IO (Complex Double))
-> Sum (Complex Double)
forall a. (Int -> Ptr a -> Int -> IO a) -> Sum a
Sum Int -> Ptr (Complex Double) -> Int -> IO (Complex Double)
forall a. Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex)

sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a
sumReal :: Int -> Ptr a -> Int -> IO a
sumReal Int
n Ptr a
xPtr Int
incx =
   ContT a IO a -> IO a
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT a IO a -> IO a) -> ContT a IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
incxPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
incx
      Ptr a
yPtr <- a -> FortranIO a (Ptr a)
forall a r. Real a => a -> FortranIO r (Ptr a)
Call.real a
forall a. Floating a => a
one
      Ptr CInt
incyPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      IO a -> ContT a IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> ContT a IO a) -> IO a -> ContT a IO a
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
forall a.
Real a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
BlasReal.dot Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr

sumComplex, sumComplexAlt ::
   Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex :: Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex Int
n Ptr (Complex a)
xPtr Int
incx =
   ContT (Complex a) IO (Complex a) -> IO (Complex a)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (Complex a) IO (Complex a) -> IO (Complex a))
-> ContT (Complex a) IO (Complex a) -> IO (Complex a)
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      let sxPtr :: Ptr (RealOf (Complex a))
sxPtr = Ptr (Complex a) -> Ptr (RealOf (Complex a))
forall a. Ptr a -> Ptr (RealOf a)
realPtr Ptr (Complex a)
xPtr
      Ptr CInt
incxPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
incx)
      Ptr a
yPtr <- a -> FortranIO (Complex a) (Ptr a)
forall a r. Real a => a -> FortranIO r (Ptr a)
Call.real a
forall a. Floating a => a
one
      Ptr CInt
incyPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      IO (Complex a) -> ContT (Complex a) IO (Complex a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Complex a) -> ContT (Complex a) IO (Complex a))
-> IO (Complex a) -> ContT (Complex a) IO (Complex a)
forall a b. (a -> b) -> a -> b
$
         (a -> a -> Complex a) -> IO a -> IO a -> IO (Complex a)
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> Complex a
forall a. a -> a -> Complex a
(Complex.:+)
            (Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
forall a.
Real a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
BlasReal.dot Ptr CInt
nPtr Ptr a
Ptr (RealOf (Complex a))
sxPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr)
            (Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
forall a.
Real a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
BlasReal.dot Ptr CInt
nPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
Ptr (RealOf (Complex a))
sxPtr Int
1) Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr)

sumComplexAlt :: Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplexAlt Int
n Ptr (Complex a)
aPtr Int
inca =
   ContT (Complex a) IO (Complex a) -> IO (Complex a)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (Complex a) IO (Complex a) -> IO (Complex a))
-> ContT (Complex a) IO (Complex a) -> IO (Complex a)
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
transPtr <- Char -> FortranIO (Complex a) (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
      Ptr CInt
mPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
2
      Ptr CInt
nPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
onePtr <- a -> FortranIO (Complex a) (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr CInt
inc0Ptr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      let saPtr :: Ptr (RealOf (Complex a))
saPtr = Ptr (Complex a) -> Ptr (RealOf (Complex a))
forall a. Ptr a -> Ptr (RealOf a)
realPtr Ptr (Complex a)
aPtr
      Ptr CInt
ldaPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
inca)
      Ptr a
sxPtr <- Int -> FortranIO (Complex a) (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      Ptr CInt
incxPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
betaPtr <- a -> FortranIO (Complex a) (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr (Complex a)
yPtr <- FortranIO (Complex a) (Ptr (Complex a))
forall a r. Storable a => FortranIO r (Ptr a)
Call.alloca
      let syPtr :: Ptr (RealOf (Complex a))
syPtr = Ptr (Complex a) -> Ptr (RealOf (Complex a))
forall a. Ptr a -> Ptr (RealOf a)
realPtr Ptr (Complex a)
yPtr
      Ptr CInt
incyPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO (Complex a) -> ContT (Complex a) IO (Complex a)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Complex a) -> ContT (Complex a) IO (Complex a))
-> IO (Complex a) -> ContT (Complex a) IO (Complex a)
forall a b. (a -> b) -> a -> b
$ do
         Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
nPtr Ptr a
onePtr Ptr CInt
inc0Ptr Ptr a
sxPtr Ptr CInt
incxPtr
         Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
gemv
            Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
onePtr Ptr a
Ptr (RealOf (Complex a))
saPtr Ptr CInt
ldaPtr
            Ptr a
sxPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
Ptr (RealOf (Complex a))
syPtr Ptr CInt
incyPtr
         Ptr (Complex a) -> IO (Complex a)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Complex a)
yPtr


mulReal ::
   (Class.Floating a) =>
   Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulReal :: Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulReal Int
n Ptr a
aPtr Int
inca Ptr a
xPtr Int
incx Ptr a
yPtr Int
incy = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CChar
uploPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'U'
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
   Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
   Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
inca
   Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
incx
   Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
   Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
incy
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
      Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.hbmv Ptr CChar
uploPtr
         Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
         Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr

mul ::
   (Class.Floating a) =>
   Conjugation -> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul :: Conjugation
-> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul Conjugation
conj Int
n Ptr a
aPtr Int
inca Ptr a
xPtr Int
incx Ptr a
yPtr Int
incy = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ case Conjugation
conj of Conjugation
NonConjugated -> Char
'N'; Conjugation
Conjugated -> Char
'C'
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr CInt
klPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
   Ptr CInt
kuPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
   Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
   Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
inca
   Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
incx
   Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
   Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
incy
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
      Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.gbmv Ptr CChar
transPtr
         Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
         Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr

{- |
Use the foldBalanced trick.
-}
product :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a
product :: Int -> Ptr a -> Int -> IO a
product Int
n Ptr a
aPtr Int
inca =
   case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
n Int
1 of
      Ordering
LT -> a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
forall a. Floating a => a
one
      Ordering
EQ -> Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
aPtr
      Ordering
GT -> let n2 :: Int
n2 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
n Int
2; new :: Int
new = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
n2
            in Int -> (Ptr a -> IO a) -> IO a
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
ForeignArray.alloca (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
newInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ((Ptr a -> IO a) -> IO a) -> (Ptr a -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr -> do
         Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
forall a.
Floating a =>
Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs Int
n2 Ptr a
aPtr Int
inca Ptr a
xPtr Int
1
         Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int -> Bool
forall a. Integral a => a -> Bool
odd Int
n) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
xPtr Int
n2 (a -> IO ()) -> IO a -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
aPtr ((Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
inca)
         Int -> Ptr a -> IO a
forall a. Floating a => Int -> Ptr a -> IO a
productLoop Int
new Ptr a
xPtr

{- |
If 'mul' would be based on a scalar loop
we would not need to cut the vector into chunks.

The invariance is:
When calling @productLoop n xPtr@,
starting from xPtr there is storage allocated for 2*n-1 elements.
-}
productLoop :: (Class.Floating a) => Int -> Ptr a -> IO a
productLoop :: Int -> Ptr a -> IO a
productLoop Int
n Ptr a
xPtr =
   if Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
1
      then Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
xPtr
      else do
         let n2 :: Int
n2 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
n Int
2
         Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
forall a.
Floating a =>
Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs Int
n2 Ptr a
xPtr Int
1 (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr Int
n) Int
1
         Int -> Ptr a -> IO a
forall a. Floating a => Int -> Ptr a -> IO a
productLoop (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
n2) (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n2))

mulPairs ::
   (Class.Floating a) =>
   Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs :: Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs Int
n Ptr a
aPtr Int
inca Ptr a
xPtr Int
incx =
   let inca2 :: Int
inca2 = Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
inca
   in Conjugation
-> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
forall a.
Floating a =>
Conjugation
-> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul Conjugation
NonConjugated Int
n Ptr a
aPtr Int
inca2 (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
aPtr Int
inca) Int
inca2 Ptr a
xPtr Int
incx


newtype LACGV a = LACGV {LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()}

lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv =
   LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a. LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
getLACGV (LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ())
-> LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
   LACGV Float
-> LACGV Double
-> LACGV (Complex Float)
-> LACGV (Complex Double)
-> LACGV a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating
      ((Ptr CInt -> Ptr Float -> Ptr CInt -> IO ()) -> LACGV Float
forall a. (Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> LACGV a
LACGV ((Ptr CInt -> Ptr Float -> Ptr CInt -> IO ()) -> LACGV Float)
-> (Ptr CInt -> Ptr Float -> Ptr CInt -> IO ()) -> LACGV Float
forall a b. (a -> b) -> a -> b
$ (Ptr Float -> Ptr CInt -> IO ())
-> Ptr CInt -> Ptr Float -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const ((Ptr Float -> Ptr CInt -> IO ())
 -> Ptr CInt -> Ptr Float -> Ptr CInt -> IO ())
-> (Ptr Float -> Ptr CInt -> IO ())
-> Ptr CInt
-> Ptr Float
-> Ptr CInt
-> IO ()
forall a b. (a -> b) -> a -> b
$ (Ptr CInt -> IO ()) -> Ptr Float -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const ((Ptr CInt -> IO ()) -> Ptr Float -> Ptr CInt -> IO ())
-> (Ptr CInt -> IO ()) -> Ptr Float -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const (IO () -> Ptr CInt -> IO ()) -> IO () -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
      ((Ptr CInt -> Ptr Double -> Ptr CInt -> IO ()) -> LACGV Double
forall a. (Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> LACGV a
LACGV ((Ptr CInt -> Ptr Double -> Ptr CInt -> IO ()) -> LACGV Double)
-> (Ptr CInt -> Ptr Double -> Ptr CInt -> IO ()) -> LACGV Double
forall a b. (a -> b) -> a -> b
$ (Ptr Double -> Ptr CInt -> IO ())
-> Ptr CInt -> Ptr Double -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const ((Ptr Double -> Ptr CInt -> IO ())
 -> Ptr CInt -> Ptr Double -> Ptr CInt -> IO ())
-> (Ptr Double -> Ptr CInt -> IO ())
-> Ptr CInt
-> Ptr Double
-> Ptr CInt
-> IO ()
forall a b. (a -> b) -> a -> b
$ (Ptr CInt -> IO ()) -> Ptr Double -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const ((Ptr CInt -> IO ()) -> Ptr Double -> Ptr CInt -> IO ())
-> (Ptr CInt -> IO ()) -> Ptr Double -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const (IO () -> Ptr CInt -> IO ()) -> IO () -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
      ((Ptr CInt -> Ptr (Complex Float) -> Ptr CInt -> IO ())
-> LACGV (Complex Float)
forall a. (Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> LACGV a
LACGV Ptr CInt -> Ptr (Complex Float) -> Ptr CInt -> IO ()
forall a.
Real a =>
Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
LapackComplex.lacgv)
      ((Ptr CInt -> Ptr (Complex Double) -> Ptr CInt -> IO ())
-> LACGV (Complex Double)
forall a. (Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> LACGV a
LACGV Ptr CInt -> Ptr (Complex Double) -> Ptr CInt -> IO ()
forall a.
Real a =>
Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
LapackComplex.lacgv)


{-
Work around an inconsistency of BLAS.
In case of a zero-column matrix
BLAS's gemv and gbmv do not initialize the target vector.
In contrast, these work-arounds do.
-}
{-# INLINE gemv #-}
gemv ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt ->
   Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
gemv :: Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
gemv Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr
      Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr
   Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.gemv Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr
      Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr

{-# INLINE gbmv #-}
gbmv ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt -> IO ()
gbmv :: Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
gbmv Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr
      Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr
   Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.gbmv Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr
      Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr

initializeMV ::
   Class.Floating a =>
   Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV :: Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   CChar
trans <- Ptr CChar -> IO CChar
forall a. Storable a => Ptr a -> IO a
peek Ptr CChar
transPtr
   let (Ptr CInt
mtPtr,Ptr CInt
ntPtr) =
         if CChar
trans CChar -> CChar -> Bool
forall a. Eq a => a -> a -> Bool
== Char -> CChar
CStr.castCharToCChar Char
'N'
            then (Ptr CInt
mPtr,Ptr CInt
nPtr) else (Ptr CInt
nPtr,Ptr CInt
mPtr)
   CInt
n <- Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
ntPtr
   a
beta <- Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
betaPtr
   Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CInt
n CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0 Bool -> Bool -> Bool
&& a -> Bool
forall a. Floating a => a -> Bool
isZero a
beta) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      CInt -> (Ptr CInt -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
Marshal.with CInt
0 ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
incbPtr ->
      Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
BlasGen.copy Ptr CInt
mtPtr Ptr a
betaPtr Ptr CInt
incbPtr Ptr a
yPtr Ptr CInt
incyPtr


multiplyMatrix ::
   (Class.Floating a) =>
   Order -> Order -> Int -> Int -> Int ->
   ForeignPtr a -> ForeignPtr a -> Ptr a -> IO ()
multiplyMatrix :: Order
-> Order
-> Int
-> Int
-> Int
-> ForeignPtr a
-> ForeignPtr a
-> Ptr a
-> IO ()
multiplyMatrix Order
orderA Order
orderB Int
m Int
k Int
n ForeignPtr a
a ForeignPtr a
b Ptr a
cPtr = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   let lda :: Int
lda = case Order
orderA of Order
RowMajor -> Int
k; Order
ColumnMajor -> Int
m
   let ldb :: Int
ldb = case Order
orderB of Order
RowMajor -> Int
n; Order
ColumnMajor -> Int
k
   let ldc :: Int
ldc = Int
m
   if Int
kInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0
      then IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ a -> Int -> Ptr a -> IO ()
forall a. Floating a => a -> Int -> Ptr a -> IO ()
fill a
forall a. Floating a => a
zero (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n) Ptr a
cPtr
      else do
      Ptr CChar
transaPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
transposeFromOrder Order
orderA
      Ptr CChar
transbPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ Order -> Char
transposeFromOrder Order
orderB
      Ptr CInt
mPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
m
      Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
kPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
k
      Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr a
aPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
a
      Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
lda
      Ptr a
bPtr <- ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall k (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a))
-> ((Ptr a -> IO ()) -> IO ()) -> FortranIO () (Ptr a)
forall a b. (a -> b) -> a -> b
$ ForeignPtr a -> (Ptr a -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
b
      Ptr CInt
ldbPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldb
      Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr CInt
ldcPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
ldc
      IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
         Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
BlasGen.gemm
            Ptr CChar
transaPtr Ptr CChar
transbPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
kPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
            Ptr a
bPtr Ptr CInt
ldbPtr Ptr a
betaPtr Ptr a
cPtr Ptr CInt
ldcPtr



withAutoWorkspaceInfo ::
   (Class.Floating a) =>
   String -> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspaceInfo :: String
-> String -> (Ptr a -> Ptr CInt -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspaceInfo String
msg String
name Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
computation =
   String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
msg String
name ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
infoPtr ->
   (Ptr a -> Ptr CInt -> IO ()) -> IO ()
forall a. Floating a => (Ptr a -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspace ((Ptr a -> Ptr CInt -> IO ()) -> IO ())
-> (Ptr a -> Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
workPtr Ptr CInt
lworkPtr ->
      Ptr a -> Ptr CInt -> Ptr CInt -> IO ()
computation Ptr a
workPtr Ptr CInt
lworkPtr Ptr CInt
infoPtr

withAutoWorkspace ::
   (Class.Floating a) =>
   (Ptr a -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspace :: (Ptr a -> Ptr CInt -> IO ()) -> IO ()
withAutoWorkspace Ptr a -> Ptr CInt -> IO ()
computation = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CInt
lworkPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (-Int
1)
   Int
lwork <- IO Int -> ContT () IO Int
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> ContT () IO Int) -> IO Int -> ContT () IO Int
forall a b. (a -> b) -> a -> b
$ (Ptr a -> IO Int) -> IO Int
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr a -> IO Int) -> IO Int) -> (Ptr a -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr a
workPtr -> do
      Ptr a -> Ptr CInt -> IO ()
computation Ptr a
workPtr Ptr CInt
lworkPtr
      Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
1 (Int -> Int) -> (a -> Int) -> a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int
forall a. Floating a => a -> Int
ceilingSize (a -> Int) -> IO a -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
workPtr
   Ptr a
workPtr <- Int -> FortranIO () (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
lwork
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Int -> IO ()
pokeCInt Ptr CInt
lworkPtr Int
lwork
   IO () -> ContT () IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr a -> Ptr CInt -> IO ()
computation Ptr a
workPtr Ptr CInt
lworkPtr

withInfo :: String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo :: String -> String -> (Ptr CInt -> IO ()) -> IO ()
withInfo String
msg String
name Ptr CInt -> IO ()
computation = (Ptr CInt -> IO ()) -> IO ()
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
infoPtr -> do
   Ptr CInt -> IO ()
computation Ptr CInt
infoPtr
   Int
info <- Ptr CInt -> IO Int
peekCInt Ptr CInt
infoPtr
   case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
info (Int
0::Int) of
      Ordering
EQ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Ordering
LT -> String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> String -> Int -> String
forall r. PrintfType r => String -> r
printf String
argMsg String
name (-Int
info)
      Ordering
GT -> String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> Int -> String
forall r. PrintfType r => String -> r
printf String
msg Int
info

argMsg :: String
argMsg :: String
argMsg = String
"%s: illegal value in %d-th argument"

errorCodeMsg :: String
errorCodeMsg :: String
errorCodeMsg = String
"unknown error code %d"

rankMsg :: String
rankMsg :: String
rankMsg = String
"deficient rank %d"

definiteMsg :: String
definiteMsg :: String
definiteMsg = String
"minor of order %d not positive definite"

eigenMsg :: String
eigenMsg :: String
eigenMsg = String
"%d off-diagonal elements not converging"


pokeCInt :: Ptr CInt -> Int -> IO ()
pokeCInt :: Ptr CInt -> Int -> IO ()
pokeCInt Ptr CInt
ptr = Ptr CInt -> CInt -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CInt
ptr (CInt -> IO ()) -> (Int -> CInt) -> Int -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral

peekCInt :: Ptr CInt -> IO Int
peekCInt :: Ptr CInt -> IO Int
peekCInt Ptr CInt
ptr = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
ptr


ceilingSize :: (Class.Floating a) => a -> Int
ceilingSize :: a -> Int
ceilingSize =
   Flip (->) Int a -> a -> Int
forall (f :: * -> * -> *) b a. Flip f b a -> f a b
getFlip (Flip (->) Int a -> a -> Int) -> Flip (->) Int a -> a -> Int
forall a b. (a -> b) -> a -> b
$
   Flip (->) Int Float
-> Flip (->) Int Double
-> Flip (->) Int (Complex Float)
-> Flip (->) Int (Complex Double)
-> Flip (->) Int a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating
      ((Float -> Int) -> Flip (->) Int Float
forall (f :: * -> * -> *) b a. f a b -> Flip f b a
Flip Float -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling)
      ((Double -> Int) -> Flip (->) Int Double
forall (f :: * -> * -> *) b a. f a b -> Flip f b a
Flip Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling)
      ((Complex Float -> Int) -> Flip (->) Int (Complex Float)
forall (f :: * -> * -> *) b a. f a b -> Flip f b a
Flip ((Complex Float -> Int) -> Flip (->) Int (Complex Float))
-> (Complex Float -> Int) -> Flip (->) Int (Complex Float)
forall a b. (a -> b) -> a -> b
$ Float -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Float -> Int) -> (Complex Float -> Float) -> Complex Float -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Complex Float -> Float
forall a. Complex a -> a
Complex.realPart)
      ((Complex Double -> Int) -> Flip (->) Int (Complex Double)
forall (f :: * -> * -> *) b a. f a b -> Flip f b a
Flip ((Complex Double -> Int) -> Flip (->) Int (Complex Double))
-> (Complex Double -> Int) -> Flip (->) Int (Complex Double)
forall a b. (a -> b) -> a -> b
$ Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Double -> Int)
-> (Complex Double -> Double) -> Complex Double -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Complex Double -> Double
forall a. Complex a -> a
Complex.realPart)


caseRealComplexFunc :: (Class.Floating a) => f a -> b -> b -> b
caseRealComplexFunc :: f a -> b -> b -> b
caseRealComplexFunc f a
f b
r b
c =
   f a -> Const b a -> b
forall (f :: * -> *) c a. f c -> Const a c -> a
getConstFunc f a
f (Const b a -> b) -> Const b a -> b
forall a b. (a -> b) -> a -> b
$
   Const b Float
-> Const b Double
-> Const b (Complex Float)
-> Const b (Complex Double)
-> Const b a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating (b -> Const b Float
forall k a (b :: k). a -> Const a b
Const b
r) (b -> Const b Double
forall k a (b :: k). a -> Const a b
Const b
r) (b -> Const b (Complex Float)
forall k a (b :: k). a -> Const a b
Const b
c) (b -> Const b (Complex Double)
forall k a (b :: k). a -> Const a b
Const b
c)

getConstFunc :: f c -> Const a c -> a
getConstFunc :: f c -> Const a c -> a
getConstFunc f c
_ = Const a c -> a
forall a k (b :: k). Const a b -> a
getConst


data ComplexPart = RealPart | ImaginaryPart
   deriving (ComplexPart -> ComplexPart -> Bool
(ComplexPart -> ComplexPart -> Bool)
-> (ComplexPart -> ComplexPart -> Bool) -> Eq ComplexPart
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ComplexPart -> ComplexPart -> Bool
$c/= :: ComplexPart -> ComplexPart -> Bool
== :: ComplexPart -> ComplexPart -> Bool
$c== :: ComplexPart -> ComplexPart -> Bool
Eq, Eq ComplexPart
Eq ComplexPart
-> (ComplexPart -> ComplexPart -> Ordering)
-> (ComplexPart -> ComplexPart -> Bool)
-> (ComplexPart -> ComplexPart -> Bool)
-> (ComplexPart -> ComplexPart -> Bool)
-> (ComplexPart -> ComplexPart -> Bool)
-> (ComplexPart -> ComplexPart -> ComplexPart)
-> (ComplexPart -> ComplexPart -> ComplexPart)
-> Ord ComplexPart
ComplexPart -> ComplexPart -> Bool
ComplexPart -> ComplexPart -> Ordering
ComplexPart -> ComplexPart -> ComplexPart
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ComplexPart -> ComplexPart -> ComplexPart
$cmin :: ComplexPart -> ComplexPart -> ComplexPart
max :: ComplexPart -> ComplexPart -> ComplexPart
$cmax :: ComplexPart -> ComplexPart -> ComplexPart
>= :: ComplexPart -> ComplexPart -> Bool
$c>= :: ComplexPart -> ComplexPart -> Bool
> :: ComplexPart -> ComplexPart -> Bool
$c> :: ComplexPart -> ComplexPart -> Bool
<= :: ComplexPart -> ComplexPart -> Bool
$c<= :: ComplexPart -> ComplexPart -> Bool
< :: ComplexPart -> ComplexPart -> Bool
$c< :: ComplexPart -> ComplexPart -> Bool
compare :: ComplexPart -> ComplexPart -> Ordering
$ccompare :: ComplexPart -> ComplexPart -> Ordering
$cp1Ord :: Eq ComplexPart
Ord, Int -> ComplexPart -> String -> String
[ComplexPart] -> String -> String
ComplexPart -> String
(Int -> ComplexPart -> String -> String)
-> (ComplexPart -> String)
-> ([ComplexPart] -> String -> String)
-> Show ComplexPart
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [ComplexPart] -> String -> String
$cshowList :: [ComplexPart] -> String -> String
show :: ComplexPart -> String
$cshow :: ComplexPart -> String
showsPrec :: Int -> ComplexPart -> String -> String
$cshowsPrec :: Int -> ComplexPart -> String -> String
Show, Int -> ComplexPart
ComplexPart -> Int
ComplexPart -> [ComplexPart]
ComplexPart -> ComplexPart
ComplexPart -> ComplexPart -> [ComplexPart]
ComplexPart -> ComplexPart -> ComplexPart -> [ComplexPart]
(ComplexPart -> ComplexPart)
-> (ComplexPart -> ComplexPart)
-> (Int -> ComplexPart)
-> (ComplexPart -> Int)
-> (ComplexPart -> [ComplexPart])
-> (ComplexPart -> ComplexPart -> [ComplexPart])
-> (ComplexPart -> ComplexPart -> [ComplexPart])
-> (ComplexPart -> ComplexPart -> ComplexPart -> [ComplexPart])
-> Enum ComplexPart
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
enumFromThenTo :: ComplexPart -> ComplexPart -> ComplexPart -> [ComplexPart]
$cenumFromThenTo :: ComplexPart -> ComplexPart -> ComplexPart -> [ComplexPart]
enumFromTo :: ComplexPart -> ComplexPart -> [ComplexPart]
$cenumFromTo :: ComplexPart -> ComplexPart -> [ComplexPart]
enumFromThen :: ComplexPart -> ComplexPart -> [ComplexPart]
$cenumFromThen :: ComplexPart -> ComplexPart -> [ComplexPart]
enumFrom :: ComplexPart -> [ComplexPart]
$cenumFrom :: ComplexPart -> [ComplexPart]
fromEnum :: ComplexPart -> Int
$cfromEnum :: ComplexPart -> Int
toEnum :: Int -> ComplexPart
$ctoEnum :: Int -> ComplexPart
pred :: ComplexPart -> ComplexPart
$cpred :: ComplexPart -> ComplexPart
succ :: ComplexPart -> ComplexPart
$csucc :: ComplexPart -> ComplexPart
Enum, ComplexPart
ComplexPart -> ComplexPart -> Bounded ComplexPart
forall a. a -> a -> Bounded a
maxBound :: ComplexPart
$cmaxBound :: ComplexPart
minBound :: ComplexPart
$cminBound :: ComplexPart
Bounded)