module Data.LinFunc (LinFunc, Module(..), var, varSum, (*&), vsum, combination, linCombination) where
import Control.Monad
import qualified Data.Map as M
import qualified Data.IntMap as IM
import Data.Ratio
import Data.Array.Base
import Data.Array.IArray
import Data.LinFunc.Class
type LinFunc = M.Map
instance Module Int Int where
(*^) = (*)
zero = 0
(^+^) = (+)
(^-^) = ()
neg = negate
instance Module Double Double where
(*^) = (*)
zero = 0
(^+^) = (+)
(^-^) = ()
neg = negate
instance Module Integer Integer where
(*^) = (*)
zero = 0
(^+^) = (+)
(^-^) = ()
neg = negate
instance Integral a => Module (Ratio a) (Ratio a) where
(*^) = (*)
zero = 0
(^+^) = (+)
(^-^) = ()
neg = negate
instance Module r m => Module r (a -> m) where
(*^) = fmap . (*^)
zero = const zero
(^+^) = liftM2 (^+^)
(^-^) = liftM2 (^-^)
neg = fmap neg
instance (Ord k, Module r m) => Module r (M.Map k m) where
(*^) = fmap . (*^)
zero = M.empty
(^+^) = M.unionWith (^+^)
neg = fmap neg
instance Module r m => Module r (IM.IntMap m) where
(*^) = fmap . (*^)
zero = IM.empty
(^+^) = IM.unionWith (^+^)
neg = fmap neg
instance (Module r m) => Module r (Array Int m) where
(*^) = amap . (*^)
zero = listArray (0,0) [zero]
a ^+^ b | numElements a >= numElements b
= accum (^+^) a (assocs b)
| otherwise
= accum (^+^) b (assocs a)
a ^-^ b | numElements a >= numElements b
= accum (^-^) a (assocs b)
| otherwise
= accum (^-^) b (assocs a)
neg = amap neg
instance (IArray UArray m, Module r m) => Module r (UArray Int m) where
(*^) = amap . (*^)
zero = listArray (0,0) [zero]
a ^+^ b | numElements a >= numElements b
= accum (^+^) a (assocs b)
| otherwise
= accum (^+^) b (assocs a)
a ^-^ b | numElements a >= numElements b
= accum (^-^) a (assocs b)
| otherwise
= accum (^-^) b (assocs a)
neg = amap neg
var :: (Ord v, Num c) => v -> LinFunc v c
var v = M.singleton v 1
(*&) :: (Ord v, Num c) => c -> v -> LinFunc v c
c *& v = M.singleton v c
varSum :: (Ord v, Num c) => [v] -> LinFunc v c
varSum vs = M.fromList [(v, 1) | v <- vs]
vsum :: Module r v => [v] -> v
vsum = foldr (^+^) zero
combination :: Module r m => [(r, m)] -> m
combination xs = vsum [r *^ m | (r, m) <- xs]
linCombination :: (Ord v, Num r) => [(r, v)] -> LinFunc v r
linCombination xs = M.fromListWith (+) [(v, r) | (r, v) <- xs]