{-# LANGUAGE TypeFamilies, FlexibleContexts #-}
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable, DeriveDataTypeable #-}
module Data.Map.Vector (MapVector(..)) where

import Prelude hiding (foldr)
import Data.Foldable
import Data.Traversable
import Data.Data
import Control.Applicative
import Data.AdditiveGroup
import Data.VectorSpace
import Data.Map (Map)
import qualified Data.Map as Map

-- TODO: Clean up overlong >>> examples.

-- | Note: '<*>' in the 'Applicative' instance operates under /intersection/.  i.e.:
--
-- >>> (MapVector $ Map.fromList [("x", id)]) <*> (MapVector $ Map.fromList [("y", 3)])
-- MapVector (Map.fromList [])
--  
-- '*' in the 'Num' instance performs elementwise multiplication.  It is defined in terms of
-- '<*>' and therefore also operates under intersection:
--   
-- >>> (MapVector $ Map.fromList [("x", 2), ("y", 3)]) * (MapVector $ Map.fromList [("x", 5),("y", 7)])
-- MapVector (Map.fromList [("x", 10), ("y", 21)])
--   
-- >>> (MapVector $ Map.fromList [("x", 2), ("y", 3)]) * (MapVector $ Map.fromList [("y", 7)])
-- MapVector (Map.fromList [("y", 21)])
--
-- '*^' in the 'VectorSpace' instance multiplies by the scalar of v.  Nesting MapVectors preserves
-- the scalar type, e.g. @Scalar (MapVector k (MapVector k' v))@ = @Scalar v@.
--
-- >>> 2 *^ (ConstantMap $ MapVector $ Map.fromList [("x", 3 :: Int), ("y", 5)])
-- ConstantMap (MapVector (fromList [("x",6),("y",10)]))
--
-- Finally, '<.>' in 'InnerSpace' is the dot-product operator.  Again, it operates under intersection.
--
-- >>> (MapVector $ Map.fromList [("x", 2 :: Int), ("y", 3)]) <.> (MapVector $ Map.fromList [("x", 5),("y", 7)])
-- 31
--
-- >>> (pure . MapVector $ Map.fromList [("x", 2 :: Int), ("y", 3)]) <.> (MapVector $ Map.fromList [("a", pure (5::Int))])
-- 25
--
-- Addition, using either '+' or '^+^', operates under union.

data MapVector k v = 
      MapVector (Map k v) 
    | ConstantMap v -- ^ An infinite-dimensional vector with the same value on all dimensions
    deriving (Eq, Functor, Show, Read, Foldable, Traversable, Typeable, Data)

instance (Ord k) => Applicative (MapVector k) where 
    pure = ConstantMap
    (ConstantMap f) <*> (ConstantMap v) = ConstantMap $ f v
    (ConstantMap f) <*> (MapVector vs)  = MapVector   $ f     <$> vs
    (MapVector fs)  <*> (ConstantMap v) = MapVector   $ ($ v) <$> fs
    (MapVector fs)  <*> (MapVector vs)  = MapVector   $ Map.intersectionWith ($) fs vs
    {-# INLINABLE (<*>) #-}

instance (AdditiveGroup v, Ord k) => AdditiveGroup (MapVector k v) where
    zeroV = MapVector Map.empty
    negateV = fmap negateV
    (ConstantMap v) ^+^ (ConstantMap v') = ConstantMap $ v ^+^ v'
    (ConstantMap v) ^+^ (MapVector vs)   = MapVector   $ (v ^+^) <$> vs
    (MapVector vs)  ^+^ (ConstantMap v)  = MapVector   $ (^+^ v) <$> vs
    (MapVector vs)  ^+^ (MapVector vs')  = MapVector   $ Map.unionWith (^+^) vs vs'
    {-# INLINABLE (^+^) #-}
    
instance (Ord k, VectorSpace v) => VectorSpace (MapVector k v) where
    type Scalar (MapVector k v) = Scalar v  -- therefore Scalar (MapVector k (Mapvector k' v))
                                            --   = Scalar v
    s *^ v  = (s *^) <$> v 
    {-# INLINABLE (*^) #-}

instance (Ord k, VectorSpace v, InnerSpace v, AdditiveGroup (Scalar v)) => InnerSpace (MapVector k v) where
    (ConstantMap v) <.> (ConstantMap v') =                      v <.> v'
    (ConstantMap v) <.> (MapVector vs)   = foldl' (^+^) zeroV $ (v <.>) <$> vs
    (MapVector vs)  <.> (ConstantMap v)  = foldl' (^+^) zeroV $ (<.> v) <$> vs
    (MapVector vs)  <.> (MapVector vs')  = foldl' (^+^) zeroV $ Map.intersectionWith (<.>) vs vs'
    {-# INLINABLE (<.>) #-}

instance (Ord k, AdditiveGroup v, Num v) => Num (MapVector k v) where
    (+) = (^+^)
    {-# INLINE (+) #-}
    x * y = (*) <$> x <*> y
    {-# INLINE (*) #-}
    abs = fmap abs
    {-# INLINE abs #-}
    fromInteger = pure . fromInteger
    signum = error "no signum for MapVectors"

-- It looks like a HasBasis instance should be possible; 
-- I just haven't spent the time to figure it out.

-- (Remember to tighten version bounds on vector-space if this is implemented.)

--instance (Ord k, HasBasis v) => HasBasis (MapVector k v) where
--    type Basis (MapVector k v) = (k, Basis v)
--    basisValue (k, v) = MapVector $ Map.fromList $ (k, basisValue v):[]
--    
--    decompose (ConstantMap _) = error "decompose: not defined for ConstantMap"
--    decompose (MapVector vs) =