#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
#endif
module Linear.Vector
( Additive(..)
, negated
, (^*)
, (*^)
, (^/)
, sumV
, basis
, basisFor
, kronecker
, outer
) where
import Control.Applicative
import Data.Complex
import Data.Foldable as Foldable (Foldable, foldMap, forM_, foldl')
import Data.Functor.Identity
import Data.HashMap.Lazy as HashMap
import Data.Hashable
import Data.IntMap as IntMap
import Data.Map as Map
import Data.Monoid (Sum(..), mempty)
import Data.Vector as Vector
import Data.Vector.Mutable as Mutable
import Data.Traversable (Traversable, mapAccumL)
import GHC.Generics
import Linear.Instances ()
infixl 6 ^+^, ^-^
infixl 7 ^*, *^, ^/
class GAdditive f where
gzero :: Num a => f a
gliftU2 :: (a -> a -> a) -> f a -> f a -> f a
gliftI2 :: (a -> b -> c) -> f a -> f b -> f c
instance GAdditive U1 where
gzero = U1
gliftU2 _ U1 U1 = U1
gliftI2 _ U1 U1 = U1
instance (GAdditive f, GAdditive g) => GAdditive (f :*: g) where
gzero = gzero :*: gzero
gliftU2 f (a :*: b) (c :*: d) = gliftU2 f a c :*: gliftU2 f b d
gliftI2 f (a :*: b) (c :*: d) = gliftI2 f a c :*: gliftI2 f b d
instance Additive f => GAdditive (Rec1 f) where
gzero = Rec1 zero
gliftU2 f (Rec1 g) (Rec1 h) = Rec1 (liftU2 f g h)
gliftI2 f (Rec1 g) (Rec1 h) = Rec1 (liftI2 f g h)
instance GAdditive f => GAdditive (M1 i c f) where
gzero = M1 gzero
gliftU2 f (M1 g) (M1 h) = M1 (gliftU2 f g h)
gliftI2 f (M1 g) (M1 h) = M1 (gliftI2 f g h)
instance GAdditive Par1 where
gzero = Par1 0
gliftU2 f (Par1 a) (Par1 b) = Par1 (f a b)
gliftI2 f (Par1 a) (Par1 b) = Par1 (f a b)
class Functor f => Additive f where
zero :: Num a => f a
#ifndef HLINT
default zero :: (GAdditive (Rep1 f), Generic1 f, Num a) => f a
zero = to1 gzero
#endif
(^+^) :: Num a => f a -> f a -> f a
#ifndef HLINT
default (^+^) :: Num a => f a -> f a -> f a
(^+^) = liftU2 (+)
#endif
(^-^) :: Num a => f a -> f a -> f a
#ifndef HLINT
default (^-^) :: Num a => f a -> f a -> f a
x ^-^ y = x ^+^ negated y
#endif
lerp :: Num a => a -> f a -> f a -> f a
lerp alpha u v = alpha *^ u ^+^ (1 alpha) *^ v
liftU2 :: (a -> a -> a) -> f a -> f a -> f a
#ifndef HLINT
default liftU2 :: Applicative f => (a -> a -> a) -> f a -> f a -> f a
liftU2 = liftA2
#endif
liftI2 :: (a -> b -> c) -> f a -> f b -> f c
#ifndef HLINT
default liftI2 :: Applicative f => (a -> b -> c) -> f a -> f b -> f c
liftI2 = liftA2
#endif
instance Additive ZipList where
zero = ZipList []
liftU2 f (ZipList xs) (ZipList ys) = ZipList (liftU2 f xs ys)
liftI2 = liftA2
instance Additive Vector where
zero = mempty
liftU2 f u v = case compare lu lv of
LT | lu == 0 -> v
| otherwise -> modify (\ w -> Foldable.forM_ [0..lu1] $ \i -> unsafeWrite w i $ f (unsafeIndex u i) (unsafeIndex v i)) v
EQ -> Vector.zipWith f u v
GT | lv == 0 -> u
| otherwise -> modify (\ w -> Foldable.forM_ [0..lv1] $ \i -> unsafeWrite w i $ f (unsafeIndex u i) (unsafeIndex v i)) u
where
lu = Vector.length u
lv = Vector.length v
liftI2 = Vector.zipWith
instance Additive Maybe where
zero = Nothing
liftU2 f (Just a) (Just b) = Just (f a b)
liftU2 _ Nothing ys = ys
liftU2 _ xs Nothing = xs
liftI2 = liftA2
instance Additive [] where
zero = []
liftU2 f = go where
go (x:xs) (y:ys) = f x y : go xs ys
go [] ys = ys
go xs [] = xs
liftI2 = Prelude.zipWith
instance Additive IntMap where
zero = IntMap.empty
liftU2 = IntMap.unionWith
liftI2 = IntMap.intersectionWith
instance Ord k => Additive (Map k) where
zero = Map.empty
liftU2 = Map.unionWith
liftI2 = Map.intersectionWith
instance (Eq k, Hashable k) => Additive (HashMap k) where
zero = HashMap.empty
liftU2 = HashMap.unionWith
liftI2 = HashMap.intersectionWith
instance Additive ((->) b) where
zero = const 0
liftU2 = liftA2
liftI2 = liftA2
instance Additive Complex where
zero = 0 :+ 0
liftU2 f (a :+ b) (c :+ d) = f a c :+ f b d
liftI2 f (a :+ b) (c :+ d) = f a c :+ f b d
instance Additive Identity where
zero = Identity 0
liftU2 = liftA2
liftI2 = liftA2
negated :: (Functor f, Num a) => f a -> f a
negated = fmap negate
sumV :: (Foldable f, Additive v, Num a) => f (v a) -> v a
sumV = Foldable.foldl' (^+^) zero
(*^) :: (Functor f, Num a) => a -> f a -> f a
(*^) a = fmap (a*)
(^*) :: (Functor f, Num a) => f a -> a -> f a
f ^* a = fmap (*a) f
(^/) :: (Functor f, Fractional a) => f a -> a -> f a
f ^/ a = fmap (/a) f
setElement :: Traversable t => Int -> a -> t a -> t a
setElement i x = snd . mapAccumL aux 0
where aux j y = let j' = j + 1
y' = if i == j then x else y
in j' `seq` (j', y')
basis :: (Applicative t, Traversable t, Num a) => [t a]
basis = [ setElement k 1 z | k <- [0..n 1] ]
where z = pure 0
n = getSum $ foldMap (const (Sum 1)) z
basisFor :: (Traversable t, Enum a, Num a) => t a -> [t a]
basisFor v = [ setElement k 1 z | k <- [0..n1] ]
where z = 0 <$ v
n = getSum $ foldMap (const (Sum 1)) v
kronecker :: (Applicative t, Num a, Traversable t) => t a -> t (t a)
kronecker v = snd $ mapAccumL aux 0 v
where aux i e = let i' = i + 1
in i' `seq` (i', setElement i e z)
z = pure 0
outer :: (Functor f, Functor g, Num a) => f a -> g a -> f (g a)
outer a b = fmap (\x->fmap (*x) b) a