-- | The Map module provides tools for developing function space 'Manifold's.
-- A map is a 'Manifold' where the 'Point's of the Manifold represent
-- parametric functions between 'Manifold's. The defining feature of 'Map's is
-- that they have a particular 'Domain' and 'Codomain', which themselves are
-- 'Manifold's.

module Goal.Geometry.Map.Multilinear (
    -- * Tensors
      Tensor (Tensor)
    -- ** Construction
    , (>.<)
    -- ** Matrix Operations
    , (<#>)
    , matrixRank
    , matrixInverse
    , matrixTranspose
    , matrixSquareRoot
    , matrixApply
    , matrixMap
    , matrixDiagonalConcatenate
    -- ** Cartesian
    , coordinateTransform
    , linearProjection
    -- ** HMatrix Conversion
    , toHMatrix
    , fromHMatrix
    -- * Affine Functions
    , Affine (Affine)
    , splitAffine
    , joinAffine
    ) where

--- Imports ---

import Prelude hiding (map,minimum,maximum)

-- Package --

import Goal.Core

import Goal.Geometry.Set
import Goal.Geometry.Manifold
import Goal.Geometry.Linear
import Goal.Geometry.Map

-- Qualified --

import qualified Data.Vector.Storable as C
import qualified Numeric.LinearAlgebra.HMatrix as H

--import Data.Vector.Storable.UnsafeSerialize



--- Affine Functions ---


-- | 'Manifold's of 'Affine' functions.
data Affine m n = Affine m n deriving (Eq, Read, Show)

splitAffine :: (Manifold m, Manifold n) => Function c d :#: Affine m n -> (d :#: m, Function c d :#: Tensor m n)
-- | Splits an 'Point' on an 'Affine' space into a matrix and a constant.
splitAffine aff =
    let (Affine m n) = manifold aff
        tns = Tensor m n
        css = coordinates aff
        (mcs,mtxcs) = C.splitAt (dimension m) css
     in (fromCoordinates m mcs, fromCoordinates tns mtxcs)

joinAffine :: (Manifold m, Manifold n) => d :#: m -> Function c d :#: Tensor m n -> Function c d :#: Affine m n
-- | Combines a matrix and a constant into 'Point' on an 'Affine' space.
joinAffine dm mtx =
    let (Tensor m n) = manifold mtx
     in fromCoordinates (Affine m n) $ coordinates dm C.++ coordinates mtx

-- Tensor Products --

-- | 'Manifold' of 'Tensor's given by the tensor product of the underlying pair of 'Manifold's.
data Tensor m n = Tensor m n deriving (Eq, Read, Show)

toHMatrix :: Manifold n => c :#: Tensor m n -> H.Matrix Double
-- | Converts a point on a 'Tensor' product manifold to a matrix for snappy
-- calculation.
toHMatrix pq =
    let (Tensor _ m) = manifold pq
     in H.reshape (dimension m) $ coordinates pq

fromHMatrix :: (Manifold m, Manifold n) => Tensor m n -> H.Matrix Double -> c :#: Tensor m n
fromHMatrix tns = fromCoordinates tns . H.flatten

matrixRank :: (Manifold m, Manifold n) => c :#: Tensor m n -> Int
matrixRank = H.rank . toHMatrix

(>.<) :: (Manifold m, Manifold n) => d :#: m -> c :#: n -> Function (Dual c) d :#: Tensor m n
-- | '>.<' denotes the outer product between two points. It provides a way of
-- constructing matrices of the 'Tensor' product space.
(>.<) p q = fromHMatrix (Tensor (manifold p) $ manifold q) $ coordinates p `H.outer` coordinates q

(<#>) :: (Manifold m, Manifold n, Manifold o)
      => Function d e :#: Tensor m n -> Function c d :#: Tensor n o -> Function c e :#: Tensor m o
-- | Tensor product composition.
(<#>) p q =
    let (Tensor m _) = manifold p
        (Tensor _ o) = manifold q
     in fromHMatrix (Tensor m o) $ toHMatrix p <> toHMatrix q

matrixSquareRoot :: Manifold m => c :#: Tensor m m -> c :#: Tensor m m
-- | The square root of a matrix.
matrixSquareRoot pq = fromHMatrix (manifold pq) . H.sqrtm $ toHMatrix pq

matrixInverse :: (Manifold n, Manifold m) => Function c d :#: Tensor m n -> Function d c :#: Tensor n m
-- | The inverse of a given 'Tensor' point.
matrixInverse pq =
    let Tensor m n = manifold pq
     in fromHMatrix (Tensor n m) . H.inv $ toHMatrix pq

matrixTranspose :: (Manifold m, Manifold n) => Function c d :#: Tensor m n -> Function (Dual d) (Dual c) :#: Tensor n m
-- | The transpose of a given 'Tensor' point.
matrixTranspose pq =
    let Tensor m n = manifold pq
     in fromHMatrix (Tensor n m) . H.tr $ toHMatrix pq

matrixDiagonalConcatenate :: (Manifold m, Manifold n, Manifold o, Manifold p)
    => Function c d :#: Tensor m n
    -> Function e f :#: Tensor o p
    -> Function (c,e) (d,f) :#: Tensor (m,o) (n,p)
-- | Creates a block diagonal matrix.
matrixDiagonalConcatenate cdmn efop =
    let (Tensor m n) = manifold cdmn
        (Tensor o p) = manifold efop
     in fromHMatrix (Tensor (m,o) (n,p)) $ H.diagBlock [toHMatrix cdmn, toHMatrix efop]


coordinateTransform :: Manifold m => [c :#: m] -> Function Cartesian c :#: Tensor m Euclidean
-- | Returns the coordinate transformation from 'Euclidean' space into the space
-- defined by the given basis vectors. This is a glorified fromColumns function.
coordinateTransform bss =
    fromHMatrix (Tensor (manifold $ head bss) . Euclidean $ length bss) . H.fromColumns $ coordinates <$> bss

linearProjection :: Manifold m => [Cartesian :#: m] -> Function Cartesian Cartesian :#: Tensor m m
-- | Returns the linear projection operator for the given subset of basis vectors.
linearProjection bss =
    let mtx = coordinateTransform bss
        mtxt = matrixTranspose mtx
     in mtx <#> matrixInverse (mtxt <#> mtx) <#> mtxt

matrixApply :: (Manifold m, Manifold n) => (Function c d :#: Tensor n m) -> (c :#: m) -> d :#: n
-- | Matrix vector multiplication.
matrixApply pq p =
    let (Tensor n _) = manifold pq
     in fromCoordinates n $ toHMatrix pq H.#> coordinates p
    {-
    let (Tensor n m) = manifold pq
     in if m == manifold p
          then fromCoordinates n $ toHMatrix pq H.#> coordinates p
          else error "matrix applied to wrong Manifold"
          -}

matrixMap :: (Manifold m, Manifold n) => (Function c d :#: Tensor m n) -> [c :#: n] -> [d :#: m]
-- | Mapped matrix vector multiplication, where we first turn the input vectors into a matrix itself (this can greatly improve computation time).
matrixMap pq ps =
    let (Tensor n _) = manifold pq
        mtx = toHMatrix pq
        xs = H.fromColumns $ coordinates <$> ps
     in map (fromCoordinates n) . H.toColumns $ mtx <> xs
    {-
    let (Tensor n m) = manifold pq
        mtx = toHMatrix pq
        xs = H.fromColumns $ coordinates <$> ps
     in if all (== m) $ manifold <$> ps
           then map (fromCoordinates n) . H.toColumns $ mtx <> xs
           else error "matrix applied to wrong Manifold"
           -}


--- Instances ---


-- Tensor Products --

instance (Manifold m, Manifold n) => Manifold (Tensor n m) where
    dimension (Tensor n m) = dimension m * dimension n

instance (Manifold m, Manifold n) => Map (Tensor m n) where
    type Domain (Tensor m n) = n
    domain (Tensor _ n) = n
    type Codomain (Tensor m n) = m
    codomain (Tensor m _) = m

instance (Manifold m, Manifold n) => Apply c d (Tensor m n) where
    (>.>) = matrixApply
    (>$>) = matrixMap

-- Affine Map --

instance (Manifold m, Manifold n) => Manifold (Affine m n) where
    dimension (Affine m n) = dimension m * dimension n + dimension m

instance (Manifold m, Manifold n) => Map (Affine m n) where
    type Domain (Affine m n) = n
    domain (Affine _ n) = n
    type Codomain (Affine m n) = m
    codomain (Affine m _) = m

instance (Manifold m, Manifold n) => Apply c d (Affine m n) where
    (>.>) p x =
        let (b,mtx) = splitAffine p
         in mtx >.> x <+> b
    (>$>) p xs =
        let (b,mtx) = splitAffine p
         in map (<+> b) $ mtx >$> xs