{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver -fplugin=GHC.TypeLits.Normalise -fconstraint-solver-iterations=10 #-}
{-# LANGUAGE UndecidableInstances #-}

-- | Multilayer perceptrons which instantiate backpropagation through laziness.
-- Right now the structure is simplier than it could be, but it leads to nice
-- types. If anyone ever wants to use a DNN with super-Affine biases, the code
-- is willing.
module Goal.Geometry.Map.NeuralNetwork
    ( -- * Neural Networks
      NeuralNetwork
    ) where


--- Imports ---


-- Goal --

import Goal.Core

import Goal.Geometry.Manifold
import Goal.Geometry.Map
import Goal.Geometry.Vector
import Goal.Geometry.Map.Linear
import Goal.Geometry.Differential

import qualified Goal.Core.Vector.Storable as S

--- Multilayer ---


-- | A multilayer, artificial neural network.
data NeuralNetwork (gys :: [(Type -> Type -> Type,Type)])
    (f :: (Type -> Type -> Type)) z x


--- Instances ---


instance Manifold (Affine f z z x) => Manifold (NeuralNetwork '[] f z x) where
      type Dimension (NeuralNetwork '[] f z x) = Dimension (Affine f z z x)

instance (Manifold (Affine f z z y), Manifold (NeuralNetwork gys g y x))
  => Manifold (NeuralNetwork ('(g,y) : gys) f z x) where
      type Dimension (NeuralNetwork ('(g,y) : gys) f z x)
        = Dimension (Affine f z z y) + Dimension (NeuralNetwork gys g y x)


fromSingleLayerNetwork :: c # NeuralNetwork '[] f z x -> c # Affine f z z x
{-# INLINE fromSingleLayerNetwork #-}
fromSingleLayerNetwork :: (c # NeuralNetwork '[] f z x) -> c # Affine f z z x
fromSingleLayerNetwork = (c # NeuralNetwork '[] f z x) -> c # Affine f z z x
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

toSingleLayerNetwork :: c # Affine f z z x -> c # NeuralNetwork '[] f z x
{-# INLINE toSingleLayerNetwork #-}
toSingleLayerNetwork :: (c # Affine f z z x) -> c # NeuralNetwork '[] f z x
toSingleLayerNetwork = (c # Affine f z z x) -> c # NeuralNetwork '[] f z x
forall x y c d. (Dimension x ~ Dimension y) => (c # x) -> Point d y
breakPoint

-- | Seperates a 'NeuralNetwork' into the final layer and the rest of the network.
splitNeuralNetwork
    :: (Manifold (Affine f z z y), Manifold (NeuralNetwork gys g y x))
    => c # NeuralNetwork ('(g,y):gys) f z x
    -> (c # Affine f z z y, c # NeuralNetwork gys g y x)
{-# INLINE splitNeuralNetwork #-}
splitNeuralNetwork :: (c # NeuralNetwork ('(g, y) : gys) f z x)
-> (c # Affine f z z y, c # NeuralNetwork gys g y x)
splitNeuralNetwork (Point Vector (Dimension (NeuralNetwork ('(g, y) : gys) f z x)) Double
xs) =
    let (Vector (Dimension z + Dimension (f z y)) Double
xys,Vector (Dimension (NeuralNetwork gys g y x)) Double
xns) = Vector
  ((Dimension z + Dimension (f z y))
   + Dimension (NeuralNetwork gys g y x))
  Double
-> (Vector (Dimension z + Dimension (f z y)) Double,
    Vector (Dimension (NeuralNetwork gys g y x)) Double)
forall (n :: Nat) (m :: Nat) a.
(KnownNat n, Storable a) =>
Vector (n + m) a -> (Vector n a, Vector m a)
S.splitAt Vector
  ((Dimension z + Dimension (f z y))
   + Dimension (NeuralNetwork gys g y x))
  Double
Vector (Dimension (NeuralNetwork ('(g, y) : gys) f z x)) Double
xs
     in (Vector (Dimension (Affine f z z y)) Double -> c # Affine f z z y
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector (Dimension z + Dimension (f z y)) Double
Vector (Dimension (Affine f z z y)) Double
xys, Vector (Dimension (NeuralNetwork gys g y x)) Double
-> c # NeuralNetwork gys g y x
forall c x. Vector (Dimension x) Double -> Point c x
Point Vector (Dimension (NeuralNetwork gys g y x)) Double
xns)

-- | Joins a layer onto the end of a 'NeuralNetwork'.
joinNeuralNetwork
    :: (Manifold (Affine f z z y), Manifold (NeuralNetwork gys g y x))
    => c # Affine f z z y
    -> c # NeuralNetwork gys g y x
    -> c # NeuralNetwork ('(g,y):gys) f z x
{-# INLINE joinNeuralNetwork #-}
joinNeuralNetwork :: (c # Affine f z z y)
-> (c # NeuralNetwork gys g y x)
-> c # NeuralNetwork ('(g, y) : gys) f z x
joinNeuralNetwork (Point Vector (Dimension (Affine f z z y)) Double
xys) (Point Vector (Dimension (NeuralNetwork gys g y x)) Double
xns) =
    Vector (Dimension (NeuralNetwork ('(g, y) : gys) f z x)) Double
-> c # NeuralNetwork ('(g, y) : gys) f z x
forall c x. Vector (Dimension x) Double -> Point c x
Point (Vector (Dimension (NeuralNetwork ('(g, y) : gys) f z x)) Double
 -> c # NeuralNetwork ('(g, y) : gys) f z x)
-> Vector (Dimension (NeuralNetwork ('(g, y) : gys) f z x)) Double
-> c # NeuralNetwork ('(g, y) : gys) f z x
forall a b. (a -> b) -> a -> b
$ Vector Vector (Dimension z + Dimension (f z y)) Double
Vector (Dimension (Affine f z z y)) Double
xys Vector Vector (Dimension z + Dimension (f z y)) Double
-> Vector (Dimension (NeuralNetwork gys g y x)) Double
-> Vector
     ((Dimension z + Dimension (f z y))
      + Dimension (NeuralNetwork gys g y x))
     Double
forall (n :: Nat) (m :: Nat) a.
Storable a =>
Vector n a -> Vector m a -> Vector (n + m) a
S.++ Vector (Dimension (NeuralNetwork gys g y x)) Double
xns

instance (Manifold (Affine f z z y), Manifold (NeuralNetwork gys g y x))
  => Product (NeuralNetwork ('(g,y) : gys) f z x) where
      type First (NeuralNetwork ('(g,y) : gys) f z x)
        = Affine f z z y
      type Second (NeuralNetwork ('(g,y) : gys) f z x)
        = NeuralNetwork gys g y x
      join :: (c # First (NeuralNetwork ('(g, y) : gys) f z x))
-> (c # Second (NeuralNetwork ('(g, y) : gys) f z x))
-> c # NeuralNetwork ('(g, y) : gys) f z x
join = (c # First (NeuralNetwork ('(g, y) : gys) f z x))
-> (c # Second (NeuralNetwork ('(g, y) : gys) f z x))
-> c # NeuralNetwork ('(g, y) : gys) f z x
forall (f :: Type -> Type -> Type) z y
       (gys :: [(Type -> Type -> Type, Type)]) (g :: Type -> Type -> Type)
       x c.
(Manifold (Affine f z z y), Manifold (NeuralNetwork gys g y x)) =>
(c # Affine f z z y)
-> (c # NeuralNetwork gys g y x)
-> c # NeuralNetwork ('(g, y) : gys) f z x
joinNeuralNetwork
      split :: (c # NeuralNetwork ('(g, y) : gys) f z x)
-> (c # First (NeuralNetwork ('(g, y) : gys) f z x),
    c # Second (NeuralNetwork ('(g, y) : gys) f z x))
split = (c # NeuralNetwork ('(g, y) : gys) f z x)
-> (c # First (NeuralNetwork ('(g, y) : gys) f z x),
    c # Second (NeuralNetwork ('(g, y) : gys) f z x))
forall (f :: Type -> Type -> Type) z y
       (gys :: [(Type -> Type -> Type, Type)]) (g :: Type -> Type -> Type)
       x c.
(Manifold (Affine f z z y), Manifold (NeuralNetwork gys g y x)) =>
(c # NeuralNetwork ('(g, y) : gys) f z x)
-> (c # Affine f z z y, c # NeuralNetwork gys g y x)
splitNeuralNetwork

instance (Map c f z y, Map c (NeuralNetwork gys g) y x, Transition c (Dual c) y)
  => Map c (NeuralNetwork ('(g,y) : gys) f) z x where
    {-# INLINE (>.>) #-}
    >.> :: (c # NeuralNetwork ('(g, y) : gys) f z x) -> (c #* x) -> c # z
(>.>) c # NeuralNetwork ('(g, y) : gys) f z x
fg c #* x
x =
        let (c # Affine f z z y
f,c # NeuralNetwork gys g y x
g) = (c # NeuralNetwork ('(g, y) : gys) f z x)
-> (c # First (NeuralNetwork ('(g, y) : gys) f z x),
    c # Second (NeuralNetwork ('(g, y) : gys) f z x))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split c # NeuralNetwork ('(g, y) : gys) f z x
fg
         in c # Affine f z z y
f (c # Affine f z z y) -> (c #* y) -> c # z
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> (c #* x) -> c # y
>.> (c # y) -> c #* y
forall c d x. Transition c d x => (c # x) -> d # x
transition (c # NeuralNetwork gys g y x
g (c # NeuralNetwork gys g y x) -> (c #* x) -> c # y
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> (c #* x) -> c # y
>.> c #* x
x)
    {-# INLINE (>$>) #-}
    >$> :: (c # NeuralNetwork ('(g, y) : gys) f z x) -> [c #* x] -> [c # z]
(>$>) c # NeuralNetwork ('(g, y) : gys) f z x
fg [c #* x]
xs =
        let (c # Affine f z z y
f,c # NeuralNetwork gys g y x
g) = (c # NeuralNetwork ('(g, y) : gys) f z x)
-> (c # First (NeuralNetwork ('(g, y) : gys) f z x),
    c # Second (NeuralNetwork ('(g, y) : gys) f z x))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split c # NeuralNetwork ('(g, y) : gys) f z x
fg
         in c # Affine f z z y
f (c # Affine f z z y) -> [c #* y] -> [c # z]
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> [c #* x] -> [c # y]
>$> ((c # y) -> c #* y) -> [c # y] -> [c #* y]
forall a b. (a -> b) -> [a] -> [b]
map (c # y) -> c #* y
forall c d x. Transition c d x => (c # x) -> d # x
transition (c # NeuralNetwork gys g y x
g (c # NeuralNetwork gys g y x) -> [c #* x] -> [c # y]
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> [c #* x] -> [c # y]
>$> [c #* x]
xs)

instance Map c f z x => Map c (NeuralNetwork '[] f) z x where
    {-# INLINE (>.>) #-}
    >.> :: (c # NeuralNetwork '[] f z x) -> (c #* x) -> c # z
(>.>) c # NeuralNetwork '[] f z x
f c #* x
x = (c # NeuralNetwork '[] f z x) -> c # Affine f z z x
forall c (f :: Type -> Type -> Type) z x.
(c # NeuralNetwork '[] f z x) -> c # Affine f z z x
fromSingleLayerNetwork c # NeuralNetwork '[] f z x
f (c # Affine f z z x) -> (c #* x) -> c # z
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> (c #* x) -> c # y
>.> c #* x
x
    {-# INLINE (>$>) #-}
    >$> :: (c # NeuralNetwork '[] f z x) -> [c #* x] -> [c # z]
(>$>) c # NeuralNetwork '[] f z x
f [c #* x]
xs = (c # NeuralNetwork '[] f z x) -> c # Affine f z z x
forall c (f :: Type -> Type -> Type) z x.
(c # NeuralNetwork '[] f z x) -> c # Affine f z z x
fromSingleLayerNetwork c # NeuralNetwork '[] f z x
f (c # Affine f z z x) -> [c #* x] -> [c # z]
forall c (f :: Type -> Type -> Type) y x.
Map c f y x =>
(c # f y x) -> [c #* x] -> [c # y]
>$> [c #* x]
xs

instance (Propagate c f z x) => Propagate c (NeuralNetwork '[] f) z x where
    {-# INLINE propagate #-}
    propagate :: [c #* z]
-> [c #* x]
-> (c # NeuralNetwork '[] f z x)
-> (c #* NeuralNetwork '[] f z x, [c # z])
propagate [c #* z]
dps [c #* x]
qs c # NeuralNetwork '[] f z x
f =
        let (Dual c # Affine f z z x
df,[c # z]
ps) = [c #* z]
-> [c #* x]
-> (c # Affine f z z x)
-> (Dual c # Affine f z z x, [c # z])
forall c (f :: Type -> Type -> Type) y x.
Propagate c f y x =>
[c #* y] -> [c #* x] -> (c # f y x) -> (c #* f y x, [c # y])
propagate [c #* z]
dps [c #* x]
qs ((c # Affine f z z x) -> (Dual c # Affine f z z x, [c # z]))
-> (c # Affine f z z x) -> (Dual c # Affine f z z x, [c # z])
forall a b. (a -> b) -> a -> b
$ (c # NeuralNetwork '[] f z x) -> c # Affine f z z x
forall c (f :: Type -> Type -> Type) z x.
(c # NeuralNetwork '[] f z x) -> c # Affine f z z x
fromSingleLayerNetwork c # NeuralNetwork '[] f z x
f
         in ((Dual c # Affine f z z x) -> c #* NeuralNetwork '[] f z x
forall c (f :: Type -> Type -> Type) z x.
(c # Affine f z z x) -> c # NeuralNetwork '[] f z x
toSingleLayerNetwork Dual c # Affine f z z x
df,[c # z]
ps)

instance
    ( Propagate c f z y, Propagate c (NeuralNetwork gys g) y x, Map c f y z
    , Transition c (Dual c) y, Legendre y, Riemannian c y, Bilinear f z y)
  => Propagate c (NeuralNetwork ('(g,y) : gys) f) z x where
      {-# INLINE propagate #-}
      propagate :: [c #* z]
-> [c #* x]
-> (c # NeuralNetwork ('(g, y) : gys) f z x)
-> (c #* NeuralNetwork ('(g, y) : gys) f z x, [c # z])
propagate [c #* z]
dzs [c #* x]
xs c # NeuralNetwork ('(g, y) : gys) f z x
fg =
          let (c # Affine f z z y
f,c # NeuralNetwork gys g y x
g) = (c # NeuralNetwork ('(g, y) : gys) f z x)
-> (c # First (NeuralNetwork ('(g, y) : gys) f z x),
    c # Second (NeuralNetwork ('(g, y) : gys) f z x))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split c # NeuralNetwork ('(g, y) : gys) f z x
fg
              fmtx :: c # f z y
fmtx = (c # z, c # f z y) -> c # f z y
forall a b. (a, b) -> b
snd ((c # z, c # f z y) -> c # f z y)
-> (c # z, c # f z y) -> c # f z y
forall a b. (a -> b) -> a -> b
$ (c # Affine f z z y)
-> (c # First (Affine f z z y), c # Second (Affine f z z y))
forall z c. Product z => (c # z) -> (c # First z, c # Second z)
split c # Affine f z z y
f
              mys :: [Dual c # y]
mys = (c # y) -> Dual c # y
forall c d x. Transition c d x => (c # x) -> d # x
transition ((c # y) -> Dual c # y) -> [c # y] -> [Dual c # y]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [c # y]
ys
              (Point (Dual c) (Affine f z z y)
df,[c # z]
zhts) = [c #* z]
-> [Dual c # y]
-> (c # Affine f z z y)
-> (Point (Dual c) (Affine f z z y), [c # z])
forall c (f :: Type -> Type -> Type) y x.
Propagate c f y x =>
[c #* y] -> [c #* x] -> (c # f y x) -> (c #* f y x, [c # y])
propagate [c #* z]
dzs [Dual c # y]
mys c # Affine f z z y
f
              (Point (Dual c) (NeuralNetwork gys g y x)
dg,[c # y]
ys) = [Dual c # y]
-> [c #* x]
-> (c # NeuralNetwork gys g y x)
-> (Point (Dual c) (NeuralNetwork gys g y x), [c # y])
forall c (f :: Type -> Type -> Type) y x.
Propagate c f y x =>
[c #* y] -> [c #* x] -> (c # f y x) -> (c #* f y x, [c # y])
propagate [Dual c # y]
dys [c #* x]
xs c # NeuralNetwork gys g y x
g
              dys0 :: [c # y]
dys0 = [c #* z]
dzs [c #* z] -> (c # f z y) -> [c # y]
forall c (f :: Type -> Type -> Type) x y.
(Map c f x y, Bilinear f y x) =>
[c #* y] -> (c # f y x) -> [c # x]
<$< c # f z y
fmtx
              dys :: [Dual c # y]
dys = ((c # y) -> (c # y) -> Dual c # y)
-> [c # y] -> [c # y] -> [Dual c # y]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (c # y) -> (c # y) -> Dual c # y
forall c x. Riemannian c x => (c # x) -> (c # x) -> c #* x
flat [c # y]
ys [c # y]
dys0
           in ((Dual c # First (NeuralNetwork ('(g, y) : gys) f z x))
-> (Dual c # Second (NeuralNetwork ('(g, y) : gys) f z x))
-> c #* NeuralNetwork ('(g, y) : gys) f z x
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join Dual c # First (NeuralNetwork ('(g, y) : gys) f z x)
Point (Dual c) (Affine f z z y)
df Dual c # Second (NeuralNetwork ('(g, y) : gys) f z x)
Point (Dual c) (NeuralNetwork gys g y x)
dg, [c # z]
zhts)