-- |
-- Module      : Data.Manifold.Function.LocalModel
-- Copyright   : (c) Justus Sagemüller 2017
-- License     : GPL v3
-- 
-- Maintainer  : (@) jsagemue $ uni-koeln.de
-- Stability   : experimental
-- Portability : portable
-- 

{-# LANGUAGE ScopedTypeVariables      #-}
{-# LANGUAGE UnicodeSyntax            #-}
{-# LANGUAGE TypeOperators            #-}
{-# LANGUAGE TupleSections            #-}
{-# LANGUAGE TypeFamilies             #-}
{-# LANGUAGE UndecidableInstances     #-}
{-# LANGUAGE FlexibleContexts         #-}
{-# LANGUAGE StandaloneDeriving       #-}
{-# LANGUAGE TemplateHaskell          #-}
{-# LANGUAGE ConstraintKinds          #-}

module Data.Manifold.Function.LocalModel (
    -- * The model class
      LocalModel (..), ModellableRelation
    -- ** Local data fit models
    , AffineModel(..), QuadraticModel(..)
    , estimateLocalJacobian, estimateLocalHessian
    , propagationCenteredModel
    , propagationCenteredQuadraticModel
    , quadraticModel_derivatives
    -- ** Differential equations
    , DifferentialEqn, LocalDifferentialEqn(..)
    , propagateDEqnSolution_loc, LocalDataPropPlan(..)
    -- ** Range interpolation
    , rangeWithinVertices
    ) where


import Data.Manifold.Types
import Data.Manifold.PseudoAffine
import Data.Manifold.Types.Primitive ((^))
import Data.Manifold.Shade
import Data.Manifold.Riemannian

import Data.VectorSpace
import Math.LinearMap.Category

import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NE

import qualified Prelude as Hask

import Control.Category.Constrained.Prelude
import Control.Arrow.Constrained

import Control.Lens
import Control.Lens.TH


newtype LocalDifferentialEqn  x y = LocalDifferentialEqn {
      _rescanDifferentialEqn ::  x y
                             -> (Maybe (Shade' y), Maybe (Shade' (LocalLinear x y)))
    }
makeLenses ''LocalDifferentialEqn

type DifferentialEqn  x y = Shade (x,y) -> LocalDifferentialEqn  x y

data LocalDataPropPlan x y = LocalDataPropPlan
       { _sourcePosition :: !(Interior x)
       , _targetPosOffset :: !(Needle x)
       , _sourceData, _targetAPrioriData :: !y
       , _relatedData :: [(Needle x, y)]
       }
deriving instance (Show (Interior x), Show y, Show (Needle x))
             => Show (LocalDataPropPlan x y)

makeLenses ''LocalDataPropPlan


{-# DEPRECATED estimateLocalJacobian "Use `fitLocally`" #-}
estimateLocalJacobian ::  x y . ( WithField  Manifold x, Refinable y
                                 , SimpleSpace (Needle x), SimpleSpace (Needle y) )
            => Metric x -> [(Local x, Shade' y)]
                             -> Maybe (Shade' (LocalLinear x y))
estimateLocalJacobian = elj ( pseudoAffineWitness :: PseudoAffineWitness x
                            , pseudoAffineWitness :: PseudoAffineWitness y )
 where elj ( PseudoAffineWitness (SemimanifoldWitness BoundarylessWitness)
           , PseudoAffineWitness (SemimanifoldWitness BoundarylessWitness) )
        mex [(Local x₁, Shade' y₁ ey₁),(Local x₀, Shade' y₀ ey₀)]
         = return $ Shade' (dx-+|>δy)
                          (Norm . LinearFunction $ \δj -> δx  (σey<$|δj $ δx))
        where Just δx = x₁.-~.x₀
              δx' = (mex<$|δx)
              dx = δx'^/(δx'<.>^δx)
              Just δy = y₁.-~.y₀
              σey = convolveMetric ([]::[y]) ey₀ ey₁
       elj _ mex (po:ps)
           | DualSpaceWitness <- dualSpaceWitness :: DualNeedleWitness y
           , length ps > 1
               = mixShade's =<< (:|) <$> estimateLocalJacobian mex ps
                             <*> sequenceA [estimateLocalJacobian mex [po,pi] | pi<-ps]
       elj _ _ _ = return $ Shade' zeroV mempty


data AffineModel x y = AffineModel {
         _affineModelOffset :: Shade                      y
       , _affineModelLCoeff :: Shade ( Needle x  +>Needle y)
       }
deriving instance (Show (Shade y), Show (Shade (Needle x+>Needle y)))
              => Show (AffineModel x y)
makeLenses ''AffineModel


data QuadraticModel x y = QuadraticModel {
         _quadraticModelOffset :: Shade                      y
       , _quadraticModelLCoeff :: Shade ( Needle x  +>Needle y)
       , _quadraticModelQCoeff :: Shade (Needle x⊗〃+>Needle y)
       }
deriving instance ( Show (Shade y)
                  , Show (Shade (Needle x+>Needle y))
                  , Show (Shade (Needle x⊗〃+>Needle y)) )
              => Show (QuadraticModel x y)
makeLenses ''QuadraticModel

type QModelTup s x y = ( Needle y, (Needle x+>Needle y
                                 , SymmetricTensor s (Needle x)+>(Needle y)) )



quadratic_linearRegression ::  s x y .
                      ( WithField s PseudoAffine x
                      , WithField s PseudoAffine y, Geodesic y
                      , SimpleSpace (Needle x), SimpleSpace (Needle y) )
            => NE.NonEmpty (Needle x, Shade' y) -> QuadraticModel x y
quadratic_linearRegression = case ( dualSpaceWitness :: DualSpaceWitness (Needle x)
                                  , dualSpaceWitness :: DualSpaceWitness (Needle y) ) of
    (DualSpaceWitness, DualSpaceWitness) -> gLinearRegression
         (\δx -> lfun $ \(c,(b,a)) -> (a $ squareV δx) ^+^ (b $ δx) ^+^ c )
         (\cmy (cBest, (bBest, aBest)) σ
            -> let (σc, (σb, σa)) = second summandSpaceNorms $ summandSpaceNorms σ
               in QuadraticModel (Shade (cmy⊙+^cBest $ ([]::[y])) σc)
                              (Shade bBest σb)
                              (Shade aBest σa) )

gLinearRegression ::  s x y  ψ.
                      ( WithField s PseudoAffine x
                      , WithField s PseudoAffine y, Geodesic y
                      , SimpleSpace (Needle x), SimpleSpace (Needle y)
                      , SimpleSpace ψ, Scalar ψ ~ s )
            => (Needle x -> ψ -+> Needle y)
               -> (Interior y -> ψ -> Variance ψ ->  x y)
               -> NE.NonEmpty (Needle x, Shade' y) ->  x y
gLinearRegression fwdCalc analyse = qlr (pseudoAffineWitness, geodesicWitness)
 where qlr :: (PseudoAffineWitness y, GeodesicWitness y)
                   -> NE.NonEmpty (Needle x, Shade' y) ->  x y
       qlr (PseudoAffineWitness (SemimanifoldWitness _), GeodesicWitness _) ps
                 = analyse cmy ψ σψ
        where Just cmy = pointsBarycenter $ _shade'Ctr.snd<$>ps
              Just vsxy = Hask.mapM (\(x, Shade' y ey) -> (x,).(,ey)<$>y.-~.cmy) ps
              ψ = linearFit_bestModel regResult
              σψ = dualNorm . (case linearFit_χν² regResult of
                                     χν² | χν² > 0, recip χν² > 0
                                            -> scaleNorm (recip $ 1 + sqrt χν²)
                                     _ -> {-Dbg.trace ("Fit for regression model requires"
               ++" well-defined χν² (which needs positive number of degrees of freedom)."
               ++"\n Data: "++show (length ps
                                * subbasisDimension (entireBasis :: SubBasis (Needle y)))
               ++"\n Model parameters: "++show (subbasisDimension
                                        (entireBasis :: SubBasis ψ)) )-}
                                          id)
                                $ linearFit_modelUncertainty regResult
              regResult = linearRegression (arr . fwdCalc) (NE.toList vsxy)

quadraticModel_derivatives ::  x y .
          ( PseudoAffine x, PseudoAffine y
          , SimpleSpace (Needle x), SimpleSpace (Needle y)
          , Scalar (Needle y) ~ Scalar (Needle x) ) =>
     QuadraticModel x y -> (Shade' y, (Shade' (LocalLinear x y), Shade' (LocalBilinear x y)))
quadraticModel_derivatives (QuadraticModel sh shð shð²)
    | (PseudoAffineWitness (SemimanifoldWitness BoundarylessWitness))
                                     :: PseudoAffineWitness y <- pseudoAffineWitness
    , DualSpaceWitness :: DualSpaceWitness (Needle x) <- dualSpaceWitness
    , DualSpaceWitness :: DualSpaceWitness (Needle y) <- dualSpaceWitness
             = (dualShade sh, ( dualShade shð
                              , linIsoTransformShade (2*^id) $ dualShade shð² ))

{-# DEPRECATED estimateLocalHessian "Use `fitLocally`" #-}
estimateLocalHessian ::  x y . ( WithField  Manifold x, Refinable y, Geodesic y
                                , FlatSpace (Needle x), FlatSpace (Needle y) )
            => NonEmpty (Local x, Shade' y) -> QuadraticModel x y
estimateLocalHessian pts = quadratic_linearRegression $ first getLocalOffset <$> pts


propagationCenteredModel ::  x y  .
                         ( ModellableRelation x y, LocalModel  )
         => LocalDataPropPlan x (Shade' y) ->  x y
propagationCenteredModel propPlan = case fitLocally (NE.toList ptsFromCenter) of
                                       Just ->
 where ctrOffset = propPlan^.targetPosOffset^/2
       ptsFromCenter = (negateV ctrOffset, propPlan^.sourceData)
                     :| [(δx^-^ctrOffset, shy)
                        | (δx, shy)
                            <- (propPlan^.targetPosOffset, propPlan^.targetAPrioriData)
                               : propPlan^.relatedData
                        ]


propagationCenteredQuadraticModel ::  x y .
                         ( ModellableRelation x y )
         => LocalDataPropPlan x (Shade' y) -> QuadraticModel x y
propagationCenteredQuadraticModel = propagationCenteredModel


propagateDEqnSolution_loc ::  x y  . (ModellableRelation x y, LocalModel )
           => DifferentialEqn  x y
               -> LocalDataPropPlan x (Shade' y)
               -> Maybe (Shade' y)
propagateDEqnSolution_loc f propPlan
                  = pdesl (dualSpaceWitness :: DualNeedleWitness x)
                          (dualSpaceWitness :: DualNeedleWitness y)
                          (boundarylessWitness :: BoundarylessWitness x)
                          (pseudoAffineWitness :: PseudoAffineWitness y)
                          (geodesicWitness :: GeodesicWitness y)
 where pdesl DualSpaceWitness DualSpaceWitness BoundarylessWitness
             (PseudoAffineWitness (SemimanifoldWitness BoundarylessWitness))
             (GeodesicWitness _)
          | Nothing <- jacobian  = Nothing
          | otherwise            = pure result
         where (_,jacobian) = f shxy ^. rescanDifferentialEqn
                               $ propagationCenteredModel propPlan
               jacobianSh :: Shade (LocalLinear x y)
               Just jacobianSh = dualShade' <$> jacobian
               mx = propPlan^.sourcePosition .+~^ propPlan^.targetPosOffset ^/ 2 :: x
               (Shade _ expax' :: Shade x)
                    = coverAllAround (propPlan^.sourcePosition)
                                     [δx | (δx,_) <- propPlan^.relatedData]
               shxy = coverAllAround (mx, )
                                     [ (δx ^-^ propPlan^.targetPosOffset ^/ 2,  ^+^ v)
                                     | (δx,neυ) <- (zeroV, propPlan^.sourceData)
                                                  : (second id
                                                      <$> propPlan^.relatedData)
                                     , let Just  = neυ^.shadeCtr .-~. 
                                     , v <- normSpanningSystem' (neυ^.shadeNarrowness)
                                     ]
                where Just  = middleBetween (propPlan^.sourceData.shadeCtr)
                                              (propPlan^.targetAPrioriData.shadeCtr)
               expax = dualNorm expax'
               result :: Shade' y
               result = convolveShade' (propPlan^.sourceData)
                             (dualShade . linearProjectShade (lfun ($ δx)) $ jacobianSh)
               δx = propPlan^.targetPosOffset


type ModellableRelation x y = ( WithField  Manifold x
                              , Refinable y, Geodesic y
                              , FlatSpace (Needle x), FlatSpace (Needle y) )

class LocalModel  where
  fitLocally :: ModellableRelation x y
                  => [(Needle x, Shade' y)] -> Maybe ( x y)
  tweakLocalOffset :: ModellableRelation x y
                  => Lens' ( x y) (Shade y)
  evalLocalModel :: ModellableRelation x y =>  x y -> Needle x -> Shade' y

modelParametersOverdetMargin :: Int -> Int
modelParametersOverdetMargin n = n + round (sqrt $ fromIntegral n) - 1


-- | Dimension of the space of affine functions on @v@.
p¹Dimension ::  v p . FiniteDimensional v => p v -> Int
p¹Dimension _ = 1 + d
 where d = subbasisDimension (entireBasis :: SubBasis v)

-- | Dimension of the space of quadratic functions on @v@.
p²Dimension ::  v p . FiniteDimensional v => p v -> Int
p²Dimension _ = 1 + d + (d*(d+1))`div`2
 where d = subbasisDimension (entireBasis :: SubBasis v)

instance LocalModel AffineModel where
  fitLocally = aFitL dualSpaceWitness
   where aFitL ::  x y . ModellableRelation x y
                    => DualSpaceWitness (Needle y)
                      -> [(Needle x, Shade' y)] -> Maybe (AffineModel x y)
         aFitL DualSpaceWitness dataPts
          | (p₀:ps, :_) <- splitAt (modelParametersOverdetMargin
                                        $ p¹Dimension ([]::[Needle x])) dataPts
                 = Just . gLinearRegression
                            (\δx -> lfun $ \(b,a) -> (a $ δx) ^+^ b )
                            (\cmy (bBest, aBest) σ
                               -> let (σb, σa) = summandSpaceNorms σ
                                  in AffineModel (Shade (cmy⊙+^bBest $ ([]::[y]))
                                                        $ scaleNorm 2 σb)
                               -- The magic factor 2 seems dubious ↗, but testing indicates
                               -- that this is necessary to not overrate the accuracy.
                               --   TODO:  check the algorithms in linearmap-category.
                                                 (Shade aBest σa) )
                     $ (p₀:|ps++[])
          | otherwise  = Nothing
  tweakLocalOffset = affineModelOffset
  evalLocalModel = aEvL pseudoAffineWitness
   where aEvL ::  x y . ModellableRelation x y
                => PseudoAffineWitness y -> AffineModel x y -> Needle x -> Shade' y
         aEvL (PseudoAffineWitness (SemimanifoldWitness _)) (AffineModel shy₀ shj) δx
          = convolveShade' (dualShade shy₀)
                           (dualShade . linearProjectShade (lfun ($ δx)) $ shj)

instance LocalModel QuadraticModel where
  fitLocally = qFitL
   where qFitL ::  x y . ModellableRelation x y
                    => [(Needle x, Shade' y)] -> Maybe (QuadraticModel x y)
         qFitL dataPts
          | (p₀:ps, :_) <- splitAt (modelParametersOverdetMargin
                                        $ p²Dimension ([]::[Needle x])) dataPts
                 = Just . quadratic_linearRegression
                     $ (p₀:|ps++[])
          | otherwise  = Nothing
  tweakLocalOffset = quadraticModelOffset
  evalLocalModel = aEvL pseudoAffineWitness
   where aEvL ::  x y . ModellableRelation x y
                => PseudoAffineWitness y -> QuadraticModel x y -> Needle x -> Shade' y
         aEvL (PseudoAffineWitness (SemimanifoldWitness _))
              (QuadraticModel shy₀ shj shjj) δx
          = (dualShade shy₀)
           `convolveShade'`
            (dualShade . linearProjectShade (lfun ($ δx)) $ shj)
           `convolveShade'`
            (dualShade . linearProjectShade (lfun ($ squareV δx)) $ shjj)