{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Adapted from [Linear.Vector](https://hackage.haskell.org/package/linear-1.21.8/docs/Linear-Vector.html)
module Nonlinear.Vector
  ( Vec (..),
    negated,
    (^*),
    (*^),
    (^/),
    basis,
    basisFor,
    scaled,
    outer,
    unit,
    dot,
    quadrance,
    qd,
    distance,
    norm,
    signorm,
    normalize,
    project,
  )
where

import Control.Applicative (liftA2)
import Data.Foldable (Foldable (foldl'), toList)
import Nonlinear.Internal (ASetter', Lens', imap, set)

-- |
-- Class of vectors of statically known size.
--
-- Conceptually, this is 'Data.Functor.Rep.Representable', but with a 'Traversable' and 'Monad' constraint instead of just 'Functor'.
-- This makes it a catch-all class for things that we would normally think of as vectors of statically known size.
-- The Monad constraint might seem weird, but since we can implement the normal (diagonal) Monad instance in terms of 'construct', it doesn't actually preclude anything.
class (Traversable v, Monad v) => Vec v where
  construct :: ((forall b. Lens' (v b) b) -> a) -> v a

infixl 7 ^*, *^, ^/

-- | Compute the negation of a vector
--
-- >>> negated (V2 2 4)
-- V2 (-2) (-4)
negated :: (Vec f, Num a) => f a -> f a
negated :: f a -> f a
negated = (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
negate
{-# INLINE negated #-}

-- | Compute the left scalar product
--
-- >>> 2 *^ V2 3 4
-- V2 6 8
(*^) :: (Vec f, Num a) => a -> f a -> f a
*^ :: a -> f a -> f a
(*^) a
a = (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a
a a -> a -> a
forall a. Num a => a -> a -> a
*)
{-# INLINE (*^) #-}

-- | Compute the right scalar product
--
-- >>> V2 3 4 ^* 2
-- V2 6 8
(^*) :: (Vec f, Num a) => f a -> a -> f a
f a
f ^* :: f a -> a -> f a
^* a
a = (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
* a
a) f a
f
{-# INLINE (^*) #-}

-- | Compute division by a scalar on the right.
(^/) :: (Vec f, Fractional a) => f a -> a -> f a
f a
f ^/ :: f a -> a -> f a
^/ a
a = (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
a) f a
f
{-# INLINE (^/) #-}

-- | Produce a default basis for a vector space. If the dimensionality
-- of the vector space is not statically known, see 'basisFor'.
basis :: (Vec t, Num a) => [t a]
basis :: [t a]
basis = t () -> [t a]
forall (t :: * -> *) a b. (Vec t, Num a) => t b -> [t a]
basisFor (() -> t ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

-- | Produce a default basis for a vector space from which the
-- argument is drawn.
basisFor :: (Vec t, Num a) => t b -> [t a]
basisFor :: t b -> [t a]
basisFor t b
t = t (t a) -> [t a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (t (t a) -> [t a]) -> t (t a) -> [t a]
forall a b. (a -> b) -> a -> b
$ (Int -> b -> t a) -> t b -> t (t a)
forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap (\Int
i b
_ -> (Int -> b -> a) -> t b -> t a
forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap (\Int
j b
_ -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then a
1 else a
0) t b
t) t b
t
{-# INLINEABLE basisFor #-}

-- | Produce a diagonal (scale) matrix from a vector.
--
-- >>> scaled (V2 2 3)
-- V2 (V2 2 0) (V2 0 3)
scaled :: (Vec t, Num a) => t a -> t (t a)
scaled :: t a -> t (t a)
scaled t a
t = (Int -> a -> t a) -> t a -> t (t a)
forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap (\Int
i a
_ -> (Int -> a -> a) -> t a -> t a
forall (t :: * -> *) a b.
Traversable t =>
(Int -> a -> b) -> t a -> t b
imap (\Int
j a
a -> if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then a
a else a
0) t a
t) t a
t
{-# INLINE scaled #-}

-- | Create a unit vector.
--
-- >>> unit _x :: V2 Int
-- V2 1 0
unit :: (Vec t, Num a) => ASetter' (t a) a -> t a
unit :: ASetter' (t a) a -> t a
unit ASetter' (t a) a
l = ASetter' (t a) a -> a -> t a -> t a
forall s a. ASetter' s a -> a -> s -> s
set ASetter' (t a) a
l a
1 (a -> t a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
0)
{-# INLINE unit #-}

-- | Outer (tensor) product of two vectors
outer :: (Vec f, Vec g, Num a) => f a -> g a -> f (g a)
outer :: f a -> g a -> f (g a)
outer f a
a g a
b = (a -> g a) -> f a -> f (g a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\a
x -> (a -> a) -> g a -> g a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Num a => a -> a -> a
* a
x) g a
b) f a
a
{-# INLINE outer #-}

-- | Compute the inner product of two vectors or (equivalently)
-- convert a vector @f a@ into a covector @f a -> a@.
--
-- >>> V2 1 2 `dot` V2 3 4
-- 11
dot :: (Vec f, Num a) => f a -> f a -> a
dot :: f a -> f a -> a
dot f a
a f a
b = (a -> a -> a) -> a -> f a -> a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0 ((a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Num a => a -> a -> a
(+) f a
a f a
b)
{-# INLINE dot #-}

-- | Compute the squared norm. The name quadrance arises from
-- Norman J. Wildberger's rational trigonometry.
quadrance :: (Vec f, Num a) => f a -> a
quadrance :: f a -> a
quadrance = (a -> a -> a) -> a -> f a -> a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\a
b a
a -> a
b a -> a -> a
forall a. Num a => a -> a -> a
+ a
a a -> a -> a
forall a. Num a => a -> a -> a
* a
a) a
0
{-# INLINE quadrance #-}

-- | Compute the quadrance of the difference
qd :: (Vec f, Num a) => f a -> f a -> a
qd :: f a -> f a -> a
qd f a
a f a
b = (a -> a -> a) -> a -> f a -> a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0 (f a -> a) -> f a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-) f a
a f a
b
{-# INLINE qd #-}

-- | Compute the distance between two vectors in a metric space
distance :: (Vec f, Floating a) => f a -> f a -> a
distance :: f a -> f a -> a
distance f a
f f a
g = f a -> a
forall (f :: * -> *) a. (Vec f, Floating a) => f a -> a
norm (f a -> a) -> f a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> f a -> f a -> f a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-) f a
f f a
g
{-# INLINE distance #-}

-- | Compute the norm of a vector in a metric space
norm :: (Vec f, Floating a) => f a -> a
norm :: f a -> a
norm f a
v = a -> a
forall a. Floating a => a -> a
sqrt (f a -> a
forall (f :: * -> *) a. (Vec f, Num a) => f a -> a
quadrance f a
v)
{-# INLINE norm #-}

-- | Convert a non-zero vector to unit vector.
signorm :: (Vec f, Floating a) => f a -> f a
signorm :: f a -> f a
signorm f a
v = (a -> a) -> f a -> f a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> a -> a
forall a. Fractional a => a -> a -> a
/ f a -> a
forall (f :: * -> *) a. (Vec f, Floating a) => f a -> a
norm f a
v) f a
v
{-# INLINE signorm #-}

-- | Normalize a 'Metric' functor to have unit 'norm'. This function
-- does not change the functor if its 'norm' is 0 or 1.
normalize :: (Vec f, Floating a) => f a -> f a
normalize :: f a -> f a
normalize = f a -> f a
forall (f :: * -> *) a. (Vec f, Floating a) => f a -> f a
signorm
{-# INLINE normalize #-}

-- | @project u v@ computes the projection of @v@ onto @u@.
project :: (Vec v, Fractional a) => v a -> v a -> v a
project :: v a -> v a -> v a
project v a
u v a
v = ((v a
v v a -> v a -> a
forall (f :: * -> *) a. (Vec f, Num a) => f a -> f a -> a
`dot` v a
u) a -> a -> a
forall a. Fractional a => a -> a -> a
/ v a -> a
forall (f :: * -> *) a. (Vec f, Num a) => f a -> a
quadrance v a
u) a -> v a -> v a
forall (f :: * -> *) a. (Vec f, Num a) => a -> f a -> f a
*^ v a
u
{-# INLINE project #-}