{-# language TypeFamilies, FlexibleInstances, MultiParamTypeClasses, CPP #-}
module Data.Sparse.Internal.SVector where
import Control.Arrow
import Control.Monad (unless)
import Control.Monad.IO.Class
import qualified Data.Foldable as F
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as VM
import Data.Complex
import Foreign.C.Types (CSChar, CInt, CShort, CLong, CLLong, CIntMax, CFloat, CDouble)
import Control.Monad.Primitive
import Data.Sparse.Internal.SVector.Mutable hiding (fromList)
import qualified Data.Sparse.Internal.SVector.Mutable as SMV (fromList)
import Numeric.LinearAlgebra.Class
data SVector a = SV { svDim :: {-# UNPACK #-} !Int,
svIx :: V.Vector Int,
svVal :: V.Vector a } deriving Eq
instance Show a => Show (SVector a) where
show (SV n ix v) = unwords ["SV (",show n,"),",show nz,"NZ:",show (V.zip ix v)]
where nz = V.length ix
instance Functor SVector where
fmap f (SV n ix v) = SV n ix (fmap f v)
instance Foldable SVector where
foldr f z (SV _ _ v) = foldr f z v
instance Traversable SVector where
traverse f (SV n ix v) = SV n ix <$> traverse f v
instance HasData (SVector a) where
nnz = length . svIx
fromDense :: V.Vector a -> SVector a
fromDense xs = SV n (V.enumFromTo 0 (n-1)) xs where
n = V.length xs
fromVector :: Int -> V.Vector (Int, a) -> SVector a
fromVector n ixs = SV n ix xs where
(ix, xs) = V.unzip ixs
fromList :: Int -> [(Int, a)] -> SVector a
fromList n = fromVector n . V.fromList
index :: SVector a -> Int -> Maybe a
index cv i =
case F.find (== i) (svIx cv) of
Just i' -> Just $ svVal cv V.! i'
Nothing -> Nothing
intersectWith :: (a -> b -> c) -> SVector a -> SVector b -> SVector c
intersectWith g v1 v2 = SV nfin ixf vf where
nfin = V.length vf
(ixf, vf) = V.unzip $ V.create (intersectWithM g v1 v2)
intersectWithM :: PrimMonad m =>
(a -> b -> c)
-> SVector a
-> SVector b
-> m (VM.MVector (PrimState m) (Int, c))
intersectWithM g (SV n1 ixu u) (SV n2 ixv v) = do
vm <- VM.new n
(vm', nfin) <- go ixu u ixv v 0 vm
return $ VM.take nfin vm'
where
n = min n1 n2
go ixu_ u_ ixv_ v_ i vm
| V.null u_ || V.null v_ || i == n = return (vm, i)
| otherwise = do
let (u, us) = headTail u_
(v, vs) = headTail v_
(ix1, ix1s) = headTail ixu_
(ix2, ix2s) = headTail ixv_
if ix1 == ix2 then do VM.write vm i (ix1, g u v)
go ix1s us ix2s vs (i + 1) vm
else if ix1 < ix2 then go ix1s us ixv v_ i vm
else go ixu u_ ix2s vs i vm
unionWith :: (t -> t -> a) -> t -> SVector t -> SVector t -> SVector a
unionWith g z v1 v2 = SV n ixf vf where
n = max (svDim v1) (svDim v2)
(ixf, vf) = V.unzip $ V.create (unionWithM g z v1 v2)
unionWithM :: PrimMonad m =>
(a -> a -> b)
-> a
-> SVector a
-> SVector a
-> m (VM.MVector (PrimState m) (Int, b))
unionWithM g z (SV n1 ixu u) (SV n2 ixv v) = do
vm <- VM.new n
(vm', nfin) <- go ixu u ixv v 0 vm
let vm'' = VM.take nfin vm'
return vm''
where
n = min n1 n2
go iu u_ iv v_ i vm
| (V.null u_ && V.null v_) || i == n = return (vm , i)
| V.null u_ = do
VM.write vm i (V.head iv, g z (V.head v_))
go iu u_ (V.tail iv) (V.tail v_) (i + 1) vm
| V.null v_ = do
VM.write vm i (V.head iu, g (V.head u_) z)
go (V.tail iu) (V.tail u_) iv v_ (i + 1) vm
| otherwise = do
let (u, us) = headTail u_
(v, vs) = headTail v_
(iu1, ius) = headTail iu
(iv1, ivs) = headTail iv
if iu1 == iv1 then do VM.write vm i (iu1, g u v)
go ius us ivs vs (i + 1) vm
else if iu1 < iv1 then do VM.write vm i (iu1, g u z)
go ius us iv v_ (i + 1) vm
else do VM.write vm i (iv1, g z v)
go iu u_ ivs vs (i + 1) vm
#define SVType(t) \
instance AdditiveGroup (SVector t) where {zeroV = SV 0 V.empty V.empty; (^+^) = unionWith (+) 0 ; negateV = fmap negate};\
instance VectorSpace (SVector t) where {type Scalar (SVector t) = (t); n .* x = fmap (* n) x };\
instance InnerSpace (SVector t) where {a <.> b = sum $ intersectWith (*) a b}
#define SVTypeC(t) \
instance AdditiveGroup (SVector (Complex t)) where {zeroV = SV 0 V.empty V.empty; (^+^) = unionWith (+) (0 :+ 0) ; negateV = fmap negate};\
instance VectorSpace (SVector (Complex t)) where {type Scalar (SVector (Complex t)) = (Complex t); n .* x = fmap (* n) x };\
instance InnerSpace (SVector (Complex t)) where {x <.> y = sum $ intersectWith (*) (conjugate <$> x) y};\
unionWithSMV :: PrimMonad m =>
(a -> a -> a)
-> a -> SVector a -> SMVector m a -> m (SMVector m a)
unionWithSMV g z (SV n ixu uu) (SMV n2 ixm_ vm_) = do
ixmnew <- VM.new nnzero
vmnew <- VM.new nnzero
unless (n == n2) (error "unionWithSMV : operand vectors must have the same length")
(ixm, vm, nfin) <- go 0 ixmnew vmnew
let ixm' = VM.take nfin ixm
vm' = VM.take nfin vm
return $ SMV n ixm' vm'
where
nnzero = lu + lv
lu = V.length ixu
lv = VM.length ixm_
go i ixm vm
| i == lu && i == lv || i == n = return (ixm, vm, i)
| i == lu = do
v0 <- VM.read vm i
VM.write ixm i i
VM.write vm i (g z v0)
go (i + 1) ixm vm
| i == lv = do
let u0 = uu V.! i
VM.write ixm_ i i
VM.write vm_ i (g u0 z)
go (i + 1) ixm vm
| otherwise = do
let u = uu V.! i
iu = ixu V.! i
v <- VM.read vm i
iv <- VM.read ixm i
if iu == iv then
do
VM.write ixm i iu
VM.write vm i (g u v)
go (i + 1) ixm vm
else if iu < iv then
do
VM.write ixm i iu
VM.write vm i (g u z)
go (i + 1) ixm vm
else
do
VM.write ixm i iv
VM.write vm i (g z v)
go (i + 1) ixm vm
thaw :: PrimMonad m => SVector a -> m (SMVector m a)
thaw (SV n ix v) = do
vm <- V.thaw v
ixm <- V.thaw ix
return $ SMV n ixm vm
freeze :: PrimMonad m => SMVector m a -> m (SVector a)
freeze (SMV n ixm vm) = do
v <- V.freeze vm
ix <- V.freeze ixm
return $ SV n ix v
both :: Arrow arr => arr b c -> arr (b, b) (c, c)
both f = f *** f
headTail :: V.Vector a -> (a, V.Vector a)
headTail = V.head &&& V.tail