{-# LANGUAGE FlexibleInstances , OverlappingInstances #-} ----------------------------------------------------------------------------- -- | -- Module : SparseVector -- Copyright : (c) Grzegorz Chrupala 2008 -- Maintainer : gchrupala@lsv.uni-saarland.de -- -- A typeclass for sparse arrays -- module SparseVector ( Vector(..) , dot , minus ) where import qualified Data.Map as Map import qualified Data.IntMap as IntMap import Data.List (foldl') import qualified Prelude import Prelude hiding (abs,sum) class Vector v where -- | Elementwise addition plus :: v -> v -> v -- | Multiplication by scalar scale :: v -> Double -> v -- | Elementwise multiplication mult :: v -> v -> v -- | Summation sum :: v -> Double sum = foldl' (+) 0 . elems -- | Elementwise absolute value abs :: v -> v -- | Conversion to list of elements elems :: v -> [Double] -- | Zero vector zero :: v -- | Dot product dot :: (Vector v) => v -> v -> Double dot v w = sum (mult v w) -- | Elementwise subtraction minus :: (Vector v) => v -> v -> v v `minus` w = v `plus` (w `scale` (-1)) instance Vector Double where v `plus` w = v + w v `scale` n = v * n sum = id v `mult` w = v * w abs = abs elems = return zero = 0 instance (Ord k, Eq v, Vector v) => Vector (Map.Map k v) where v `plus` w = Map.unionWith plus v w v `scale` n = Map.map (`scale` n) v v `mult` w = intersectWith Map.size Map.insert Map.foldrWithKey Map.findWithDefault mult v w abs = Map.map abs elems v = do v' <- Map.elems v e <- elems v' return e zero = Map.empty instance (Eq v, Vector v) => Vector (IntMap.IntMap v) where v `plus` w = IntMap.unionWith plus v w v `scale` n = IntMap.map (`scale` n) v v `mult` w = intersectWith IntMap.size IntMap.insert IntMap.foldWithKey IntMap.findWithDefault mult v w abs = IntMap.map abs elems v = do v' <- IntMap.elems v e <- elems v' return e zero = IntMap.empty flipIf b = if b then flip else id intersectWith size insert foldWithKey findWithDefault f x y = flipIf (size x > size y) --(\v w -> mapWithKey (\k a -> f a (findWithDefault zero k w)) v) (\v w -> foldWithKey (\k a m -> let a' = f a (findWithDefault zero k w) in if a' == zero then m else insert k a' m) zero v) x y