{-# LANGUAGE TypeFamilies #-}
module Numeric.BLAS.Private where

import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated, Conjugated))
import Numeric.BLAS.Scalar (RealOf, zero, one, minusOne, isZero)

import qualified Foreign.Marshal.Array.Guarded as ForeignArray
import qualified Foreign.Marshal.Utils as Marshal
import qualified Foreign.C.String as CStr
import Foreign.Marshal.Array (advancePtr)
import Foreign.C.Types (CChar, CInt)
import Foreign.Ptr (Ptr, castPtr)
import Foreign.Storable (Storable, peek, pokeElemOff, peekElemOff)

import Control.Monad.Trans.Cont (evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (when)
import Control.Applicative (liftA2)

import qualified Data.Array.Comfort.Shape as Shape

import qualified Data.Complex as Complex
import Data.Complex (Complex((:+)))

import Prelude hiding (sum)


type ShapeInt = Shape.ZeroBased Int

shapeInt :: Int -> ShapeInt
shapeInt :: Int -> ShapeInt
shapeInt = Int -> ShapeInt
forall n. n -> ZeroBased n
Shape.ZeroBased


realPtr :: Ptr a -> Ptr (RealOf a)
realPtr :: forall a. Ptr a -> Ptr (RealOf a)
realPtr = Ptr a -> Ptr (RealOf a)
forall a b. Ptr a -> Ptr b
castPtr


pointerSeq :: (Storable a) => Int -> Ptr a -> [Ptr a]
pointerSeq :: forall a. Storable a => Int -> Ptr a -> [Ptr a]
pointerSeq Int
k Ptr a
ptr = (Ptr a -> Ptr a) -> Ptr a -> [Ptr a]
forall a. (a -> a) -> a -> [a]
iterate ((Ptr a -> Int -> Ptr a) -> Int -> Ptr a -> Ptr a
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Int
k) Ptr a
ptr


fill :: (Class.Floating a) => a -> Int -> Ptr a -> IO ()
fill :: forall a. Floating a => a -> Int -> Ptr a -> IO ()
fill a
a Int
n Ptr a
dstPtr = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr a
srcPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
a
   Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
   Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
   IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
srcPtr Ptr CInt
incxPtr Ptr a
dstPtr Ptr CInt
incyPtr


copyConjugate ::
   (Class.Floating a) =>
   Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyConjugate :: forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
copyConjugate Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr
   Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a. Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv Ptr CInt
nPtr Ptr a
yPtr Ptr CInt
incyPtr



newtype Sum a = Sum {forall a. Sum a -> Int -> Ptr a -> Int -> IO a
runSum :: Int -> Ptr a -> Int -> IO a}

sum :: Class.Floating a => Int -> Ptr a -> Int -> IO a
sum :: forall a. Floating a => Int -> Ptr a -> Int -> IO a
sum =
   Sum a -> Int -> Ptr a -> Int -> IO a
forall a. Sum a -> Int -> Ptr a -> Int -> IO a
runSum (Sum a -> Int -> Ptr a -> Int -> IO a)
-> Sum a -> Int -> Ptr a -> Int -> IO a
forall a b. (a -> b) -> a -> b
$
   Sum Float
-> Sum Double
-> Sum (Complex Float)
-> Sum (Complex Double)
-> Sum a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
forall (f :: * -> *).
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating
      ((Int -> Ptr Float -> Int -> IO Float) -> Sum Float
forall a. (Int -> Ptr a -> Int -> IO a) -> Sum a
Sum Int -> Ptr Float -> Int -> IO Float
forall a. Real a => Int -> Ptr a -> Int -> IO a
sumReal)
      ((Int -> Ptr Double -> Int -> IO Double) -> Sum Double
forall a. (Int -> Ptr a -> Int -> IO a) -> Sum a
Sum Int -> Ptr Double -> Int -> IO Double
forall a. Real a => Int -> Ptr a -> Int -> IO a
sumReal)
      ((Int -> Ptr (Complex Float) -> Int -> IO (Complex Float))
-> Sum (Complex Float)
forall a. (Int -> Ptr a -> Int -> IO a) -> Sum a
Sum Int -> Ptr (Complex Float) -> Int -> IO (Complex Float)
forall a. Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex)
      ((Int -> Ptr (Complex Double) -> Int -> IO (Complex Double))
-> Sum (Complex Double)
forall a. (Int -> Ptr a -> Int -> IO a) -> Sum a
Sum Int -> Ptr (Complex Double) -> Int -> IO (Complex Double)
forall a. Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex)

sumReal :: Class.Real a => Int -> Ptr a -> Int -> IO a
sumReal :: forall a. Real a => Int -> Ptr a -> Int -> IO a
sumReal Int
n Ptr a
xPtr Int
incx =
   ContT a IO a -> IO a
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT a IO a -> IO a) -> ContT a IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr CInt
incxPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
incx
      Ptr a
