{-# language FlexibleContexts, TypeFamilies #-}
{-# language DeriveFunctor, DeriveFoldable #-}
module Data.Sparse.SpVector where
import Control.Exception
import Control.Monad.Catch (MonadThrow (..))
import Control.Exception.Common
import GHC.Exts
import Data.Sparse.Utils
import Data.Sparse.Types
import Data.Sparse.Internal.IntM
import Numeric.Eps
import Numeric.LinearAlgebra.Class
import Data.Complex
import Data.Maybe
import Text.Printf
import qualified Data.IntMap.Strict as IM
import qualified Data.Foldable as F
import qualified Data.Vector as V
data SpVector a = SV { svDim :: {-# UNPACK #-} !Int ,
svData :: !(IntM a)} deriving (Eq, Functor, Foldable)
instance Show a => Show (SpVector a) where
show (SV d x) = "SV (" ++ show d ++ ") "++ show (toList x)
spySV :: Fractional b => SpVector a -> b
spySV s = fromIntegral (size (dat s)) / fromIntegral (dim s)
nzSV :: SpVector a -> Int
nzSV sv = size (dat sv)
sizeStrSV :: SpVector a -> String
sizeStrSV sv = unwords ["(",show (dim sv),"elements ) , ",show (nzSV sv),"NZ ( density", sys,")"] where
sy = spy sv :: Double
sys = printf "%1.3f %%" (sy * 100) :: String
instance Set SpVector where
liftU2 f2 (SV n1 x1) (SV n2 x2) = SV (max n1 n2) (liftU2 f2 x1 x2)
liftI2 f2 (SV n1 x1) (SV n2 x2) = SV (max n1 n2) (liftI2 f2 x1 x2)
foldlWithKeySV, foldlWithKeySV' :: (a -> IM.Key -> b -> a) -> a -> SpVector b -> a
foldlWithKeySV f d v = foldlWithKey f d (svData v)
foldlWithKeySV' f d v = foldlWithKey' f d (svData v)
instance FiniteDim (SpVector a) where
type FDSize (SpVector a) = Int
dim = svDim
instance HasData (SpVector a) where
type HDData (SpVector a) = IntM a
dat = svData
nnz (SV _ x) = length x
instance Sparse (SpVector a) where
spy = spySV
instance Elt a => SpContainer (SpVector a) where
type ScIx (SpVector a) = Int
type ScElem (SpVector a) = a
scInsert = insertSpVector
scLookup v i = lookupSV i v
scToList = toListSV
v @@ i = lookupDenseSV i v
instance AdditiveGroup a => AdditiveGroup (SpVector a) where
zeroV = SV 0 zeroV
(^+^) = liftU2 (^+^)
negateV v = fmap negateV v
instance VectorSpace a => VectorSpace (SpVector a) where
type Scalar (SpVector a) = Scalar a
n .* v = fmap (n .*) v
instance InnerSpace a => InnerSpace (SpVector a) where
v <.> w = sum $ liftI2 (<.>) v w
instance (Normed a, Magnitude a ~ RealScalar a, RealScalar a ~ Scalar a) => Normed (SpVector a) where
type Magnitude (SpVector a) = Magnitude a
type RealScalar (SpVector a) = RealScalar a
norm1 = sum . fmap norm1
norm2Sq = sum . fmap norm2Sq
normP p v = (sum (fmap (\x -> normP p x ** p) v)) ** (1 / p)
normalize p v = v ./ normP p v
normalize2 v = v ./ norm2 v
normalize2' v = v ./ norm2' v
norm2 c = sqrt (norm2Sq c)
norm2' c = sqrt (norm2Sq c)
dotS :: InnerSpace t => SpVector t -> SpVector t -> Scalar (IntM t)
(SV m a) `dotS` (SV n b)
| n == m = a <.> b
| otherwise = error $ unwords ["<.> : Incompatible dimensions:", show m, show n]
dotSSafe :: (InnerSpace t, MonadThrow m) =>
SpVector t -> SpVector t -> m (Scalar (IntM t))
dotSSafe (SV m a) (SV n b)
| n == m = return $ a <.> b
| otherwise = throwM (DotSizeMismatch m n)
zeroSV :: Int -> SpVector a
zeroSV n = SV n empty
singletonSV :: a -> SpVector a
singletonSV x = SV 1 (singleton 0 x)
ei :: Num a => Int -> IM.Key -> SpVector a
ei n i = SV n (insert (i - 1) 1 empty)
mkSpVector :: Epsilon a => Int -> IM.IntMap a -> SpVector a
mkSpVector d im = SV d $ IntM $ IM.filterWithKey (\k v -> isNz v && inBounds0 d k) im
mkSpVector1 :: Int -> IM.IntMap a -> SpVector a
mkSpVector1 d ll = SV d $ IntM $ IM.filterWithKey (\ k _ -> inBounds0 d k) ll
mkSpVR :: Int -> [Double] -> SpVector Double
mkSpVR d ll = SV d $ mkIm ll
mkSpVC :: Int -> [Complex Double] -> SpVector (Complex Double)
mkSpVC d ll = SV d $ mkImC ll
fromListDenseSV :: Int -> [a] -> SpVector a
fromListDenseSV d ll = SV d (fromList $ indexed (take d ll))
spVectorDenseIx :: Epsilon a => (Int -> a) -> UB -> [Int] -> SpVector a
spVectorDenseIx f n ix =
fromListSV n $ filter q $ zip ix $ map f ix where
q (i, v) = inBounds0 n i && isNz v
spVectorDenseLoHi :: Epsilon a => (Int -> a) -> UB -> Int -> Int -> SpVector a
spVectorDenseLoHi f n lo hi = spVectorDenseIx f n [lo .. hi]
oneHotSVU :: Num a => Int -> IxRow -> SpVector a
oneHotSVU n k = SV n (singleton k 1)
oneHotSV :: Num a => Int -> IxRow -> SpVector a
oneHotSV n k |inBounds0 n k = oneHotSVU n k
|otherwise = error "`oneHotSV n k` must satisfy 0 <= k <= n"
onesSV :: Num a => Int -> SpVector a
onesSV d = constv d 1
zerosSV :: Num a => Int -> SpVector a
zerosSV d = constv d 0
constv :: Int -> a -> SpVector a
constv d x = SV d $ fromList $ indexed $ replicate d x
fromVector :: V.Vector a -> SpVector a
fromVector qv = V.ifoldl' ins (zeroSV n) qv where
n = V.length qv
ins vv i x = insertSpVector i x vv
toVector :: SpVector a -> V.Vector a
toVector = V.fromList . snd . unzip . toListSV
toVectorDense :: Num a => SpVector a -> V.Vector a
toVectorDense = V.fromList . toDenseListSV
insertSpVector :: IM.Key -> a -> SpVector a -> SpVector a
insertSpVector i x (SV d xim) | inBounds0 d i = SV d (insert i x xim)
insertSpVectorSafe :: MonadThrow m => Int -> a -> SpVector a -> m (SpVector a)
insertSpVectorSafe i x (SV d xim)
| inBounds0 d i = return $ SV d (insert i x xim)
| otherwise = throwM (OOBIxError "insertSpVector" i)
fromListSV :: Foldable t => Int -> t (Int, a) -> SpVector a
fromListSV d iix = SV d $ foldr insf empty iix where
insf (i, x) xacc | inBounds0 d i = insert i x xacc
| otherwise = xacc
createv :: [a] -> SpVector a
createv ll = fromListSV n $ indexed ll where n = length ll
vr :: [Double] -> SpVector Double
vr = createv
vc :: [Complex Double] -> SpVector (Complex Double)
vc = createv
toListSV :: SpVector a -> [(Int, a)]
toListSV sv = toList (dat sv)
toDenseListSV :: Num b => SpVector b -> [b]
toDenseListSV (SV d (IntM im)) = fmap (\i -> IM.findWithDefault 0 i im) [0 .. d-1]
ifoldSV :: (IM.Key -> a -> b -> b) -> b -> SpVector a -> b
ifoldSV f e (SV _ (IntM im)) = IM.foldrWithKey f e im
lookupSV :: IM.Key -> SpVector a -> Maybe a
lookupSV i (SV _ (IntM im)) = IM.lookup i im
lookupDefaultSV :: a -> IM.Key -> SpVector a -> a
lookupDefaultSV def i (SV _ (IntM im)) = IM.findWithDefault def i im
lookupDenseSV :: Num a => IM.Key -> SpVector a -> a
lookupDenseSV = lookupDefaultSV 0
tailSV :: SpVector a -> SpVector a
tailSV (SV n (IntM sv)) = SV (n-1) $ IntM ta where
ta = IM.mapKeys (\i -> i - 1) $ IM.delete 0 sv
headSV :: Num a => SpVector a -> a
headSV (SV _ (IntM im)) = fromMaybe 0 (IM.lookup 0 im)
takeSV, dropSV :: Int -> SpVector a -> SpVector a
takeSV n (SV _ sv) = SV n $ filterWithKey (\i _ -> i < n) sv
dropSV n (SV n0 (IntM sv)) = SV (n0 - n) $ IntM $ IM.mapKeys (subtract n) $ IM.filterWithKey (\i _ -> i >= n) sv
rangeSV :: (IM.Key, IM.Key) -> SpVector a -> SpVector a
rangeSV (rmin, rmax) (SV n (IntM sv))
| len > 0 && len <= n = SV len $ IntM sv'
| otherwise = error $ unwords ["rangeSV : invalid bounds", show (rmin, rmax) ] where
len = rmax - rmin
sv' = IM.mapKeys (subtract rmin) $ IM.filterWithKey (\i _ -> i >= rmin && i <= rmax) sv
concatSV :: SpVector a -> SpVector a -> SpVector a
concatSV (SV n1 (IntM s1)) (SV n2 (IntM s2)) = SV (n1+n2) $ IntM (IM.union s1 s2') where
s2' = IM.mapKeys (+ n1) s2
filterSV :: (a -> Bool) -> SpVector a -> SpVector a
filterSV q sv = SV (dim sv) $ IntM (IM.filter q (unIM $ dat sv))
ifilterSV :: (Int -> a -> Bool) -> SpVector a -> SpVector a
ifilterSV q sv = SV (dim sv) (filterWithKey q (dat sv))
sparsifySV :: Epsilon a => SpVector a -> SpVector a
sparsifySV = filterSV isNz
orthogonalSV :: (Scalar (SpVector t) ~ t, InnerSpace (SpVector t), Fractional t) =>
SpVector t -> SpVector t
orthogonalSV v = u where
(h, t) = (headSV v, tailSV v)
n = dim v
v2 = onesSV (n - 1)
yn = singletonSV $ - (v2 `dot` t)/h
u = concatSV yn v2