-- |
-- Module:      Data.Geo.Jord.Vector3d
-- Copyright:   (c) 2020 Cedric Liegeois
-- License:     BSD3
-- Maintainer:  Cedric Liegeois <ofmooseandmen@yahoo.fr>
-- Stability:   experimental
-- Portability: portable
--
-- 3-element vector.
--
module Data.Geo.Jord.Vector3d
    ( Vector3d(..)
    , vadd
    , vsub
    , vdot
    , vnorm
    , vcross
    , vmultm
    , vscale
    , vunit
    , vzero
    , transpose
    , mdot
    ) where

-- | 3-element vector.
data Vector3d =
    Vector3d
        { vx :: Double
        , vy :: Double
        , vz :: Double
        }
    deriving (Eq, Show)

-- | Adds 2 vectors.
vadd :: Vector3d -> Vector3d -> Vector3d
vadd v1 v2 = Vector3d x y z
  where
    x = vx v1 + vx v2
    y = vy v1 + vy v2
    z = vz v1 + vz v2

-- | Subtracts 2 vectors.
vsub :: Vector3d -> Vector3d -> Vector3d
vsub v1 v2 = Vector3d x y z
  where
    x = vx v1 - vx v2
    y = vy v1 - vy v2
    z = vz v1 - vz v2

-- | Computes the cross product of 2 vectors: the vector perpendicular to given vectors.
vcross :: Vector3d -> Vector3d -> Vector3d
vcross v1 v2 = Vector3d x y z
  where
    x = vy v1 * vz v2 - vz v1 * vy v2
    y = vz v1 * vx v2 - vx v1 * vz v2
    z = vx v1 * vy v2 - vy v1 * vx v2

-- | Computes the dot product of 2 vectors.
vdot :: Vector3d -> Vector3d -> Double
vdot v1 v2 = vx v1 * vx v2 + vy v1 * vy v2 + vz v1 * vz v2

-- | Computes the norm of a vector.
vnorm :: Vector3d -> Double
vnorm v = sqrt (x * x + y * y + z * z)
  where
    x = vx v
    y = vy v
    z = vz v

-- | @vmultm v rm@ multiplies vector @v@ by __3x3__ matrix @m@ (rows).
vmultm :: Vector3d -> [Vector3d] -> Vector3d
vmultm v rm
    | length rm /= 3 = error ("Invalid matrix" ++ show rm)
    | otherwise = Vector3d x y z
  where
    [x, y, z] = map (vdot v) rm

-- | @vscale v s@ multiplies each component of @v@ by @s@.
vscale :: Vector3d -> Double -> Vector3d
vscale v s = Vector3d x y z
  where
    x = vx v * s
    y = vy v * s
    z = vz v * s

-- | Normalises a vector. The 'vnorm' of the produced vector is @1@.
vunit :: Vector3d -> Vector3d
vunit v
    | s == 1.0 = v
    | otherwise = vscale v s
  where
    s = 1.0 / vnorm v

-- | vector of vnorm 0.
vzero :: Vector3d
vzero = Vector3d 0 0 0

-- | transpose __square (3x3)__ matrix of 'Vector3d'.
transpose :: [Vector3d] -> [Vector3d]
transpose m = fmap ds2v (transpose' xs)
  where
    xs = fmap v2ds m

-- | transpose matrix.
transpose' :: [[Double]] -> [[Double]]
transpose' ([]:_) = []
transpose' x = map head x : transpose' (map tail x)

-- | multiplies 2 __3x3__ matrices.
mdot :: [Vector3d] -> [Vector3d] -> [Vector3d]
mdot a b = fmap ds2v [[vdot ar bc | bc <- transpose b] | ar <- a]

-- | 'Vector3d' to list of doubles.
v2ds :: Vector3d -> [Double]
v2ds (Vector3d x' y' z') = [x', y', z']

-- | list of doubles to 'Vector3d'.
ds2v :: [Double] -> Vector3d
ds2v [x', y', z'] = Vector3d x' y' z'
ds2v xs = error ("Invalid list: " ++ show xs)