module Linear.Vector
( Additive(..)
, negated
, (^*)
, (*^)
, (^/)
, basis
, basisFor
) where
import Control.Applicative
import Data.Complex
import Data.Foldable (foldMap)
import Data.Functor.Bind
import Data.Functor.Identity
import Data.HashMap.Lazy as HashMap
import Data.Hashable
import Data.IntMap as IntMap
import Data.Map as Map
import Data.Monoid (Sum(..))
import Data.Traversable (Traversable, mapAccumL)
import Linear.Instances ()
infixl 6 ^+^, ^-^
infixl 7 ^*, *^, ^/
class Bind f => Additive f where
zero :: Num a => f a
#ifndef HLINT
default zero :: (Applicative f, Num a) => f a
zero = pure 0
#endif
(^+^) :: Num a => f a -> f a -> f a
#ifndef HLINT
default (^+^) :: (Num a) => f a -> f a -> f a
(^+^) = liftU2 (+)
#endif
(^-^) :: Num a => f a -> f a -> f a
#ifndef HLINT
default (^-^) :: (Num a) => f a -> f a -> f a
x ^-^ y = x ^+^ negated y
#endif
lerp :: Num a => a -> f a -> f a -> f a
lerp alpha u v = alpha *^ u ^+^ (1 alpha) *^ v
liftU2 :: (a -> a -> a) -> f a -> f a -> f a
#ifndef HLINT
default liftU2 :: (Applicative f) => (a -> a -> a) -> f a -> f a -> f a
liftU2 = liftA2
#endif
instance Additive IntMap where
zero = IntMap.empty
liftU2 = IntMap.unionWith
instance Ord k => Additive (Map k) where
zero = Map.empty
liftU2 = Map.unionWith
instance (Eq k, Hashable k) => Additive (HashMap k) where
zero = HashMap.empty
liftU2 = HashMap.unionWith
instance Additive ((->) b)
instance Additive Complex
instance Additive Identity
negated :: (Functor f, Num a) => f a -> f a
negated = fmap negate
(*^) :: (Functor f, Num a) => a -> f a -> f a
(*^) a = fmap (a*)
(^*) :: (Functor f, Num a) => f a -> a -> f a
f ^* a = fmap (*a) f
(^/) :: (Functor f, Fractional a) => f a -> a -> f a
f ^/ a = fmap (/a) f
setElement :: Traversable t => Int -> a -> t a -> t a
setElement i x = snd . mapAccumL aux 0
where aux j y = let j' = j + 1
y' = if i == j then x else y
in j' `seq` (j', y')
basis :: (Applicative t, Traversable t, Num a) => [t a]
basis = [ setElement k 1 z | k <- [0..n 1] ]
where z = pure 0
n = getSum $ foldMap (const (Sum 1)) z
basisFor :: (Traversable t, Enum a, Num a) => t a -> [t a]
basisFor v = [ setElement k 1 z | k <- [0..n1] ]
where z = 0 <$ v
n = getSum $ foldMap (const (Sum 1)) v