{-# language TypeFamilies, MultiParamTypeClasses, KindSignatures, FlexibleContexts, FlexibleInstances, ConstraintKinds #-}
{-# language AllowAmbiguousTypes #-}
{-# language CPP #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.LinearAlgebra.Class
-- Copyright   :  (c) Marco Zocca 2017
--
-- Maintainer  :  zocca marco gmail
-- Stability   :  experimental
-- Portability :  portable
--
-- Typeclasses for linear algebra and related concepts
--
-----------------------------------------------------------------------------
module Numeric.LinearAlgebra.Class where

-- import Control.Applicative
import Data.Complex

-- import Control.Exception
-- import Control.Exception.Common

-- import Data.Typeable (Typeable)

import qualified Data.Vector as V (Vector)

import Data.Sparse.Types
import Numeric.Eps

-- * Matrix and vector elements (optionally Complex)
class (Eq e , Fractional e, Floating e, Num (EltMag e), Ord (EltMag e)) => Elt e where
type EltMag e :: *
-- | Complex conjugate, or identity function if its input is real-valued
conj :: e -> e
conj = id
-- | Magnitude
mag :: e -> EltMag e

instance Elt Double where {type EltMag Double = Double ; mag = id}
instance Elt Float where {type EltMag Float = Float; mag = id}
instance (RealFloat e) => Elt (Complex e) where
type EltMag (Complex e) = e
conj = conjugate
mag = magnitude

infixl 6 ^+^, ^-^

-- | The zero element: identity for '(^+^)'
zeroV :: v
(^+^) :: v -> v -> v
negateV :: v -> v
-- | Group subtraction
(^-^) :: v -> v -> v
(^-^) x y = x ^+^ negateV y

infixr 7 .*

-- * Vector space @v@.
class (AdditiveGroup v, Num (Scalar v)) => VectorSpace v where
type Scalar v :: *
-- | Scale a vector
(.*) :: Scalar v -> v -> v

-- | Adds inner (dot) products.
class VectorSpace v => InnerSpace v where
-- | Inner/dot product
(<.>) :: v -> v -> Scalar v

-- | Inner product
dot :: InnerSpace v => v -> v -> Scalar v
dot = (<.>)

infixr 7 ./
infixl 7 *.

-- | Scale a vector by the reciprocal of a number (e.g. for normalization)
(./) :: (VectorSpace v, s ~ Scalar v, Fractional s) => v -> s -> v
v ./ s = (recip s) .* v

-- | Vector multiplied by scalar
(*.) :: (VectorSpace v, s ~ Scalar v) => v -> s -> v
(*.) = flip (.*)

-- | Convex combination of two vectors (NB: 0 <= `a` <= 1).
cvx :: VectorSpace v => Scalar v -> v -> v -> v
cvx a u v = a .* u ^+^ ((1-a) .* v)

-- ** Hilbert-space distance function

-- |`hilbertDistSq x y = || x - y ||^2` computes the squared L2 distance between two vectors
hilbertDistSq :: InnerSpace v => v -> v -> Scalar v
hilbertDistSq x y = t <.> t where
t = x ^-^ y

-- * Normed vector spaces

class (InnerSpace v, Num (RealScalar v), Eq (RealScalar v), Epsilon (Magnitude v), Show (Magnitude v), Ord (Magnitude v)) => Normed v where
type Magnitude v :: *
type RealScalar v :: *
-- | L1 norm
norm1 :: v -> Magnitude v
-- | Euclidean (L2) norm squared
norm2Sq :: v -> Magnitude v
-- | Lp norm (p > 0)
normP :: RealScalar v -> v -> Magnitude v
-- | Normalize w.r.t. Lp norm
normalize :: RealScalar v -> v -> v
-- | Normalize w.r.t. L2 norm
normalize2 :: v -> v
-- | Normalize w.r.t. norm2' instead of norm2
normalize2' :: Floating (Scalar v) => v -> v
normalize2' x = x ./ norm2' x
-- | Euclidean (L2) norm
norm2 :: Floating (Magnitude v) => v -> Magnitude v
norm2 x = sqrt (norm2Sq x)
-- | Euclidean (L2) norm; returns a Complex (norm :+ 0) for Complex-valued vectors
norm2' :: Floating (Scalar v) => v -> Scalar v
norm2' x = sqrt \$ x <.> x
-- | Lp norm (p > 0)
norm :: Floating (Magnitude v) => RealScalar v -> v -> Magnitude v
norm p v
| p == 1 = norm1 v
| p == 2 = norm2 v
| otherwise = normP p v

-- | Infinity-norm (Real)
normInftyR :: (Foldable t, Ord a) => t a -> a
normInftyR x = maximum x

-- | Infinity-norm (Complex)
normInftyC :: (Foldable t, RealFloat a, Functor t) => t (Complex a) -> a
normInftyC x = maximum (magnitude <\$> x)

-- | Lp inner product (p > 0)
dotLp :: (Set t, Foldable t, Floating a) => a -> t a -> t a ->  a
dotLp p v1 v2 = sum u**(1/p) where
f a b = (a*b)**p
u = liftI2 f v1 v2

-- | Reciprocal
reciprocal :: (Functor f, Fractional b) => f b -> f b
reciprocal = fmap recip

-- |Scale
scale :: (Num b, Functor f) => b -> f b -> f b
scale n = fmap (* n)

-- * Matrix ring

-- | A matrix ring is any collection of matrices over some ring R that form a ring under matrix addition and matrix multiplication

class (AdditiveGroup m, Epsilon (MatrixNorm m)) => MatrixRing m where
type MatrixNorm m :: *
-- | Matrix-matrix product
(##) :: m -> m -> m
-- | Matrix times matrix transpose (A B^T)
(##^) :: m -> m -> m
-- | Matrix transpose times matrix (A^T B)
(#^#) :: m -> m -> m
a #^# b = transpose a ## b
-- | Matrix transpose (Hermitian conjugate in the Complex case)
transpose :: m -> m
-- | Frobenius norm
normFrobenius :: m -> MatrixNorm m

-- a "sparse matrix ring" ?

-- class MatrixRing m a => SparseMatrixRing m a where
--   (#~#) :: Epsilon a => Matrix m a -> Matrix m a -> Matrix m a

-- * Linear vector space

class (VectorSpace v {-, MatrixRing (MatrixType v)-}) => LinearVectorSpace v where
type MatrixType v :: *
-- | Matrix-vector action
(#>) :: MatrixType v -> v -> v
-- | Dual matrix-vector action
(<#) :: v -> MatrixType v -> v

-- ** LinearVectorSpace + Normed

type V v = (LinearVectorSpace v, Normed v)

-- ** Linear systems

class LinearVectorSpace v => LinearSystem v where
-- | Solve a linear system; uses GMRES internally as default method
MatrixType v   -- ^ System matrix
-> v              -- ^ Right-hand side
-> m v            -- ^ Result

-- * FiniteDim : finite-dimensional objects

class FiniteDim f where
type FDSize f
-- | Dimension (i.e. Int for SpVector, (Int, Int) for SpMatrix)
dim :: f -> FDSize f

-- * HasData : accessing inner data (do not export)

class HasData f where
type HDData f
-- | Number of nonzeros
nnz :: f -> Int
dat :: f -> HDData f

-- * Sparse : sparse datastructures

class (FiniteDim f, HasData f) => Sparse f where
-- | Sparsity (fraction of nonzero elements)
spy :: Fractional b => f -> b

-- * Set : types that behave as sets

class Functor f => Set f where
-- | Union binary lift : apply function on _union_ of two "sets"
liftU2 :: (a -> a -> a) -> f a -> f a -> f a

-- | Intersection binary lift : apply function on _intersection_ of two "sets"
liftI2 :: (a -> a -> b) -> f a -> f a -> f b

-- * SpContainer : sparse container datastructures. Insertion, lookup, toList, lookup with 0 default
class Sparse c => SpContainer c where
type ScIx c :: *
type ScElem c
scInsert :: ScIx c -> ScElem c -> c -> c
scLookup :: c -> ScIx c -> Maybe (ScElem c)
scToList :: c -> [(ScIx c, ScElem c)]
-- -- | Lookup with default, infix form ("safe" : should throw an exception if lookup is outside matrix bounds)
(@@) :: c -> ScIx c -> ScElem c

-- * SparseVector

class SpContainer v => SparseVector v where
type SpvIx v :: *
svFromList :: Int -> [(SpvIx v, ScElem v)] -> v
svFromListDense :: Int -> [ScElem v] -> v
svConcat :: Foldable t => t v -> v
-- svZipWith :: (e -> e -> e) -> v e -> v e -> v e

-- * SparseMatrix

class SpContainer m => SparseMatrix m where
smFromVector :: LexOrd -> (Int, Int) -> V.Vector (IxRow, IxCol, ScElem m) -> m
-- smFromFoldableDense :: Foldable t => t e -> m e
smTranspose :: m -> m
-- smExtractSubmatrix ::
--   m e -> (IxRow, IxRow) -> (IxCol, IxCol) -> m e
encodeIx :: m -> LexOrd -> (IxRow, IxCol) -> LexIx
decodeIx :: m -> LexOrd -> LexIx -> (IxRow, IxCol)

-- data RowsFirst = RowsFirst
-- data ColsFirst = ColsFirst

-- * SparseMatVec

-- | Combining functions for relating (structurally) matrices and vectors, e.g. extracting/inserting rows/columns/submatrices

-- class (SparseMatrix m o e, SparseVector v e) => SparseMatVec m o v e where
--   smvInsertRow :: m e -> v e -> IxRow -> m e
--   smvInsertCol :: m e -> v e -> IxCol -> m e
--   smvExtractRow :: m e -> IxRow -> v e
--   smvExtractCol :: m e -> IxCol -> v e

-- * Utilities

-- | Lift a real number onto the complex plane
toC :: Num a => a -> Complex a
toC r = r :+ 0

-- | Instances for builtin types
#define ScalarType(t) \
instance AdditiveGroup (t) where {zeroV = 0; (^+^) = (+); negateV = negate};\
instance VectorSpace (t) where {type Scalar (t) = t; (.*) = (*) };

-- ScalarType(Int)
-- ScalarType(Integer)
ScalarType(Float)
ScalarType(Double)
ScalarType(Complex Float)
ScalarType(Complex Double)
-- ScalarType(CSChar)
-- ScalarType(CInt)
-- ScalarType(CShort)
-- ScalarType(CLong)
-- ScalarType(CLLong)
-- ScalarType(CIntMax)
-- ScalarType(CFloat)
-- ScalarType(CDouble)

#undef ScalarType

instance InnerSpace Float  where {(<.>) = (*)}
instance InnerSpace Double where {(<.>) = (*)}
instance InnerSpace (Complex Float)  where {x <.> y = x * conjugate y}
instance InnerSpace (Complex Double) where {x <.> y = x * conjugate y}

#define SimpleNormedInstance(t) \
instance Normed (t) where {type Magnitude (t) = t; type RealScalar (t) = t;\
norm1 = abs; norm2Sq = (**2); normP _ = abs; normalize _ = signum;\
normalize2 = signum; normalize2' = signum; norm2 = abs; norm2' = abs; norm _ = abs};

SimpleNormedInstance(Float)
SimpleNormedInstance(Double)

#undef SimpleNormedInstance

#define ComplexNormedInstance(t) \
instance Normed (Complex (t)) where {\
type Magnitude  (Complex (t)) = t;\
type RealScalar (Complex (t)) = t;\
norm1   (r :+ i) = abs r + abs i;\
norm2Sq (r :+ i) = r*r + i*i;\
normP p (r :+ i) = (r**p + i**p)**(1/p);\
normalize p c = toC (1 / normP p c) * c;\
normalize2  c = (1 / norm2' c) * c;\
norm2  = magnitude;\
norm2' = toC . magnitude;};

ComplexNormedInstance(Float)
ComplexNormedInstance(Double)

#undef ComplexNormedInstance