{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.BLAS.Vector (
Vector,
RealOf,
ComplexOf,
toList,
fromList,
autoFromList,
CheckedArray.append, (+++),
CheckedArray.take, CheckedArray.drop,
CheckedArray.takeLeft, CheckedArray.takeRight,
swap,
CheckedArray.singleton,
constant,
zero,
one,
unit,
dot, inner, (-*|),
sum,
absSum,
norm2,
norm2Squared,
normInf,
normInf1,
argAbsMaximum,
argAbs1Maximum,
product,
scale, scaleReal, (.*|),
add, sub, (|+|), (|-|),
negate, raise,
mac,
mul, mulConj,
minimum, argMinimum,
maximum, argMaximum,
limits, argLimits,
CheckedArray.foldl,
CheckedArray.foldl1,
CheckedArray.foldMap,
conjugate,
fromReal,
toComplex,
realPart,
imaginaryPart,
zipComplex,
unzipComplex,
) where
import qualified Numeric.BLAS.Matrix.RowMajor as RowMajor
import qualified Numeric.BLAS.Scalar as Scalar
import qualified Numeric.BLAS.Private as Private
import Numeric.BLAS.Matrix.Modifier (Conjugation(NonConjugated, Conjugated))
import Numeric.BLAS.Scalar (ComplexOf, RealOf, minusOne)
import Numeric.BLAS.Private
(ComplexShape, ixReal, ixImaginary, fill, copyConjugate, realPtr)
import qualified Numeric.BLAS.FFI.Generic as Blas
import qualified Numeric.BLAS.FFI.Complex as BlasComplex
import qualified Numeric.BLAS.FFI.Real as BlasReal
import qualified Numeric.Netlib.Utility as Call
import qualified Numeric.Netlib.Class as Class
import qualified Foreign.Marshal.Array.Guarded as ForeignArray
import Foreign.Marshal.Array (advancePtr)
import Foreign.ForeignPtr (withForeignPtr, castForeignPtr)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable, peek, peekElemOff, pokeElemOff)
import Foreign.C.Types (CInt)
import System.IO.Unsafe (unsafePerformIO)
import Control.Monad.Trans.Cont (ContT(ContT), evalContT)
import Control.Monad.IO.Class (liftIO)
import Control.Monad.ST (runST)
import Control.Monad (fmap, return, (=<<))
import Control.Applicative (liftA3, (<$>))
import qualified Data.Array.Comfort.Storable.Mutable.Unchecked as UMutArray
import qualified Data.Array.Comfort.Storable.Mutable as MutArray
import qualified Data.Array.Comfort.Storable.Unchecked as Array
import qualified Data.Array.Comfort.Storable as CheckedArray
import qualified Data.Array.Comfort.Shape as Shape
import Data.Array.Comfort.Storable.Unchecked (Array(Array), append, (!))
import Data.Array.Comfort.Shape ((::+))
import Data.Function (id, flip, ($), (.))
import Data.Complex (Complex)
import Data.Maybe (Maybe(Nothing,Just), maybe)
import Data.Tuple.HT (mapFst, uncurry3)
import Data.Tuple (fst, snd)
import Data.Ord ((>=))
import Data.Eq (Eq, (==))
import Prelude (Int, fromIntegral, (-), Char, error, IO)
type Vector = Array
toList :: (Shape.C sh, Storable a) => Vector sh a -> [a]
toList = Array.toList
fromList :: (Shape.C sh, Storable a) => sh -> [a] -> Vector sh a
fromList = CheckedArray.fromList
autoFromList :: (Storable a) => [a] -> Vector (Shape.ZeroBased Int) a
autoFromList = Array.vectorFromList
constant :: (Shape.C sh, Class.Floating a) => sh -> a -> Vector sh a
constant sh a = Array.unsafeCreateWithSize sh $ fill a
zero :: (Shape.C sh, Class.Floating a) => sh -> Vector sh a
zero = flip constant Scalar.zero
one :: (Shape.C sh, Class.Floating a) => sh -> Vector sh a
one = flip constant Scalar.one
unit ::
(Shape.Indexed sh, Class.Floating a) =>
sh -> Shape.Index sh -> Vector sh a
unit sh ix = Array.unsafeCreateWithSize sh $ \n xPtr -> do
fill Scalar.zero n xPtr
pokeElemOff xPtr (Shape.offset sh ix) Scalar.one
infixr 5 +++
(+++) ::
(Shape.C shx, Shape.C shy, Storable a) =>
Vector shx a -> Vector shy a -> Vector (shx::+shy) a
(+++) = append
swap ::
(Shape.Indexed sh, Storable a) =>
Shape.Index sh -> Shape.Index sh -> Vector sh a -> Vector sh a
swap i j x =
runST (do
y <- MutArray.thaw x
xi <- MutArray.read y i
xj <- MutArray.read y j
MutArray.write y i xj
MutArray.write y j xi
UMutArray.unsafeFreeze y)
infixl 7 -*|, .*|
newtype Dot f a = Dot {runDot :: f a -> f a -> a}
dot, (-*|) ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Vector sh a -> Vector sh a -> a
(-*|) = dot
dot =
runDot $
Class.switchFloating
(Dot dotReal)
(Dot dotReal)
(Dot $ dotComplex 'T')
(Dot $ dotComplex 'T')
inner ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Vector sh a -> Vector sh a -> a
inner =
runDot $
Class.switchFloating
(Dot dotReal)
(Dot dotReal)
(Dot $ dotComplex 'C')
(Dot $ dotComplex 'C')
dotReal ::
(Shape.C sh, Eq sh, Class.Real a) =>
Vector sh a -> Vector sh a -> a
dotReal arrX@(Array shX _x) (Array shY y) = unsafePerformIO $ do
Call.assert "dot: shapes mismatch" (shX == shY)
evalContT $ do
(nPtr, sxPtr, incxPtr) <- vectorArgs arrX
syPtr <- ContT $ withForeignPtr y
incyPtr <- Call.cint 1
liftIO $ BlasReal.dot nPtr sxPtr incxPtr syPtr incyPtr
dotComplex ::
(Shape.C sh, Eq sh, Class.Real a) =>
Char -> Vector sh (Complex a) -> Vector sh (Complex a) -> Complex a
dotComplex trans (Array shX x) (Array shY y) = unsafePerformIO $ do
Call.assert "dot: shapes mismatch" (shX == shY)
evalContT $ do
let m = Shape.size shX
transPtr <- Call.char trans
mPtr <- Call.cint m
nPtr <- Call.cint 1
alphaPtr <- Call.number Scalar.one
xPtr <- ContT $ withForeignPtr x
ldxPtr <- Call.leadingDim m
yPtr <- ContT $ withForeignPtr y
incyPtr <- Call.cint 1
betaPtr <- Call.number Scalar.zero
zPtr <- Call.alloca
inczPtr <- Call.cint 1
liftIO $
Private.gemv
transPtr mPtr nPtr alphaPtr xPtr ldxPtr
yPtr incyPtr betaPtr zPtr inczPtr
liftIO $ peek zPtr
sum :: (Shape.C sh, Class.Floating a) => Vector sh a -> a
sum (Array sh x) = unsafePerformIO $
withForeignPtr x $ \xPtr -> Private.sum (Shape.size sh) xPtr 1
absSum :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
absSum arr = unsafePerformIO $
evalContT $ liftIO . uncurry3 asum =<< vectorArgs arr
asum :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)
asum =
getNrm $
Class.switchFloating
(Nrm BlasReal.asum) (Nrm BlasReal.asum)
(Nrm BlasComplex.casum) (Nrm BlasComplex.casum)
norm2 :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
norm2 arr = unsafePerformIO $
evalContT $ liftIO . uncurry3 nrm2 =<< vectorArgs arr
nrm2 :: Class.Floating a => Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)
nrm2 =
getNrm $
Class.switchFloating
(Nrm BlasReal.nrm2) (Nrm BlasReal.nrm2)
(Nrm BlasComplex.cnrm2) (Nrm BlasComplex.cnrm2)
newtype Nrm a =
Nrm {getNrm :: Ptr CInt -> Ptr a -> Ptr CInt -> IO (RealOf a)}
newtype Norm f a = Norm {getNorm :: f a -> RealOf a}
norm2Squared :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
norm2Squared =
getNorm $
Class.switchFloating
(Norm norm2SquaredReal)
(Norm norm2SquaredReal)
(Norm $ norm2SquaredReal . decomplex)
(Norm $ norm2SquaredReal . decomplex)
norm2SquaredReal :: (Shape.C sh, Class.Real a) => Vector sh a -> a
norm2SquaredReal arr =
unsafePerformIO $ evalContT $ do
(nPtr, sxPtr, incxPtr) <- vectorArgs arr
liftIO $ BlasReal.dot nPtr sxPtr incxPtr sxPtr incxPtr
normInf :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
normInf arr = unsafePerformIO $
fmap (Scalar.absolute . maybe Scalar.zero snd) $ absMax arr
normInf1 :: (Shape.C sh, Class.Floating a) => Vector sh a -> RealOf a
normInf1 arr = unsafePerformIO $
evalContT $ do
(nPtr, sxPtr, incxPtr) <- vectorArgs arr
liftIO $
fmap (Scalar.norm1 . maybe Scalar.zero snd) $
peekElemOff1 sxPtr =<< Blas.iamax nPtr sxPtr incxPtr
argAbsMaximum ::
(Shape.InvIndexed sh, Class.Floating a) =>
Vector sh a -> (Shape.Index sh, a)
argAbsMaximum arr = unsafePerformIO $
fmap
(maybe
(error "Vector.argAbsMaximum: empty vector")
(mapFst (Shape.uncheckedIndexFromOffset $ Array.shape arr))) $
absMax arr
absMax ::
(Shape.C sh, Class.Floating a) =>
Vector sh a -> IO (Maybe (Int, a))
absMax arr@(Array sh x) =
case Scalar.complexSingletonOfFunctor arr of
Scalar.Real -> evalContT $ do
(nPtr, sxPtr, incxPtr) <- vectorArgs arr
liftIO $ peekElemOff1 sxPtr =<< Blas.iamax nPtr sxPtr incxPtr
Scalar.Complex -> evalContT $ do
let n = Shape.size sh
sxPtr <- ContT $ withForeignPtr x
liftIO $ peekElemOff1 sxPtr =<< absMaxComplex n sxPtr
absMaxComplex :: (Class.Real a) => Int -> Ptr (Complex a) -> IO CInt
absMaxComplex n sxPtr =
ForeignArray.alloca n $ \syPtr -> do
let xrPtr = realPtr sxPtr
Private.mul NonConjugated n xrPtr 2 xrPtr 2 syPtr 1
let xiPtr = advancePtr xrPtr 1
Private.mulAdd NonConjugated n xiPtr 2 xiPtr 2 Scalar.one syPtr 1
evalContT $ do
nPtr <- Call.cint n
incyPtr <- Call.cint 1
liftIO $ Blas.iamax nPtr syPtr incyPtr
argAbs1Maximum ::
(Shape.InvIndexed sh, Class.Floating a) =>
Vector sh a -> (Shape.Index sh, a)
argAbs1Maximum arr = unsafePerformIO $
evalContT $ do
(nPtr, sxPtr, incxPtr) <- vectorArgs arr
liftIO $
fmap
(maybe
(error "Vector.argAbs1Maximum: empty vector")
(mapFst (Shape.uncheckedIndexFromOffset $ Array.shape arr))) $
peekElemOff1 sxPtr =<< Blas.iamax nPtr sxPtr incxPtr
vectorArgs ::
(Shape.C sh) => Array sh a -> ContT r IO (Ptr CInt, Ptr a, Ptr CInt)
vectorArgs (Array sh x) =
liftA3 (,,)
(Call.cint $ Shape.size sh)
(ContT $ withForeignPtr x)
(Call.cint 1)
peekElemOff1 :: (Storable a) => Ptr a -> CInt -> IO (Maybe (Int, a))
peekElemOff1 ptr k1 =
let k1i = fromIntegral k1
ki = k1i-1
in if k1i == 0
then return Nothing
else Just . (,) ki <$> peekElemOff ptr ki
product :: (Shape.C sh, Class.Floating a) => Vector sh a -> a
product (Array sh x) = unsafePerformIO $
withForeignPtr x $ \xPtr -> Private.product (Shape.size sh) xPtr 1
minimum, maximum :: (Shape.C sh, Class.Real a) => Vector sh a -> a
minimum = fst . limits
maximum = snd . limits
argMinimum, argMaximum ::
(Shape.InvIndexed sh, Shape.Index sh ~ ix, Class.Real a) =>
Vector sh a -> (ix,a)
argMinimum = fst . argLimits
argMaximum = snd . argLimits
limits :: (Shape.C sh, Class.Real a) => Vector sh a -> (a,a)
limits xs0 =
let xs = Array.mapShape Shape.Deferred xs0
x0 = snd $ argAbs1Maximum xs
x1 = xs ! fst (argAbs1Maximum (raise (-x0) xs))
in if x0>=0 then (x1,x0) else (x0,x1)
argLimits ::
(Shape.InvIndexed sh, Shape.Index sh ~ ix, Class.Real a) =>
Vector sh a -> ((ix,a),(ix,a))
argLimits xs =
let p0@(_i0,x0) = argAbs1Maximum xs
p1 = (i1,xs!i1); i1 = fst $ argAbs1Maximum $ raise (-x0) xs
in if x0>=0 then (p1,p0) else (p0,p1)
scale, _scale, (.*|) ::
(Shape.C sh, Class.Floating a) =>
a -> Vector sh a -> Vector sh a
(.*|) = scale
scale alpha (Array sh x) = Array.unsafeCreateWithSize sh $ \n syPtr -> do
evalContT $ do
alphaPtr <- Call.number alpha
nPtr <- Call.cint n
sxPtr <- ContT $ withForeignPtr x
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $ Blas.copy nPtr sxPtr incxPtr syPtr incyPtr
liftIO $ Blas.scal nPtr alphaPtr syPtr incyPtr
_scale a (Array sh b) = Array.unsafeCreateWithSize sh $ \n cPtr -> do
let m = 1
let k = 1
evalContT $ do
transaPtr <- Call.char 'N'
transbPtr <- Call.char 'N'
mPtr <- Call.cint m
kPtr <- Call.cint k
nPtr <- Call.cint n
alphaPtr <- Call.number Scalar.one
aPtr <- Call.number a
ldaPtr <- Call.leadingDim m
bPtr <- ContT $ withForeignPtr b
ldbPtr <- Call.leadingDim k
betaPtr <- Call.number Scalar.zero
ldcPtr <- Call.leadingDim m
liftIO $
Blas.gemm
transaPtr transbPtr mPtr nPtr kPtr alphaPtr
aPtr ldaPtr bPtr ldbPtr betaPtr cPtr ldcPtr
scaleReal ::
(Shape.C sh, Class.Floating a) =>
RealOf a -> Vector sh a -> Vector sh a
scaleReal =
getScaleReal $
Class.switchFloating
(ScaleReal scale)
(ScaleReal scale)
(ScaleReal $ \x -> recomplex . scale x . decomplex)
(ScaleReal $ \x -> recomplex . scale x . decomplex)
newtype ScaleReal f a = ScaleReal {getScaleReal :: RealOf a -> f a -> f a}
decomplex ::
(Class.Real a) =>
Vector sh (Complex a) -> Vector (sh, ComplexShape) a
decomplex (Array sh a) = Array (sh, Shape.static) (castForeignPtr a)
recomplex ::
(Class.Real a) =>
Vector (sh, ComplexShape) a -> Vector sh (Complex a)
recomplex (Array (sh, Shape.NestedTuple _) a) = Array sh (castForeignPtr a)
infixl 6 |+|, |-|, `add`, `sub`
add, sub, (|+|), (|-|) ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Vector sh a -> Vector sh a -> Vector sh a
add = mac Scalar.one
sub x y = mac minusOne y x
(|+|) = add
(|-|) = sub
mac ::
(Shape.C sh, Eq sh, Class.Floating a) =>
a -> Vector sh a -> Vector sh a -> Vector sh a
mac alpha (Array shX x) (Array shY y) =
Array.unsafeCreateWithSize shX $ \n szPtr -> do
Call.assert "mac: shapes mismatch" (shX == shY)
evalContT $ do
nPtr <- Call.cint n
saPtr <- Call.number alpha
sxPtr <- ContT $ withForeignPtr x
incxPtr <- Call.cint 1
syPtr <- ContT $ withForeignPtr y
incyPtr <- Call.cint 1
inczPtr <- Call.cint 1
liftIO $ Blas.copy nPtr syPtr incyPtr szPtr inczPtr
liftIO $ Blas.axpy nPtr saPtr sxPtr incxPtr szPtr inczPtr
negate :: (Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh a
negate =
getConjugate $
Class.switchFloating
(Conjugate $ scaleReal Scalar.minusOne)
(Conjugate $ scaleReal Scalar.minusOne)
(Conjugate $ scaleReal Scalar.minusOne)
(Conjugate $ scaleReal Scalar.minusOne)
raise :: (Shape.C sh, Class.Floating a) => a -> Vector sh a -> Vector sh a
raise c (Array shX x) =
Array.unsafeCreateWithSize shX $ \n yPtr -> evalContT $ do
nPtr <- Call.cint n
cPtr <- Call.number c
onePtr <- Call.number Scalar.one
inccPtr <- Call.cint 0
xPtr <- ContT $ withForeignPtr x
inc1Ptr <- Call.cint 1
liftIO $ do
Blas.copy nPtr xPtr inc1Ptr yPtr inc1Ptr
Blas.axpy nPtr onePtr cPtr inccPtr yPtr inc1Ptr
mul ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Vector sh a -> Vector sh a -> Vector sh a
mul = mulConjugation NonConjugated
mulConj ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Vector sh a -> Vector sh a -> Vector sh a
mulConj = mulConjugation Conjugated
mulConjugation ::
(Shape.C sh, Eq sh, Class.Floating a) =>
Conjugation -> Vector sh a -> Vector sh a -> Vector sh a
mulConjugation conj (Array shA a) (Array shX x) =
Array.unsafeCreateWithSize shX $ \n yPtr -> do
Call.assert "mul: shapes mismatch" (shA == shX)
evalContT $ do
aPtr <- ContT $ withForeignPtr a
xPtr <- ContT $ withForeignPtr x
liftIO $ Private.mul conj n aPtr 1 xPtr 1 yPtr 1
newtype Conjugate f a = Conjugate {getConjugate :: f a -> f a}
conjugate ::
(Shape.C sh, Class.Floating a) =>
Vector sh a -> Vector sh a
conjugate =
getConjugate $
Class.switchFloating
(Conjugate id)
(Conjugate id)
(Conjugate complexConjugate)
(Conjugate complexConjugate)
complexConjugate ::
(Shape.C sh, Class.Real a) =>
Vector sh (Complex a) -> Vector sh (Complex a)
complexConjugate (Array sh x) = Array.unsafeCreateWithSize sh $ \n syPtr ->
evalContT $ do
nPtr <- Call.cint n
sxPtr <- ContT $ withForeignPtr x
incxPtr <- Call.cint 1
incyPtr <- Call.cint 1
liftIO $ copyConjugate nPtr sxPtr incxPtr syPtr incyPtr
fromReal ::
(Shape.C sh, Class.Floating a) => Vector sh (RealOf a) -> Vector sh a
fromReal =
getFromReal $
Class.switchFloating
(FromReal id)
(FromReal id)
(FromReal complexFromReal)
(FromReal complexFromReal)
newtype FromReal f a = FromReal {getFromReal :: f (RealOf a) -> f a}
toComplex ::
(Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh (ComplexOf a)
toComplex =
getToComplex $
Class.switchFloating
(ToComplex complexFromReal)
(ToComplex complexFromReal)
(ToComplex id)
(ToComplex id)
newtype ToComplex f a = ToComplex {getToComplex :: f a -> f (ComplexOf a)}
complexFromReal ::
(Shape.C sh, Class.Real a) => Vector sh a -> Vector sh (Complex a)
complexFromReal (Array sh x) =
Array.unsafeCreateWithSize sh $ \n yPtr ->
case realPtr yPtr of
yrPtr -> evalContT $ do
nPtr <- Call.cint n
xPtr <- ContT $ withForeignPtr x
incxPtr <- Call.cint 1
incyPtr <- Call.cint 2
inczPtr <- Call.cint 0
zPtr <- Call.number Scalar.zero
liftIO $ do
Blas.copy nPtr xPtr incxPtr yrPtr incyPtr
Blas.copy nPtr zPtr inczPtr (advancePtr yrPtr 1) incyPtr
realPart ::
(Shape.C sh, Class.Floating a) => Vector sh a -> Vector sh (RealOf a)
realPart =
getToReal $
Class.switchFloating
(ToReal id)
(ToReal id)
(ToReal $ RowMajor.takeColumn ixReal . decomplex)
(ToReal $ RowMajor.takeColumn ixReal . decomplex)
newtype ToReal f a = ToReal {getToReal :: f a -> f (RealOf a)}
imaginaryPart ::
(Shape.C sh, Class.Real a) => Vector sh (Complex a) -> Vector sh a
imaginaryPart = RowMajor.takeColumn ixImaginary . decomplex
zipComplex ::
(Shape.C sh, Eq sh, Class.Real a) =>
Vector sh a -> Vector sh a -> Vector sh (Complex a)
zipComplex (Array shr xr) (Array shi xi) =
Array.unsafeCreateWithSize shr $ \n yPtr -> evalContT $ do
liftIO $ Call.assert "zipComplex: shapes mismatch" (shr==shi)
nPtr <- Call.cint n
xrPtr <- ContT $ withForeignPtr xr
xiPtr <- ContT $ withForeignPtr xi
let yrPtr = realPtr yPtr
incxPtr <- Call.cint 1
incyPtr <- Call.cint 2
liftIO $ do
Blas.copy nPtr xrPtr incxPtr yrPtr incyPtr
Blas.copy nPtr xiPtr incxPtr (advancePtr yrPtr 1) incyPtr
unzipComplex ::
(Shape.C sh, Class.Real a) =>
Vector sh (Complex a) -> (Vector sh a, Vector sh a)
unzipComplex x = (realPart x, imaginaryPart x)