{-# LANGUAGE FlexibleInstances , OverlappingInstances #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  SparseVector
-- Copyright   :  (c) Grzegorz Chrupala 2008
-- Maintainer  :  gchrupala@lsv.uni-saarland.de
--
-- A typeclass for sparse arrays
-- 

module SparseVector 
    ( Vector(..)
    , dot
    , minus
    )
where
import qualified Data.Map as Map 
import qualified Data.IntMap as IntMap
import Data.List (foldl')
import qualified Prelude 
import Prelude hiding (abs,sum)


class Vector v  where 
    -- | Elementwise addition
    plus :: v -> v -> v
    -- | Multiplication by scalar
    scale :: v -> Double -> v
    -- | Elementwise multiplication
    mult :: v -> v -> v
    -- | Summation
    sum  :: v -> Double
    sum  = foldl' (+) 0 . elems
    -- | Elementwise absolute value    
    abs :: v -> v
    -- | Conversion to list of elements
    elems :: v -> [Double]
    -- | Zero vector
    zero :: v

-- | Dot product
dot :: (Vector v) => v -> v -> Double
dot v w = sum (mult v w)

-- | Elementwise subtraction
minus :: (Vector v) => v -> v -> v
v `minus` w = v `plus` (w `scale` (-1))



instance Vector Double where
    v `plus` w = v + w
    v `scale` n = v * n
    sum = id
    v `mult` w = v * w
    abs = abs
    elems = return
    zero = 0

instance (Ord k, Eq v, Vector v) => Vector (Map.Map k v) where
    v `plus` w = Map.unionWith plus v w
    v `scale` n = Map.map (`scale` n) v
    v `mult` w   = intersectWith Map.size Map.insert Map.foldrWithKey Map.findWithDefault mult v w
    abs = Map.map abs
    elems v = do
      v' <- Map.elems v
      e <- elems v'
      return e
    zero = Map.empty

instance (Eq v, Vector v) => Vector (IntMap.IntMap v) where
    v `plus` w = IntMap.unionWith plus v w
    v `scale` n = IntMap.map (`scale` n) v
    v `mult` w = intersectWith IntMap.size IntMap.insert IntMap.foldWithKey IntMap.findWithDefault mult v w
    abs = IntMap.map abs
    elems v = do
      v' <- IntMap.elems v
      e <- elems v'
      return e
    zero = IntMap.empty


flipIf b = if b then flip else id

intersectWith size insert foldWithKey findWithDefault f x y = 
    flipIf (size x > size y)
               --(\v w -> mapWithKey (\k a -> f a (findWithDefault zero k w)) v)
                 (\v w -> foldWithKey (\k a m -> let a' = f a (findWithDefault zero k w)
                                                 in  if a' == zero then m else insert k a' m) zero v)
               x y