yPtr <- a -> FortranIO a (Ptr a)
forall a r. Real a => a -> FortranIO r (Ptr a)
Call.real a
forall a. Floating a => a
one
      Ptr CInt
incyPtr <- Int -> FortranIO a (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      IO a -> ContT a IO a
forall a. IO a -> ContT a IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> ContT a IO a) -> IO a -> ContT a IO a
forall a b. (a -> b) -> a -> b
$ Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
forall a.
Real a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
BlasReal.dot Ptr CInt
nPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr

sumComplex, sumComplexAlt ::
   Class.Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex :: forall a. Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplex Int
n Ptr (Complex a)
xPtr Int
incx =
   ContT (Complex a) IO (Complex a) -> IO (Complex a)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (Complex a) IO (Complex a) -> IO (Complex a))
-> ContT (Complex a) IO (Complex a) -> IO (Complex a)
forall a b. (a -> b) -> a -> b
$ do
      Ptr CInt
nPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      let sxPtr :: Ptr (RealOf (Complex a))
sxPtr = Ptr (Complex a) -> Ptr (RealOf (Complex a))
forall a. Ptr a -> Ptr (RealOf a)
realPtr Ptr (Complex a)
xPtr
      Ptr CInt
incxPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
incx)
      Ptr a
yPtr <- a -> FortranIO (Complex a) (Ptr a)
forall a r. Real a => a -> FortranIO r (Ptr a)
Call.real a
forall a. Floating a => a
one
      Ptr CInt
incyPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      IO (Complex a) -> ContT (Complex a) IO (Complex a)
forall a. IO a -> ContT (Complex a) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Complex a) -> ContT (Complex a) IO (Complex a))
-> IO (Complex a) -> ContT (Complex a) IO (Complex a)
forall a b. (a -> b) -> a -> b
$
         (a -> a -> Complex a) -> IO a -> IO a -> IO (Complex a)
forall a b c. (a -> b -> c) -> IO a -> IO b -> IO c
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> Complex a
forall a. a -> a -> Complex a
(Complex.:+)
            (Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
forall a.
Real a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
BlasReal.dot Ptr CInt
nPtr Ptr a
Ptr (RealOf (Complex a))
sxPtr Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr)
            (Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
forall a.
Real a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO a
BlasReal.dot Ptr CInt
nPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
Ptr (RealOf (Complex a))
sxPtr Int
1) Ptr CInt
incxPtr Ptr a
yPtr Ptr CInt
incyPtr)

sumComplexAlt :: forall a. Real a => Int -> Ptr (Complex a) -> Int -> IO (Complex a)
sumComplexAlt Int
n Ptr (Complex a)
aPtr Int
inca =
   ContT (Complex a) IO (Complex a) -> IO (Complex a)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT (Complex a) IO (Complex a) -> IO (Complex a))
-> ContT (Complex a) IO (Complex a) -> IO (Complex a)
forall a b. (a -> b) -> a -> b
$ do
      Ptr CChar
