module Math.LinearAlgebra.Sparse.Vector
-- TODO: explicit export list
where

import Data.Functor
import Data.Foldable as F
import Data.List     as L
import Data.IntMap   as M hiding ((!))
import Data.Monoid

--------------------------------------------------------------------------------
-- IntMap Utilities (for internal use) --
-----------------------------------------

-- | Dot product of two `IntMap`s (for internal use)
(··) :: (Num α) => IntMap α -> IntMap α -> α
v ·· w = M.foldlWithKey' f 0 v
    where f acc 0 _ = acc
          f acc i x = acc + ((findWithDefault 0 i w) * x)

-- | Shifts (re-enumerates) keys of IntMap by given number
shiftKeys :: Int -> IntMap α -> IntMap α
shiftKeys k m = M.fromAscList [ (i+k,x) | (i,x) <- M.toAscList m ]

-- | Adds element to the map at given index, shifting all keys after it
addElem :: Maybe α -> Int -> IntMap α -> IntMap α
addElem v i m = M.unions [a, maybeSingleton i v, maybeSingleton (i+1) x, shiftKeys 1 b]
     where (a, x, b) = M.splitLookup i m
           maybeSingleton k = maybe M.empty (singleton k)

-- | Deletes element of the map at given index, shifting all keys after it
delElem :: Int -> IntMap α -> IntMap α
delElem i m = a `M.union` (shiftKeys (-1) b)
    where (a,b) = M.split i m

-- | Splits map using predicate and returns a pair with filtered map and
--   re-enumereted second part (that doesn't satisfy predicate). For example:
--
-- >>> partitionMap (>0) (fromList [(1,1),(2,-1),(3,-2),(4,3),(5,-4)])
-- ( fromList [(1,1),(4,3)], fromList [(1,-1),(2,-2),(3,-4)] )
-- 
partitionMap :: (α -> Bool) -> IntMap α -> (IntMap α, IntMap α)
partitionMap p m = (m', f')
    where f  = M.filter (not . p) m
          f' = M.fromAscList $ zip [1..] (M.elems f)
          m' = L.foldl (\mm j -> delElem j mm) m $ L.reverse (M.keys f)

--------------------------------------------------------------------------------
-- SPARSE VECTOR DATATYPE --
----------------------------

type Index = Int

-- | Type of internal vector storage
type SVec α = IntMap α

-- | Sparse vector is just indexed map of non-zero values
data SparseVector α = SV
     { dim :: Int    -- ^ real size of vector (with zeroes)
     , vec :: SVec α -- ^ IntMap storing non-zero values
     } deriving Eq

-- | Sets vector's size
setLength ::  Int -> SparseVector α -> SparseVector α
setLength n v = v { dim = n }

-- | Vector of zero size with no values
emptyVec :: SparseVector α
emptyVec = SV 0 M.empty

-- | Vector of given size with no non-zero values
zeroVec ::  Int -> SparseVector α
zeroVec n = setLength n emptyVec

-- | Vector of length 1 with given value
singVec :: (Eq α, Num α) => α -> SparseVector α
singVec 0 = zeroVec 1
singVec x = SV 1 (singleton 1 x)

-- | `fmap` applies given function on vector non-zero values
instance Functor SparseVector where
    fmap f v = v {vec = fmap f (vec v)}

-- | fold functions are applied to non-zero values
instance Foldable SparseVector where
    foldr f d v = F.foldr f d (vec v)

-- | `Num` operators like @(*)@, @(+)@ and @(-)@ work on sparse vectors 
--   like @`zipWith` (…)@ works on lists, except size of result is maximum
--   of arguments sizes.
--
--   `signum`, `abs` and `negate` work through `fmap`, so usual `Num` laws
--   are satisfied (such as @(signum v)*(abs v) == v@.
--
--   `fromInteger` constructs a vector with single element. So, 
--
-- >>> 3 + (sparseList [0,2,1])
-- sparseList [3,2,1]
--
instance (Eq α, Num α) => Num (SparseVector α) where
    (SV n v) + (SV m w) = SV (max n m) (unionWith (+) v w)
    (SV n v) * (SV m w) = SV (max n m) (intersectionWith (*) v w)
    negate              = fmap negate
    fromInteger x       = singVec (fromInteger x)
    abs                 = fmap abs
    signum              = fmap signum

-- | Monoid `mappend` operation works like concatenation of two vectors 
--   (indexes of second are shifted by length of first)
instance Monoid (SparseVector α) where
    (SV n v) `mappend` (SV m w) = SV (n+m) (v `M.union` (shiftKeys n w))
    mempty = emptyVec

-- | This is like cons (`:`) operator for lists.
--
--   @x .> v = singVec x \<\> v@
--
(.>) :: (Eq α, Num α) => α -> SparseVector α -> SparseVector α
x .> v = singVec x <> v

--------------------------------------------------------------------------------
-- FILTER --
------------

-- | Splits vector using predicate and returns a pair with filtered values and
--   re-enumereted second part (that doesn't satisfy predicate). For example:
--
-- >>> partitionVec (>0) (sparseList [0,1,-1,2,3,0,-4,5,-6,0,7])
-- ( sparseList [0,1,0,2,3,0,0,5,0,0,7], sparseList [-1,-4,-6] )
-- 
partitionVec :: (Num α) => (α -> Bool) -> SparseVector α -> (SparseVector α, SparseVector α)
partitionVec p (SV d v) = (SV st t, SV (d-st) f)
    where (t,f) = partitionMap p v
          st = size t

--------------------------------------------------------------------------------
-- LOOKUP/UPDATE --
-------------------

-- | Looks up an element in the vector (if not found, zero is returned)
(!) :: (Num α) => SparseVector α -> Index -> α
v ! i = findWithDefault 0 i (vec v)

-- | Deletes element of vector at given index (size of vector doesn't change)
eraseInVec :: (Num α) => SparseVector α -> Index -> SparseVector α
v `eraseInVec` j = v { vec = M.delete j (vec v) }

--------------------------------------------------------------------------------
-- TO/FROM LIST --
------------------

-----------
-- Vectors:

-- | Returns plain list with all zeroes restored
fillVec :: (Num α) => SparseVector α -> [α]
fillVec v = [ v ! i | i <- [1 .. dim v] ]

-- | Converts plain list to sparse vector, throwing out all zeroes
sparseList :: (Num α, Eq α) => [α] -> SparseVector α
sparseList l = SV (length l) $ M.fromList [ (i,x) | (i,x) <- zip [1..] l, x /= 0 ]

-- | Shows size and filled vector (but without zeroes)
instance (Show α, Eq α, Num α) => Show (SparseVector α) where
    show = showSparseList . fillVec

showSparseList :: (Show α, Eq α, Num α) => [α] -> String
showSparseList l = show (length l)++":  ["++
    (intercalate "|" $ L.map showNonZero l)++"]"

showNonZero x  = if x == 0 then " " else show x
               
--------------------------------------------------------------------------------
-- MULTIPLICATIONS --
---------------------

-- | Dot product of two sparse vectors
dot :: (Num α) => SparseVector α -> SparseVector α -> α
dot = (·)

-- | Unicode alias for `dot`
(·) :: (Num α) => SparseVector α -> SparseVector α -> α
(SV n v) · (SV m w) | n < m     = v ·· w  -- uses shorter vector
                    | otherwise = w ·· v