-- |
-- Module      : Data.Function.Affine
-- Copyright   : (c) Justus Sagemüller 2015
-- License     : GPL v3
-- 
-- Maintainer  : (@) sagemueller $ geo.uni-koeln.de
-- Stability   : experimental
-- Portability : portable
-- 

{-# LANGUAGE FlexibleInstances        #-}
{-# LANGUAGE UndecidableInstances     #-}
{-# LANGUAGE TypeFamilies             #-}
{-# LANGUAGE FunctionalDependencies   #-}
{-# LANGUAGE FlexibleContexts         #-}
{-# LANGUAGE LiberalTypeSynonyms      #-}
{-# LANGUAGE GADTs                    #-}
{-# LANGUAGE RankNTypes               #-}
{-# LANGUAGE TupleSections            #-}
{-# LANGUAGE ConstraintKinds          #-}
{-# LANGUAGE PatternGuards            #-}
{-# LANGUAGE PatternSynonyms          #-}
{-# LANGUAGE ViewPatterns             #-}
{-# LANGUAGE TypeOperators            #-}
{-# LANGUAGE UnicodeSyntax            #-}
{-# LANGUAGE MultiWayIf               #-}
{-# LANGUAGE ScopedTypeVariables      #-}
{-# LANGUAGE RecordWildCards          #-}
{-# LANGUAGE CPP                      #-}


module Data.Function.Affine (
              Affine
            , linearAffine
            , toOffsetSlope, toOffset'Slope 
            ) where
    


import Data.Semigroup

import Data.VectorSpace
import Data.LinearMap
import Data.LinearMap.HerMetric
import Data.AffineSpace
import Data.Tagged
import Data.Manifold.Types.Primitive
import Data.Manifold.PseudoAffine

import qualified Prelude
import qualified Control.Applicative as Hask

import Control.Category.Constrained.Prelude hiding ((^))
import Control.Category.Constrained.Reified
import Control.Arrow.Constrained
import Control.Monad.Constrained
import Data.Foldable.Constrained




data Affine s d c where
   Subtract :: AffineManifold α => Affine s (α,α) (Needle α)
   AddTo :: Affine s (α, Needle α) α
   ScaleWith :: (LinearManifold α, LinearManifold β) => (α:-*β) -> Affine s α β
   ReAffine :: ReWellPointed (Affine s) α β -> Affine s α β

reAffine :: ReWellPointed (Affine s) α β -> Affine s α β
reAffine (ReWellPointed f) = f
reAffine f = ReAffine f

pattern Specific f = ReWellPointed f
pattern Id = ReAffine WellPointedId
infixr 1 :>>>, :<<<
pattern f :>>> g <- ReAffine (WellPointedCompo (reAffine -> f) (reAffine -> g))
pattern g :<<< f <- ReAffine (WellPointedCompo (reAffine -> f) (reAffine -> g))
pattern Swap = ReAffine WellPointedSwap
pattern AttachUnit = ReAffine WellPointedAttachUnit
pattern DetachUnit = ReAffine WellPointedDetachUnit
pattern Regroup = ReAffine WellPointedRegroup
pattern Regroup' = ReAffine WellPointedRegroup_
pattern Terminal = ReAffine WellPointedTerminal
pattern Fst = ReAffine WellPointedFst
pattern Snd = ReAffine WellPointedSnd
infixr 3 :***, :&&&
pattern f :*** g <- ReAffine (WellPointedPar (reAffine -> f) (reAffine -> g))
pattern f :&&& g <- ReAffine (WellPointedFanout (reAffine -> f) (reAffine -> g))
pattern Const c = ReAffine (WellPointedConst c)


toOffsetSlope :: (MetricScalar s, WithField s LinearManifold d
                                 , WithField s AffineManifold c )
                      => Affine s d c -> (c, Needle d :-* Needle c)
toOffsetSlope f = toOffset'Slope f zeroV

-- | Basically evaluates an affine function as a generic differentiable one,
--   yielding at a given reference point the result and Jacobian. Unlike with
--   'Data.Function.Differentiable.Differentiable', the induced 1st-order Taylor
--   series is equal to the function!
toOffset'Slope :: ( MetricScalar s, WithField s AffineManifold d
                                   , WithField s AffineManifold c )
                      => Affine s d c -> d -> (c, Needle d :-* Needle c)
toOffset'Slope Subtract (a,b) = (a.-.b, linear $ uncurry(^-^))
toOffset'Slope AddTo (p,v) = (p.+^v, linear $ uncurry(^+^))
toOffset'Slope (ScaleWith q) ref = (lapply q ref, q)
toOffset'Slope Id ref = (ref, linear id)
toOffset'Slope (f :>>> g) ref = case toOffset'Slope f ref of
                  (cf,sf) -> case toOffset'Slope g cf of
                     (cg,sg)     -> (cg, sg*.*sf)
toOffset'Slope Swap ref = (swap ref, linear swap)
toOffset'Slope AttachUnit ref = ((ref,Origin), linear (,Origin))
toOffset'Slope DetachUnit ref = (fst ref, linear fst)
toOffset'Slope Regroup ref = (regroup ref, linear regroup)
toOffset'Slope Regroup' ref = (regroup' ref, linear regroup')
toOffset'Slope (f:***g) ref = case ( toOffset'Slope f (fst ref)
                                 , toOffset'Slope g (snd ref) ) of
                  ((cf, sf), (cg, sg)) -> ((cf,cg), linear $ lapply sf *** lapply sg)
toOffset'Slope Terminal ref = (Origin, zeroV)
toOffset'Slope Fst ref = (fst ref, linear fst)
toOffset'Slope Snd ref = (snd ref, linear snd)
toOffset'Slope (f:&&&g) ref = case ( toOffset'Slope (arr f) ref
                                  , toOffset'Slope (arr g) ref ) of
                  ((cf, sf), (cg, sg)) -> ((cf,cg), linear $ lapply sf &&& lapply sg)
toOffset'Slope (Const c) ref = (c, zeroV)
            
coOffsetForm :: ( MetricScalar s, WithField s AffineManifold d
                                , WithField s AffineManifold c )
                      => Affine s d c -> Affine s d c
coOffsetForm (ScaleWith q) = id&&&const zeroV >>> Subtract >>> ScaleWith q
coOffsetForm ((coOffsetForm -> Id:&&&Const cof :>>> Subtract :>>> f) :>>> g)
                    = id&&&const cof >>> Subtract >>> (f >>> g)
coOffsetForm ( (coOffsetForm -> Id:&&&Const cof :>>> Subtract :>>> f)
          :*** (coOffsetForm -> Id:&&&Const cog :>>> Subtract :>>> g) )
     = id&&&const(cof,cog) >>> Subtract >>> (f***g)
coOffsetForm (Id:&&&Const cof :>>> Subtract)
           = (Id&&&Const cof >>> ReAffine (ReWellPointed Subtract`WellPointedCompo`WellPointedId))
coOffsetForm f = f

pattern PreSubtract c f <- (coOffsetForm -> Id:&&&Const c :>>> Subtract :>>> f)

preSubtract :: ( MetricScalar s, WithField s AffineManifold d
                               , WithField s AffineManifold c )
               => c -> Affine s (Diff c) d -> Affine s c d
-- The specialised clauses may not actually be useful here.
preSubtract _ (Const d) = const d
preSubtract _ Terminal = Terminal
preSubtract c (f:>>>g) = preSubtract c f >>>! g
-- preSubtract t (f:***g) | (c,d)<-t = preSubtract c f *** preSubtract d g
preSubtract c (f:&&&g) = preSubtract c f &&& preSubtract c g
preSubtract c f = id&&&const c >>>! Subtract >>>! f
   
pattern PostAdd c f <- f:&&&Const c :>>> AddTo
pattern PostAdd' c f <- Const c:&&&f :>>> AddTo

postAdd :: (MetricScalar s, WithField s AffineManifold d, WithField s AffineManifold c)
               => Diff d -> Affine s c d -> Affine s c d
postAdd c f = f&&&const c >>>! AddTo
postAdd' :: (MetricScalar s, WithField s AffineManifold d, WithField s AffineManifold c)
               => d -> Affine s c (Diff d) -> Affine s c d
postAdd' c f = const c&&&f >>>! AddTo

instance (MetricScalar s) => EnhancedCat (->) (Affine s) where
  arr f = fst . toOffset'Slope f

instance (MetricScalar s) => EnhancedCat (Affine s) (ReWellPointed (Affine s)) where
  arr (Specific c) = c
  arr c = ReAffine c

instance (MetricScalar s, WithField s AffineManifold d, WithField s AffineManifold c)
                  => AffineSpace (Affine s d c) where
  type Diff (Affine s d c) = Affine s d (Diff c)
  
  ScaleWith q .-. ScaleWith r = ScaleWith $ q^-^r
  (PostAdd c (ScaleWith q)) .-. g = let (d, r) = toOffsetSlope g
                                    in postAdd (c.-.d) $ ScaleWith (q^-^r)
  f .-. (PostAdd d (ScaleWith r)) = let (c, q) = toOffsetSlope f
                                    in postAdd (c.-.d) $ ScaleWith (q^-^r)
  (PostAdd' c (ScaleWith q)) .-. g = let (d, r) = toOffsetSlope g
                                     in postAdd (c.-.d) $ ScaleWith (q^-^r)
  f .-. (PostAdd' d (ScaleWith r)) = let (c, q) = toOffsetSlope f
                                     in postAdd (c.-.d) $ ScaleWith (q^-^r)
  
  Id .-. Id = const zeroV
  Fst .-. Fst = const zeroV
  Snd .-. Snd = const zeroV
  Swap .-. Swap = const zeroV
  AttachUnit .-. AttachUnit = const zeroV
  DetachUnit .-. DetachUnit = const zeroV
  Terminal .-. _ = Terminal
  _ .-. Terminal = Terminal
  Subtract .-. Subtract = const zeroV
  AddTo .-. AddTo = const zeroV
  
  Const c .-. Const d = Const $ c.-.d
  
  Fst .-. Snd = Subtract

  (f:***g) .-. (h:***i) = f.-.h *** g.-.i
  (f:***g) .-. Const (c,d) = f.-.const c *** g.-.const d
  ζ .-. (f:***g) | Const (c,d) <- ζ = const c.-.f *** const d.-.g
  (f:&&&g) .-. (h:&&&i) = f.-.h &&& g.-.i
  (f:&&&_) .-. AttachUnit = f.-.id >>>! AttachUnit
  (f:&&&g) .-. Const (c,d) = f.-.const c &&& g.-.const d
  ζ .-. (f:&&&g) | Const (c,d) <- ζ = const c.-.f &&& const d.-.g

  ScaleWith q .-. f = let (c, r) = toOffset'Slope f zeroV
                      in postAdd (negateV c) $ ScaleWith (q^-^r)
  f .-. ScaleWith q = let (c, r) = toOffset'Slope f zeroV
                      in postAdd c $ ScaleWith (r^-^q)
  
  PreSubtract b f .-. g = let (c, q) = toOffsetSlope f
                              (d, r) = toOffset'Slope g b
                          in preSubtract b . postAdd (c.-.d) $ ScaleWith (q^-^r)
      -- f x = q·x + c
      -- g x = r·x + w
      -- d = r·b + w
      -- (q−r)·(x−b) = q·x − q⋅b − r⋅x + r⋅b
      -- s x = f (x−b) − g x
      --     = q⋅(x−b) + c − r⋅x − w
      --     = q⋅x − q⋅b + c − r⋅x − w
      --     = (q−r)·(x−b) + c − r⋅b − w
      --     = (q−r)·(x−b) + c − d
  
  -- According to GHC, this clause overlaps with the above. Hm...
  f .-. PreSubtract b g = let (c, q) = toOffset'Slope f b
                              (d, r) = toOffsetSlope g
                          in preSubtract b $ postAdd (c.-.d) $ ScaleWith (q^-^r)
      -- f x = q·x + v
      -- g x = r·x + d
      -- c = q·b + v
      -- (q−r)·(x−b) = q·x − q⋅b − r⋅x + r⋅b
      -- s x = f x − g (x−b)
      --     = q⋅x + v − r⋅(x−b) − d
      --     = q⋅x + v − r⋅x + r⋅b − d
      --     = (q−r)·(x−b) + q⋅b + v − d
      --     = (q−r)·(x−b) + c − d
  
  f .-. g = f&&&g >>> Subtract
  
  
  ScaleWith q .+^ ScaleWith r = ScaleWith $ q^+^r
  (PostAdd c (ScaleWith q)) .+^ g = let (d, r) = toOffsetSlope g
                                    in postAdd (c.+^d) $ ScaleWith (q^+^r)
  f .+^ (PostAdd d (ScaleWith r)) = let (c, q) = toOffsetSlope f
                                    in postAdd' (c.+^d) $ ScaleWith (q^+^r)
  (PostAdd' c (ScaleWith q)) .+^ g = let (d, r) = toOffsetSlope g
                                     in postAdd' (c.+^d) $ ScaleWith (q^+^r)
  f .+^ (PostAdd' d (ScaleWith r)) = let (c, q) = toOffsetSlope f
                                     in postAdd' (c.+^d) $ ScaleWith (q^+^r)
  (f:***g) .+^ (h:***i) = f.+^h *** g.+^i
  (f:&&&g) .+^ (h:&&&i) = f.+^h &&& g.+^i
  
  Const c .+^ Const c' = const (c.+^c')

  Terminal .+^ _ = Terminal
  Const c .+^ Terminal = Const c
  Const c .+^ f = const c&&&f >>> AddTo
  
  Id .+^ Id = Id >>> ScaleWith (linear (^*2))
  Fst .+^ Fst = Fst >>> ScaleWith (linear (^*2))
  Snd .+^ Snd = Snd >>> ScaleWith (linear (^*2))
  Fst .+^ Snd = AddTo
  Swap .+^ Swap = Swap >>> ScaleWith (linear (^*2))
  
  f .+^ Id = let (c,q) = toOffset'Slope f zeroV
             in const c&&&ScaleWith (q^+^idL) >>>! AddTo
  f .+^ AttachUnit = let (c,q) = toOffset'Slope f zeroV
                     in postAdd' c $ ScaleWith (q^+^linear(,Origin))
  f .+^ DetachUnit = let (c,q) = toOffset'Slope f zeroV
                     in postAdd' c $ ScaleWith (q^+^linear fst)
  f .+^ Swap = let (c,q) = toOffset'Slope f zeroV
               in postAdd' c $ ScaleWith (q^+^linear swap)
  
  PreSubtract b f .+^ g = let (c, q) = toOffsetSlope f
                              (d, r) = toOffset'Slope g b
                          in preSubtract b . postAdd' (c.+^d) $ ScaleWith (q^+^r)
      -- f x = q·x + c
      -- g x = r·x + w
      -- d = r·b + w
      -- (q+r)·(x−b) = q·x − q⋅b + r⋅x − r⋅b
      -- s x = f (x−b) + g x
      --     = q⋅(x−b) + c + r⋅x + w
      --     = q⋅x − q⋅b + c + r⋅x + w
      --     = (q+r)·(x−b) + c + r⋅b + w
      --     = (q−r)·(x−b) + c + d
  
  f .+^ PreSubtract b g = let (c, q) = toOffset'Slope f b
                              (d, r) = toOffsetSlope g
                          in preSubtract b . postAdd' (c.+^d) $ ScaleWith (q^+^r)
      -- f x = q·x + v
      -- g x = r·x + d
      -- c = q·b + v
      -- (q+r)·(x−b) = q·x − q⋅b + r⋅x − r⋅b
      -- s x = f x + g (x−b)
      --     = q⋅x + v + r⋅(x−b) + d
      --     = q⋅x + v + r⋅x − r⋅b + d
      --     = (q+r)·(x−b) + q⋅b + v + d
      --     = (q+r)·(x−b) + c + d
  
  f .+^ g = f&&&g >>> AddTo



instance (MetricScalar s, WithField s AffineManifold d, WithField s LinearManifold c)
                  => AdditiveGroup (Affine s d c) where
  zeroV = const zeroV
  
  negateV (Const c) = const $ negateV c
  negateV Terminal = Terminal
  negateV (ScaleWith ϕ) = ScaleWith $ negateV ϕ
  negateV (f:***g) = negateV f *** negateV g
  negateV (f:&&&g) = negateV f &&& negateV g
  negateV (f:>>>AddTo) = negateV f >>> AddTo
  negateV (f:>>>Subtract) = (f>>>swap) >>>! Subtract
  negateV (f:>>>ScaleWith ϕ) = negateV f >>>! ScaleWith ϕ
  negateV (f:>>>g) = f >>>! negateV g
  negateV AttachUnit = ScaleWith $ linear (negateV >>> (,Origin))
  negateV Subtract = Swap >>>! Subtract
  negateV f = f >>>! ScaleWith (linear negateV)
  
  (^+^) = (.+^)
  (^-^) = (.-.)
  

infixr 1 >>>!, <<<!
-- | Affine composition using only the reified skeleton, without trying to be
--   clever in any way.
(>>>!) :: ( MetricScalar s, WithField s AffineManifold α
          , WithField s AffineManifold β, WithField s AffineManifold γ )
      => Affine s α β -> Affine s β γ -> Affine s α γ
ReAffine f >>>! ReAffine g = ReAffine $ f >>> g
f >>>! ReAffine g = ReAffine $ ReWellPointed f >>> g
ReAffine f >>>! g = ReAffine $ f >>> ReWellPointed g
f >>>! g = ReAffine $ ReWellPointed f >>> ReWellPointed g

(<<<!) :: ( MetricScalar s, WithField s AffineManifold α
          , WithField s AffineManifold β, WithField s AffineManifold γ )
      => Affine s β γ -> Affine s α β -> Affine s α γ
(<<<!) = flip (>>>!)

instance (MetricScalar s) => Category (Affine s) where
  type Object (Affine s) o = WithField s AffineManifold o
  
  id = ReAffine id
  
  ScaleWith ϕ . ScaleWith ψ = ScaleWith $ ϕ*.*ψ
  g . ScaleWith ψ = let (d, ϕ) = toOffsetSlope g
                    in postAdd' d $ ScaleWith (ϕ*.*ψ)
  (f:***g) . (h:***i) = f.h *** g.i
  (f:***g) . (h:&&&i) = f.h &&& g.i
  g . (PostAdd' c f) = let (d, ϕ) = toOffset'Slope g c
                      in postAdd' d $ ScaleWith ϕ . f
  
  f . g = f <<<! g

instance (MetricScalar s) => Cartesian (Affine s) where
  type UnitObject (Affine s) = ZeroDim s
  swap = ReAffine swap
  attachUnit = ReAffine attachUnit
  detachUnit = ReAffine detachUnit
  regroup = ReAffine regroup
  regroup' = ReAffine regroup'

instance (MetricScalar s) => Morphism (Affine s) where
  Const c *** Const c' = const (c,c')
  Terminal *** Terminal = const (mempty, mempty)
  ReAffine f *** ReAffine g = ReAffine $ f *** g
  f *** ReAffine g = ReAffine $ ReWellPointed f *** g
  ReAffine f *** g = ReAffine $ f *** ReWellPointed g
  f *** g = ReAffine $ ReWellPointed f *** ReWellPointed g

instance (MetricScalar s) => PreArrow (Affine s) where
  terminal = ReAffine terminal
  fst = ReAffine fst
  snd = ReAffine snd
  Const c &&& Const c' = const (c,c')
  Terminal &&& Terminal = const (mempty, mempty)
  ReAffine f &&& ReAffine g = ReAffine $ f &&& g
  f &&& ReAffine g = ReAffine $ ReWellPointed f &&& g
  ReAffine f &&& g = ReAffine $ f &&& ReWellPointed g
  f &&& g = ReAffine $ ReWellPointed f &&& ReWellPointed g
        
--   Affine cof aof slf &&& Affine cog aog slg
--       = Affine coh (aof.-^lapply slf rco, aog.+^lapply slg rco)
--                  (linear $ lapply slf &&& lapply slg)
--    where rco = (cog.-.cof)^/2
--          coh = cof .+^ rco

instance (MetricScalar s) => WellPointed (Affine s) where
  unit = Tagged Origin
  const = ReAffine . const


linearAffine :: (MetricScalar s, WithField s LinearManifold α, WithField s LinearManifold β)
            => (α:-*β) -> Affine s α β
linearAffine = ScaleWith


type AffinFuncValue s = GenericAgent (Affine s)

instance (MetricScalar s) => HasAgent (Affine s) where
  alg = genericAlg
  ($~) = genericAgentMap
instance (MetricScalar s) => CartesianAgent (Affine s) where
  alg1to2 = genericAlg1to2
  alg2to1 = genericAlg2to1
  alg2to2 = genericAlg2to2
instance (MetricScalar s)
      => PointAgent (AffinFuncValue s) (Affine s) a x where
  point = genericPoint



instance (WithField s LinearManifold v, WithField s LinearManifold a)
    => AdditiveGroup (AffinFuncValue s a v) where
  zeroV = GenericAgent zeroV
  GenericAgent f ^+^ GenericAgent g = GenericAgent $ f ^+^ g
  negateV (GenericAgent f) = GenericAgent $ negateV f