transPtr <- Char -> FortranIO (Complex a) (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char Char
'N'
      Ptr CInt
mPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
2
      Ptr CInt
nPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
      Ptr a
onePtr <- a -> FortranIO (Complex a) (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
      Ptr CInt
inc0Ptr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
      let saPtr :: Ptr (RealOf (Complex a))
saPtr = Ptr (Complex a) -> Ptr (RealOf (Complex a))
forall a. Ptr a -> Ptr (RealOf a)
realPtr Ptr (Complex a)
aPtr
      Ptr CInt
ldaPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
inca)
      Ptr a
sxPtr <- Int -> FortranIO (Complex a) (Ptr a)
forall a r. Storable a => Int -> FortranIO r (Ptr a)
Call.allocaArray Int
n
      Ptr CInt
incxPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      Ptr a
betaPtr <- a -> FortranIO (Complex a) (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
zero
      Ptr (Complex a)
yPtr <- FortranIO (Complex a) (Ptr (Complex a))
forall a r. Storable a => FortranIO r (Ptr a)
Call.alloca
      let syPtr :: Ptr (RealOf (Complex a))
syPtr = Ptr (Complex a) -> Ptr (RealOf (Complex a))
forall a. Ptr a -> Ptr (RealOf a)
realPtr Ptr (Complex a)
yPtr
      Ptr CInt
incyPtr <- Int -> FortranIO (Complex a) (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
1
      IO (Complex a) -> ContT (Complex a) IO (Complex a)
forall a. IO a -> ContT (Complex a) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Complex a) -> ContT (Complex a) IO (Complex a))
-> IO (Complex a) -> ContT (Complex a) IO (Complex a)
forall a b. (a -> b) -> a -> b
$ do
         Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
nPtr Ptr a
onePtr Ptr CInt
inc0Ptr Ptr a
sxPtr Ptr CInt
incxPtr
         Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
gemv
            Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
onePtr Ptr a
Ptr (RealOf (Complex a))
saPtr Ptr CInt
ldaPtr
            Ptr a
sxPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
Ptr (RealOf (Complex a))
syPtr Ptr CInt
incyPtr
         Ptr (Complex a) -> IO (Complex a)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Complex a)
yPtr


mul ::
   (Class.Floating a) =>
   Conjugation -> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul :: forall a.
Floating a =>
Conjugation
-> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul Conjugation
conj Int
n Ptr a
aPtr Int
inca Ptr a
xPtr Int
incx Ptr a
yPtr Int
incy =
   Conjugation
-> Int
-> Ptr a
-> Int
-> Ptr a
-> Int
-> a
-> Ptr a
-> Int
-> IO ()
forall a.
Floating a =>
Conjugation
-> Int
-> Ptr a
-> Int
-> Ptr a
-> Int
-> a
-> Ptr a
-> Int
-> IO ()
mulAdd Conjugation
conj Int
n Ptr a
aPtr Int
inca Ptr a
xPtr Int
incx a
forall a. Floating a => a
zero Ptr a
yPtr Int
incy

mulAdd ::
   (Class.Floating a) =>
   Conjugation ->
   Int -> Ptr a -> Int -> Ptr a -> Int -> a -> Ptr a -> Int -> IO ()
mulAdd :: forall a.
Floating a =>
Conjugation
-> Int
-> Ptr a
-> Int
-> Ptr a
-> Int
-> a
-> Ptr a
-> Int
-> IO ()
mulAdd Conjugation
conj Int
n Ptr a
aPtr Int
inca Ptr a
xPtr Int
incx a
beta Ptr a
yPtr Int
incy = ContT () IO () -> IO ()
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT (ContT () IO () -> IO ()) -> ContT () IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
   Ptr CChar
transPtr <- Char -> FortranIO () (Ptr CChar)
forall r. Char -> FortranIO r (Ptr CChar)
Call.char (Char -> FortranIO () (Ptr CChar))
-> Char -> FortranIO () (Ptr CChar)
forall a b. (a -> b) -> a -> b
$ case Conjugation
conj of Conjugation
NonConjugated -> Char
'N'; Conjugation
Conjugated -> Char
'C'
   Ptr CInt
nPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
n
   Ptr CInt
klPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
   Ptr CInt
kuPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
0
   Ptr a
alphaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
forall a. Floating a => a
one
   Ptr CInt
ldaPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.leadingDim Int
inca
   Ptr CInt
incxPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
incx
   Ptr a
betaPtr <- a -> FortranIO () (Ptr a)
forall a r. Floating a => a -> FortranIO r (Ptr a)
Call.number a
beta
   Ptr CInt
incyPtr <- Int -> FortranIO () (Ptr CInt)
forall r. Int -> FortranIO r (Ptr CInt)
Call.cint Int
incy
   IO () -> ContT () IO ()
forall a. IO a -> ContT () IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ContT () IO ()) -> IO () -> ContT () IO ()
forall a b. (a -> b) -> a -> b
$
      Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gbmv Ptr CChar
transPtr
         Ptr CInt
nPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr
         Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr

{- |
Use the foldBalanced trick.
-}
product :: (Class.Floating a) => Int -> Ptr a -> Int -> IO a
product :: forall a. Floating a => Int -> Ptr a -> Int -> IO a
product Int
n Ptr a
aPtr Int
inca =
   case Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
n Int
1 of
      Ordering
LT -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
forall a. Floating a => a
one
      Ordering
EQ -> Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
aPtr
      Ordering
GT -> let n2 :: Int
n2 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
n Int
2; new :: Int
new = Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
n2
            in Int -> (Ptr a -> IO a) -> IO a
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
ForeignArray.alloca (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
newInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ((Ptr a -> IO a) -> IO a) -> (Ptr a -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr a
xPtr -> do
         Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
forall a.
Floating a =>
Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs Int
n2 Ptr a
aPtr Int
inca Ptr a
xPtr Int
1
         Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int -> Bool
forall a. Integral a => a -> Bool
odd Int
n) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr a -> Int -> a -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr a
xPtr Int
n2 (a -> IO ()) -> IO a -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr a -> Int -> IO a
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr a
aPtr ((Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
inca)
         Int -> Ptr a -> IO a
forall a. Floating a => Int -> Ptr a -> IO a
productLoop Int
new Ptr a
xPtr

{- |
If 'mul' would be based on a scalar loop
we would not need to cut the vector into chunks.

The invariance is:
When calling @productLoop n xPtr@,
starting from xPtr there is storage allocated for 2*n-1 elements.
-}
productLoop :: (Class.Floating a) => Int -> Ptr a -> IO a
productLoop :: forall a. Floating a => Int -> Ptr a -> IO a
productLoop Int
n Ptr a
xPtr =
   if Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
1
      then Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
xPtr
      else do
         let n2 :: Int
n2 = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
n Int
2
         Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
forall a.
Floating a =>
Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs Int
n2 Ptr a
xPtr Int
1 (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr Int
n) Int
1
         Int -> Ptr a -> IO a
forall a. Floating a => Int -> Ptr a -> IO a
productLoop (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
n2) (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
xPtr (Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
n2))

mulPairs ::
   (Class.Floating a) =>
   Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs :: forall a.
Floating a =>
Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mulPairs Int
n Ptr a
aPtr Int
inca Ptr a
xPtr Int
incx =
   let inca2 :: Int
inca2 = Int
2Int -> Int -> Int
forall a. Num a => a -> a -> a
*Int
inca
   in Conjugation
-> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
forall a.
Floating a =>
Conjugation
-> Int -> Ptr a -> Int -> Ptr a -> Int -> Ptr a -> Int -> IO ()
mul Conjugation
NonConjugated Int
n Ptr a
aPtr Int
inca2 (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr Ptr a
aPtr Int
inca) Int
inca2 Ptr a
xPtr Int
incx


newtype LACGV a = LACGV {forall a. LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
getLACGV :: Ptr CInt -> Ptr a -> Ptr CInt -> IO ()}

lacgv :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv :: forall a. Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
lacgv =
   LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a. LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
getLACGV (LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ())
-> LACGV a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$
   LACGV Float
-> LACGV Double
-> LACGV (Complex Float)
-> LACGV (Complex Double)
-> LACGV a
forall a (f :: * -> *).
Floating a =>
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
forall (f :: * -> *).
f Float
-> f Double -> f (Complex Float) -> f (Complex Double) -> f a
Class.switchFloating
      ((Ptr CInt -> Ptr Float -> Ptr CInt -> IO ()) -> LACGV Float
forall a. (Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> LACGV a
LACGV ((Ptr CInt -> Ptr Float -> Ptr CInt -> IO ()) -> LACGV Float)
-> (Ptr CInt -> Ptr Float -> Ptr CInt -> IO ()) -> LACGV Float
forall a b. (a -> b) -> a -> b
$ (Ptr Float -> Ptr CInt -> IO ())
-> Ptr CInt -> Ptr Float -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const ((Ptr Float -> Ptr CInt -> IO ())
 -> Ptr CInt -> Ptr Float -> Ptr CInt -> IO ())
-> (Ptr Float -> Ptr CInt -> IO ())
-> Ptr CInt
-> Ptr Float
-> Ptr CInt
-> IO ()
forall a b. (a -> b) -> a -> b
$ (Ptr CInt -> IO ()) -> Ptr Float -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const ((Ptr CInt -> IO ()) -> Ptr Float -> Ptr CInt -> IO ())
-> (Ptr CInt -> IO ()) -> Ptr Float -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const (IO () -> Ptr CInt -> IO ()) -> IO () -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
      ((Ptr CInt -> Ptr Double -> Ptr CInt -> IO ()) -> LACGV Double
forall a. (Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> LACGV a
LACGV ((Ptr CInt -> Ptr Double -> Ptr CInt -> IO ()) -> LACGV Double)
-> (Ptr CInt -> Ptr Double -> Ptr CInt -> IO ()) -> LACGV Double
forall a b. (a -> b) -> a -> b
$ (Ptr Double -> Ptr CInt -> IO ())
-> Ptr CInt -> Ptr Double -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const ((Ptr Double -> Ptr CInt -> IO ())
 -> Ptr CInt -> Ptr Double -> Ptr CInt -> IO ())
-> (Ptr Double -> Ptr CInt -> IO ())
-> Ptr CInt
-> Ptr Double
-> Ptr CInt
-> IO ()
forall a b. (a -> b) -> a -> b
$ (Ptr CInt -> IO ()) -> Ptr Double -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const ((Ptr CInt -> IO ()) -> Ptr Double -> Ptr CInt -> IO ())
-> (Ptr CInt -> IO ()) -> Ptr Double -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> Ptr CInt -> IO ()
forall a b. a -> b -> a
const (IO () -> Ptr CInt -> IO ()) -> IO () -> Ptr CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
      ((Ptr CInt -> Ptr (Complex Float) -> Ptr CInt -> IO ())
-> LACGV (Complex Float)
forall a. (Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> LACGV a
LACGV Ptr CInt -> Ptr (Complex Float) -> Ptr CInt -> IO ()
forall a.
Real a =>
Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
clacgv)
      ((Ptr CInt -> Ptr (Complex Double) -> Ptr CInt -> IO ())
-> LACGV (Complex Double)
forall a. (Ptr CInt -> Ptr a -> Ptr CInt -> IO ()) -> LACGV a
LACGV Ptr CInt -> Ptr (Complex Double) -> Ptr CInt -> IO ()
forall a.
Real a =>
Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
clacgv)

clacgv :: Class.Real a => Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
clacgv :: forall a.
Real a =>
Ptr CInt -> Ptr (Complex a) -> Ptr CInt -> IO ()
clacgv Ptr CInt
nPtr Ptr (Complex a)
xPtr Ptr CInt
incxPtr =
   a -> (Ptr a -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
Marshal.with a
forall a. Floating a => a
minusOne ((Ptr a -> IO ()) -> IO ()) -> (Ptr a -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr a
saPtr -> do
      CInt
incx <- Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
incxPtr
      CInt -> (Ptr CInt -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
Marshal.with (CInt
2CInt -> CInt -> CInt
forall a. Num a => a -> a -> a
*CInt
incx) ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
incyPtr ->
         Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a. Real a => Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
BlasReal.scal Ptr CInt
nPtr Ptr a
saPtr (Ptr a -> Int -> Ptr a
forall a. Storable a => Ptr a -> Int -> Ptr a
advancePtr (Ptr (Complex a) -> Ptr (RealOf (Complex a))
forall a. Ptr a -> Ptr (RealOf a)
realPtr Ptr (Complex a)
xPtr) Int
1) Ptr CInt
incyPtr


{-
Work around an inconsistency of BLAS.
In case of a zero-column matrix
BLAS's gemv and gbmv do not initialize the target vector.
In contrast, these work-arounds do.
-}
{-# INLINE gemv #-}
gemv ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt ->
   Ptr a -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
gemv :: forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
gemv Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr
      Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr
   Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gemv Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr
      Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr

{-# INLINE gbmv #-}
gbmv ::
   (Class.Floating a) =>
   Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr CInt -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt ->
   Ptr a -> Ptr a -> Ptr CInt -> IO ()
gbmv :: forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
gbmv Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr
      Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr
   Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
forall a.
Floating a =>
Ptr CChar
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr CInt
-> Ptr a
-> Ptr a
-> Ptr CInt
-> IO ()
Blas.gbmv Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr CInt
klPtr Ptr CInt
kuPtr
      Ptr a
alphaPtr Ptr a
aPtr Ptr CInt
ldaPtr Ptr a
xPtr Ptr CInt
incxPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr

initializeMV ::
   Class.Floating a =>
   Ptr CChar -> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV :: forall a.
Floating a =>
Ptr CChar
-> Ptr CInt -> Ptr CInt -> Ptr a -> Ptr a -> Ptr CInt -> IO ()
initializeMV Ptr CChar
transPtr Ptr CInt
mPtr Ptr CInt
nPtr Ptr a
betaPtr Ptr a
yPtr Ptr CInt
incyPtr = do
   CChar
trans <- Ptr CChar -> IO CChar
forall a. Storable a => Ptr a -> IO a
peek Ptr CChar
transPtr
   let (Ptr CInt
mtPtr,Ptr CInt
ntPtr) =
         if CChar
trans CChar -> CChar -> Bool
forall a. Eq a => a -> a -> Bool
== Char -> CChar
CStr.castCharToCChar Char
'N'
            then (Ptr CInt
mPtr,Ptr CInt
nPtr) else (Ptr CInt
nPtr,Ptr CInt
mPtr)
   CInt
n <- Ptr CInt -> IO CInt
forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
ntPtr
   a
beta <- Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek Ptr a
betaPtr
   Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CInt
n CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0 Bool -> Bool -> Bool
&& a -> Bool
forall a. Floating a => a -> Bool
isZero a
beta) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      CInt -> (Ptr CInt -> IO ()) -> IO ()
forall a b. Storable a => a -> (Ptr a -> IO b) -> IO b
Marshal.with CInt
0 ((Ptr CInt -> IO ()) -> IO ()) -> (Ptr CInt -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CInt
incbPtr ->
      Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
forall a.
Floating a =>
Ptr CInt -> Ptr a -> Ptr CInt -> Ptr a -> Ptr CInt -> IO ()
Blas.copy Ptr CInt
mtPtr Ptr a
betaPtr Ptr CInt
incbPtr Ptr a
yPtr Ptr CInt
incyPtr


{-
ToDo:

type ComplexShape =
         Shape.NestedTuple Shape.TupleAccessor (Complex Shape.Element)

This would allow the use of Complex.realPart as accessor,
but it requires GHC>7.6.3 or so, where realPart has no RealFloat constraint.
-}
type ComplexShape = Shape.NestedTuple Shape.TupleIndex (Complex Shape.Element)

ixReal, ixImaginary :: Shape.ElementIndex (Complex Shape.Element)
ElementIndex (Complex Element)
ixReal :+ ElementIndex (Complex Element)
ixImaginary =
   ComplexShape
-> DataTuple (Complex Element) (ElementIndex (Complex Element))
forall tuple.
ElementTuple tuple =>
NestedTuple TupleIndex tuple
-> DataTuple tuple (ElementIndex tuple)
Shape.indexTupleFromShape (ComplexShape
forall sh. Static sh => sh
Shape.static :: ComplexShape)