module Linear.Vector
( Additive(..)
, negated
, (^*)
, (*^)
, (^/)
, basis
, basisFor
) where
import Control.Applicative
import Data.Complex
import Data.Foldable (foldMap)
import Data.Functor.Bind
import Data.HashMap.Lazy as HashMap
import Data.Hashable
import Data.IntMap as IntMap
import Data.Map as Map
import Data.Monoid (Sum(..))
import Data.Traversable (Traversable, mapAccumL)
import Linear.Instances ()
infixl 6 ^+^, ^-^
infixl 7 ^*, *^, ^/
class Bind f => Additive f where
zero :: Num a => f a
#ifndef HLINT
default zero :: (Applicative f, Num a) => f a
zero = pure 0
#endif
(^+^) :: Num a => f a -> f a -> f a
#ifndef HLINT
default (^+^) :: (Applicative f, Num a) => f a -> f a -> f a
(^+^) = liftA2 (+)
#endif
(^-^) :: Num a => f a -> f a -> f a
#ifndef HLINT
default (^-^) :: (Applicative f, Num a) => f a -> f a -> f a
(^-^) = liftA2 ()
#endif
lerp :: Num a => a -> f a -> f a -> f a
lerp alpha u v = alpha *^ u ^+^ (1 alpha) *^ v
instance Additive IntMap where
zero = IntMap.empty
(^+^) = IntMap.unionWith (+)
xs ^-^ ys = IntMap.unionWith (+) xs (negated ys)
instance Ord k => Additive (Map k) where
zero = Map.empty
(^+^) = Map.unionWith (+)
xs ^-^ ys = Map.unionWith (+) xs (negated ys)
instance (Eq k, Hashable k) => Additive (HashMap k) where
zero = HashMap.empty
(^+^) = HashMap.unionWith (+)
xs ^-^ ys = HashMap.unionWith (+) xs (negated ys)
instance Additive ((->) b)
instance Additive Complex
negated :: (Functor f, Num a) => f a -> f a
negated = fmap negate
(*^) :: (Functor f, Num a) => a -> f a -> f a
(*^) a = fmap (a*)
(^*) :: (Functor f, Num a) => f a -> a -> f a
f ^* a = fmap (*a) f
(^/) :: (Functor f, Fractional a) => f a -> a -> f a
f ^/ a = fmap (/a) f
setElement :: Traversable t => Int -> a -> t a -> t a
setElement i x = snd . mapAccumL aux 0
where aux j y = let j' = j + 1
y' = if i == j then x else y
in j' `seq` (j', y')
basis :: (Applicative t, Traversable t, Num a) => [t a]
basis = [ setElement k 1 z | k <- [0..n 1] ]
where z = pure 0
n = getSum $ foldMap (const (Sum 1)) z
basisFor :: (Traversable t, Enum a, Num a) => t a -> [t a]
basisFor v = [ setElement k 1 z | k <- [0..n1] ]
where z = 0 <$ v
n = getSum $ foldMap (const (Sum 1)) v