-- | This module provides the core mathematical definitions used by the rest of
-- Goal. In Goal, all mathematical structures are 'Manifold's, even when they are
-- not especially complicated ones; 'Manifold's may indicate highly articulated
-- structures, but may also indicate simpler concepts such as (vector) spaces.
--
-- 'Manifold's are sets of points which can be described locally as 'Euclidean'
-- spaces. In geometry, a point is typically a member of the actual 'Manifold'.
-- However, arbitrary types of points will often be difficult to represent
-- directly, and so points in Goal are always represented in terms of their
-- 'Coordinates' in terms of a given chart.
--
-- Charts are in turn represented by phantom types. Mathematically, charts are
-- maps between the 'Manifold' and the relevant 'Cartesian' coordinate system.
-- However, since we do not represent the points of a 'Manifold' explicility,
-- we also cannot represent Charts explicitly. As such, Atlases merely index a
-- point so as to indicate how to interpret its particular 'Coordinates'.
module Goal.Geometry.Manifold
    ( -- * Manifolds
      Manifold (dimension)
    , Transition (transition)
    -- ** Sets
    , Embedded (Embedded, disembed)
    -- ** Points
    , Coordinates
    , (:#:) (coordinates, manifold)
    , coordinate
    , chart
    , breakChart
    , alterChart
    , listCoordinates
    , alterCoordinates
    , toPair
    -- ** Charts
    , Cartesian (Cartesian)
    , Polar (Polar)
    -- ** Constructors
    , fromList
    , fromCoordinates
    , euclideanPoint
    , realNumber
    -- * Direct Sums
    -- ** Replicated
    , mapReplicated
    , joinReplicated
    , concatReplicated
    -- ** DirectSum
    , joinPair
    , splitPair
    , joinPair'
    , splitPair'
    , joinTriple
    , splitTriple
    , joinTriple'
    , splitTriple'
    ) where


--- Imports ---


-- Goal --

import Goal.Core

import Goal.Geometry.Set

-- Qualified --

import qualified Data.Vector.Storable as C


--- Manifolds ---


-- | A geometric object with a certain 'dimension'. We assume that a 'Manifold'
-- somehow represents all the geometric, coordinate independent structure under
-- consideration. 'Manifold's should satisfy
--
-- > dimension m = length $ coordinates (Point m cs)
--
class Eq m => Manifold m where
    dimension :: m -> Int

-- | A point is an element of a 'Manifold' 'm' in terms of a particular
-- chart 'c'.
data c :#: m = Point
    { coordinates :: !Coordinates
    , manifold :: m } deriving (Eq, Read, Show)

infixr 1 :#:

coordinate :: Int -> c :#: m -> Double
coordinate n (Point cs _) = cs C.! n

data Embedded m c = Embedded { disembed :: m } deriving (Eq, Read, Show)

chart :: Manifold m => c -> c :#: m -> c :#: m
-- | 'chart' allows one to specify the Atlas of a new point. This is often
-- necessary when typeclass methods are used to generate points under a
-- variety of coordinate systems.
chart _ = id

breakChart :: Manifold m => c :#: m -> d :#: m
breakChart p = Point (coordinates p) (manifold p)

alterChart :: Manifold m => d -> c :#: m -> d :#: m
-- | Combines 'breakChart' and 'chart'.
alterChart _ = breakChart

toPair :: c :#: m -> (Double,Double)
toPair p = (coordinate 0 p,coordinate 1 p)

alterCoordinates :: Manifold m => (Double -> Double) -> c :#: m -> c :#: m
-- | 'alterCoordinates' allows one to map a function over the 'coordinates' of a
-- point without changing the chart.
alterCoordinates f (Point cs m) = Point (C.map f cs) m

listCoordinates :: c :#: m -> [Double]
-- | Returns the 'Coordinates' of the point in list form.
listCoordinates (Point cs _) = C.toList cs

-- | A 'transition' involves taking a point represented by the chart 'c',
-- and re-representing in terms of the chart 'd'. This will usually require
-- recomputation of the 'Coordinates'. 'Transition's should satisfy the law
--
-- > transition $ transition p = p
--
class Transition c d m where
    transition :: c :#: m -> d :#: m

fromList :: Manifold m => m -> [Double] -> c :#: m
-- | 'fromList' builds points without the need to work with vectors.
fromList m cs = fromCoordinates m $ C.fromList cs

fromCoordinates :: Manifold m => m -> Coordinates -> c :#: m
fromCoordinates m cs -- = Point cs m
    | dimension m == C.length cs = Point cs m
    | otherwise = error
        $ "Coordinate dimension (" ++ show (C.length cs) ++ ") does not match Manifold dimension (" ++ show (dimension m) ++ ")."

euclideanPoint :: [Double] -> Cartesian :#: Euclidean
-- | A convenience function for building 'Euclidean' vectors.
euclideanPoint xs = fromList (Euclidean $ length xs) xs

realNumber :: Double -> Cartesian :#: Continuum
-- | A convenience function for building elements of a 'Continuum'.
realNumber x = fromList Continuum [x]

--- Construction ---


-- Euclidean --

-- | The 'Cartesian' coordinate system.
data Cartesian = Cartesian

-- | The 'Polar' coordinate system.
data Polar = Polar

-- | A function to map functions over a point on a 'Replicated' 'Manifold'.
mapReplicated :: Manifold m => (c :#: m -> x) -> c :#: Replicated m -> [x]
mapReplicated pf ps =
    let (Replicated m k) = manifold ps
        cs = coordinates ps
        b = dimension m
     in [ pf . fromCoordinates m $ C.slice (i * b) b cs | i <- [0.. k -1 ] ]

joinReplicated :: Manifold m => [c :#: m] -> c :#: Replicated m
-- | Joins a list of distributions into a 'Replicated' 'Manifold'. Be advised that this function assumes
-- that the families of the individual distributions are equal.
joinReplicated ps =
    Point (foldl1' (C.++) (coordinates <$> ps)) $ Replicated (manifold $ head ps) (length ps)

concatReplicated :: c :#: Replicated m -> c :#: Replicated m -> c :#: Replicated m
-- | Joins two 'Replicated' 'Manifold's.
concatReplicated (Point cs (Replicated m x)) (Point cs' (Replicated _ y)) = Point (cs C.++ cs') $ Replicated m (x + y)

-- Direct Sums --

joinPair :: (Manifold m, Manifold n) => c :#: m -> d :#: n -> (c,d) :#: (m,n)
-- | Joins a pair of Points into a Point on the the direct sum of the underlying Charts and 'Manifold's.
joinPair = unsafeJoinPair

splitPair :: (Manifold m, Manifold n) => (c,d) :#: (m,n) -> (c :#: m, d :#: n)
-- | Splits a direct sum pair.
splitPair = unsafeSplitPair

joinPair' :: (Manifold m, Manifold n) => c :#: m -> c :#: n -> c :#: (m,n)
-- | Alternative version where we assume that the Charts are shared.
joinPair' = unsafeJoinPair

splitPair' :: (Manifold m, Manifold n) => c :#: (m,n) -> (c :#: m, c :#: n)
-- | Alternative version where we assume that the Charts are shared.
splitPair' = unsafeSplitPair

unsafeJoinPair :: (Manifold m, Manifold n) => c :#: m -> d :#: n -> e :#: (m,n)
unsafeJoinPair cm dn =
    fromCoordinates (manifold cm,manifold dn) $ coordinates cm C.++ coordinates dn

unsafeSplitPair :: (Manifold m, Manifold n) => c :#: (m,n) -> (d :#: m, e :#: n)
unsafeSplitPair cmn =
    let (m,n) = manifold cmn
        cs = coordinates cmn
        (mcs,ncs) = C.splitAt (dimension m) cs
     in (fromCoordinates m mcs, fromCoordinates n ncs)

joinTriple :: (Manifold m, Manifold n, Manifold o) => c :#: m -> d :#: n -> e :#: o -> (c,d,e) :#: (m,n,o)
-- | Joins a triple of Points into a Point on the the direct sum of the underlying Charts and 'Manifold's.
joinTriple = unsafeJoinTriple

splitTriple :: (Manifold m, Manifold n, Manifold o) => (c,d,e) :#: (m,n,o) -> (c :#: m, d :#: n, e :#: o)
-- | Splits a direct sum triple.
splitTriple = unsafeSplitTriple

joinTriple' :: (Manifold m, Manifold n, Manifold o) => c :#: m -> c :#: n -> c :#: o -> c :#: (m,n,o)
-- | Alternative version where we assume that the Charts are shared.
joinTriple' = unsafeJoinTriple

splitTriple' :: (Manifold m, Manifold n, Manifold o) => c :#: (m,n,o) -> (c :#: m, c :#: n, c :#: o)
-- | Alternative version where we assume that the Charts are shared.
splitTriple' = unsafeSplitTriple

unsafeJoinTriple :: (Manifold m, Manifold n, Manifold o) => c :#: m -> d :#: n -> e :#: o -> f :#: (m,n,o)
unsafeJoinTriple cm dn eo =
    fromCoordinates (manifold cm, manifold dn, manifold eo) $ coordinates cm C.++ coordinates dn C.++ coordinates eo

unsafeSplitTriple :: (Manifold m, Manifold n, Manifold o) => c :#: (m,n,o) -> (d :#: m, e :#: n, f :#: o)
unsafeSplitTriple cmno =
    let (m,n,o) = manifold cmno
        (mcs,cs') = C.splitAt (dimension m) $ coordinates cmno
        (ncs,ocs) = C.splitAt (dimension n) cs'
     in (fromCoordinates m mcs, fromCoordinates n ncs, fromCoordinates o ocs)


--- Instances ---


instance Transition c c m where
    transition = id

-- Embedded --

instance Manifold m => Set (Embedded m c) where
    type Element (Embedded m c) = c :#: m

-- Euclidean --

instance Manifold Euclidean where
    dimension (Euclidean n) = n

instance Manifold Continuum where
    dimension _ = 1

instance Transition Polar Cartesian Euclidean where
    transition p =
        let r:phis = listCoordinates p
            phiss = reverse . tails $ reverse phis
            m = manifold p
            xs = [ r * cos phi * product (sin <$> phis') | (phi,phis') <- zip phis phiss ]
         in fromList m $ xs ++ [r * product (sin <$> phis)]

instance Transition Cartesian Polar Euclidean where
    transition p =
        let (Euclidean n) = manifold p
            xs = listCoordinates p
            xs2 = listCoordinates $ alterCoordinates (^2) p
            r = sqrt $ sum xs2
            (phis,phin0:_) = splitAt (n-2) [ acos $ xi / sqrt (sum xs2i) | (xi,xs2i) <- zip xs (tails xs2) ]
            xn = last xs
            phin = if xn > 0 then phin0 else 2*pi - phin0
         in fromList (Euclidean n) $ r : (phis ++ [phin])

-- DirectSum --

instance (Manifold m, Manifold n) => Manifold (m,n) where
    dimension (m,n) = dimension m + dimension n

instance (Manifold m, Manifold n, Manifold o) => Manifold (m,n,o) where
    dimension (m,n,o) = dimension m + dimension n + dimension o

-- Replicated --

instance Manifold m => Manifold (Replicated m) where
    dimension (Replicated m rn) = dimension m * rn