{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

module Downhill.Linear.Expr
  ( -- * Expression
    Expr (..),
    Term (..),

    -- * Vectors
    BasicVector (..),
    SparseVector (..),
    DenseVector (..),
    DenseBuilder (..),
    toDenseBuilder,

    -- * Generics
    genericSumBuilder,
    genericIdentityBuilder,
    genericSumMaybeBuilder,
    genericIdentityMaybeBuilder,

    -- * Misc
    maybeToMonoid,
  )
where

import Data.Kind (Type)
import Data.Maybe (fromMaybe)
import Data.Semigroup (Sum (Sum, getSum))
import Data.VectorSpace (AdditiveGroup (..), VectorSpace (..), zeroV)
import GHC.Generics (Generic (Rep, from, to), K1 (K1), M1 (M1), U1 (U1), V1, (:*:) ((:*:)))

-- | Argument @f@ in @Term f x@ must be /linear/ function. That's a law.
data Term a v where
  Term :: (v -> VecBuilder u) -> Expr a u -> Term a v

-- | @Expr a v@ represents a linear expression of type @v@, containing some free variables of type @a@.
data Expr a v where
  ExprVar :: Expr a a
  ExprSum :: BasicVector v => [Term a v] -> Expr a v

class Monoid (VecBuilder v) => BasicVector v where
  -- | @VecBuilder v@ is a sparse representation of vector @v@. Edges of a computational graph
  -- produce builders, which are then summed into vectors in nodes. Monoid operation '<>'
  -- means addition of vectors, but it doesn't need to compute the sum immediately - it
  -- might defer computation until 'sumBuilder' is evaluated.
  --
  -- @
  -- sumBuilder mempty = zeroV
  -- sumBuilder (x <> y) = sumBuilder x ^+^ sumBuilder y
  -- @
  --
  -- 'mempty' must be cheap. '<>' must be O(1).
  type VecBuilder v :: Type

  sumBuilder :: VecBuilder v -> v
  identityBuilder :: v -> VecBuilder v

  default sumBuilder ::
    forall b.
    ( VecBuilder v ~ Maybe b,
      Generic b,
      Generic v,
      GBasicVector (Rep b) (Rep v),
      AdditiveGroup v
    ) =>
    VecBuilder v ->
    v
  sumBuilder = forall b v.
(Generic b, Generic v, AdditiveGroup v,
 GBasicVector (Rep b) (Rep v)) =>
Maybe b -> v
genericSumMaybeBuilder @b @v

  default identityBuilder ::
    forall b.
    ( VecBuilder v ~ Maybe b,
      Generic b,
      Generic v,
      GBasicVector (Rep b) (Rep v),
      AdditiveGroup v
    ) =>
    v ->
    VecBuilder v
  identityBuilder = forall b v.
(Generic b, Generic v, GBasicVector (Rep b) (Rep v)) =>
v -> Maybe b
genericIdentityMaybeBuilder @b @v

maybeToMonoid :: Monoid m => Maybe m -> m
maybeToMonoid :: forall m. Monoid m => Maybe m -> m
maybeToMonoid = forall a. a -> Maybe a -> a
fromMaybe forall a. Monoid a => a
mempty

_maybeToVector :: AdditiveGroup v => Maybe v -> v
_maybeToVector :: forall v. AdditiveGroup v => Maybe v -> v
_maybeToVector = forall a. a -> Maybe a -> a
fromMaybe forall v. AdditiveGroup v => v
zeroV

instance BasicVector Integer where
  type VecBuilder Integer = Sum Integer
  sumBuilder :: VecBuilder Integer -> Integer
sumBuilder = forall a. Sum a -> a
getSum
  identityBuilder :: Integer -> VecBuilder Integer
identityBuilder = forall a. a -> Sum a
Sum

instance (BasicVector a, BasicVector b) => BasicVector (a, b) where
  type VecBuilder (a, b) = Maybe (VecBuilder a, VecBuilder b)
  sumBuilder :: VecBuilder (a, b) -> (a, b)
sumBuilder = forall {a} {b}.
(BasicVector a, BasicVector b) =>
(VecBuilder a, VecBuilder b) -> (a, b)
sumPair forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Monoid m => Maybe m -> m
maybeToMonoid
    where
      sumPair :: (VecBuilder a, VecBuilder b) -> (a, b)
sumPair (VecBuilder a
a, VecBuilder b
b) = (forall v. BasicVector v => VecBuilder v -> v
sumBuilder VecBuilder a
a, forall v. BasicVector v => VecBuilder v -> v
sumBuilder VecBuilder b
b)
  identityBuilder :: (a, b) -> VecBuilder (a, b)
identityBuilder (a
x, b
y) = forall a. a -> Maybe a
Just (forall v. BasicVector v => v -> VecBuilder v
identityBuilder a
x, forall v. BasicVector v => v -> VecBuilder v
identityBuilder b
y)

instance (BasicVector a, BasicVector b, BasicVector c) => BasicVector (a, b, c) where
  type VecBuilder (a, b, c) = Maybe (VecBuilder a, VecBuilder b, VecBuilder c)
  sumBuilder :: VecBuilder (a, b, c) -> (a, b, c)
sumBuilder = forall {a} {b} {c}.
(BasicVector a, BasicVector b, BasicVector c) =>
(VecBuilder a, VecBuilder b, VecBuilder c) -> (a, b, c)
sumTriple forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall m. Monoid m => Maybe m -> m
maybeToMonoid
    where
      sumTriple :: (VecBuilder a, VecBuilder b, VecBuilder c) -> (a, b, c)
sumTriple (VecBuilder a
a, VecBuilder b
b, VecBuilder c
c) = (forall v. BasicVector v => VecBuilder v -> v
sumBuilder VecBuilder a
a, forall v. BasicVector v => VecBuilder v -> v
sumBuilder VecBuilder b
b, forall v. BasicVector v => VecBuilder v -> v
sumBuilder VecBuilder c
c)
  identityBuilder :: (a, b, c) -> VecBuilder (a, b, c)
identityBuilder (a
x, b
y, c
z) = forall a. a -> Maybe a
Just (forall v. BasicVector v => v -> VecBuilder v
identityBuilder a
x, forall v. BasicVector v => v -> VecBuilder v
identityBuilder b
y, forall v. BasicVector v => v -> VecBuilder v
identityBuilder c
z)

instance BasicVector Float where
  type VecBuilder Float = Sum Float
  sumBuilder :: VecBuilder Float -> Float
sumBuilder = forall a. Sum a -> a
getSum
  identityBuilder :: Float -> VecBuilder Float
identityBuilder = forall a. a -> Sum a
Sum

instance BasicVector Double where
  type VecBuilder Double = Sum Double
  sumBuilder :: VecBuilder Double -> Double
sumBuilder = forall a. Sum a -> a
getSum
  identityBuilder :: Double -> VecBuilder Double
identityBuilder = forall a. a -> Sum a
Sum

-- |  Normally graph node would compute the sum of gradients and then
-- propagate it to ancestor nodes. That's the best strategy when
-- some computation needs to be performed for backpropagation.
-- Some operations, like constructing/deconstructing tuples or
-- wrapping/unwrapping, don't need to compute the sum. Doing so only
-- destroys sparsity. A node of type @SparseVector v@ won't sum
-- the gradients, it will simply forward builders to its parents.
newtype SparseVector v = SparseVector
  {forall v. SparseVector v -> VecBuilder v
unSparseVector :: VecBuilder v}

deriving via (VecBuilder v) instance Semigroup (VecBuilder v) => Semigroup (SparseVector v)

instance Monoid (VecBuilder v) => BasicVector (SparseVector v) where
  type VecBuilder (SparseVector v) = VecBuilder v
  sumBuilder :: VecBuilder (SparseVector v) -> SparseVector v
sumBuilder = forall v. VecBuilder v -> SparseVector v
SparseVector
  identityBuilder :: SparseVector v -> VecBuilder (SparseVector v)
identityBuilder = forall v. SparseVector v -> VecBuilder v
unSparseVector

newtype DenseSemibuilder v = DenseSemibuilder {forall v. DenseSemibuilder v -> v
_unDenseSemibuilder :: v}

instance AdditiveGroup v => Semigroup (DenseSemibuilder v) where
  DenseSemibuilder v
x <> :: DenseSemibuilder v -> DenseSemibuilder v -> DenseSemibuilder v
<> DenseSemibuilder v
y = forall v. v -> DenseSemibuilder v
DenseSemibuilder (v
x forall v. AdditiveGroup v => v -> v -> v
^+^ v
y)

newtype DenseBuilder v = DenseBuilder (Maybe v)
  deriving (NonEmpty (DenseBuilder v) -> DenseBuilder v
DenseBuilder v -> DenseBuilder v -> DenseBuilder v
forall b. Integral b => b -> DenseBuilder v -> DenseBuilder v
forall v.
AdditiveGroup v =>
NonEmpty (DenseBuilder v) -> DenseBuilder v
forall v.
AdditiveGroup v =>
DenseBuilder v -> DenseBuilder v -> DenseBuilder v
forall v b.
(AdditiveGroup v, Integral b) =>
b -> DenseBuilder v -> DenseBuilder v
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
stimes :: forall b. Integral b => b -> DenseBuilder v -> DenseBuilder v
$cstimes :: forall v b.
(AdditiveGroup v, Integral b) =>
b -> DenseBuilder v -> DenseBuilder v
sconcat :: NonEmpty (DenseBuilder v) -> DenseBuilder v
$csconcat :: forall v.
AdditiveGroup v =>
NonEmpty (DenseBuilder v) -> DenseBuilder v
<> :: DenseBuilder v -> DenseBuilder v -> DenseBuilder v
$c<> :: forall v.
AdditiveGroup v =>
DenseBuilder v -> DenseBuilder v -> DenseBuilder v
Semigroup, DenseBuilder v
[DenseBuilder v] -> DenseBuilder v
DenseBuilder v -> DenseBuilder v -> DenseBuilder v
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall v. AdditiveGroup v => Semigroup (DenseBuilder v)
forall v. AdditiveGroup v => DenseBuilder v
forall v. AdditiveGroup v => [DenseBuilder v] -> DenseBuilder v
forall v.
AdditiveGroup v =>
DenseBuilder v -> DenseBuilder v -> DenseBuilder v
mconcat :: [DenseBuilder v] -> DenseBuilder v
$cmconcat :: forall v. AdditiveGroup v => [DenseBuilder v] -> DenseBuilder v
mappend :: DenseBuilder v -> DenseBuilder v -> DenseBuilder v
$cmappend :: forall v.
AdditiveGroup v =>
DenseBuilder v -> DenseBuilder v -> DenseBuilder v
mempty :: DenseBuilder v
$cmempty :: forall v. AdditiveGroup v => DenseBuilder v
Monoid) via (Maybe (DenseSemibuilder v))

toDenseBuilder :: v -> DenseBuilder v
toDenseBuilder :: forall v. v -> DenseBuilder v
toDenseBuilder = forall v. Maybe v -> DenseBuilder v
DenseBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just

-- | When sparsity is not needed, we can use vector @v@ as a builder of itself.
-- @DenseVector@ takes care of that.
newtype DenseVector v = DenseVector v
  deriving (DenseVector v
DenseVector v -> DenseVector v
DenseVector v -> DenseVector v -> DenseVector v
forall v.
v -> (v -> v -> v) -> (v -> v) -> (v -> v -> v) -> AdditiveGroup v
forall v. AdditiveGroup v => DenseVector v
forall v. AdditiveGroup v => DenseVector v -> DenseVector v
forall v.
AdditiveGroup v =>
DenseVector v -> DenseVector v -> DenseVector v
^-^ :: DenseVector v -> DenseVector v -> DenseVector v
$c^-^ :: forall v.
AdditiveGroup v =>
DenseVector v -> DenseVector v -> DenseVector v
negateV :: DenseVector v -> DenseVector v
$cnegateV :: forall v. AdditiveGroup v => DenseVector v -> DenseVector v
^+^ :: DenseVector v -> DenseVector v -> DenseVector v
$c^+^ :: forall v.
AdditiveGroup v =>
DenseVector v -> DenseVector v -> DenseVector v
zeroV :: DenseVector v
$czeroV :: forall v. AdditiveGroup v => DenseVector v
AdditiveGroup, Scalar (DenseVector v) -> DenseVector v -> DenseVector v
forall {v}. VectorSpace v => AdditiveGroup (DenseVector v)
forall v.
VectorSpace v =>
Scalar (DenseVector v) -> DenseVector v -> DenseVector v
forall v. AdditiveGroup v -> (Scalar v -> v -> v) -> VectorSpace v
*^ :: Scalar (DenseVector v) -> DenseVector v -> DenseVector v
$c*^ :: forall v.
VectorSpace v =>
Scalar (DenseVector v) -> DenseVector v -> DenseVector v
VectorSpace) via v

instance AdditiveGroup v => BasicVector (DenseVector v) where
  type VecBuilder (DenseVector v) = DenseBuilder v
  sumBuilder :: VecBuilder (DenseVector v) -> DenseVector v
sumBuilder (DenseBuilder Maybe v
Nothing) = forall v. v -> DenseVector v
DenseVector forall v. AdditiveGroup v => v
zeroV
  sumBuilder (DenseBuilder (Just v
x)) = forall v. v -> DenseVector v
DenseVector v
x
  identityBuilder :: DenseVector v -> VecBuilder (DenseVector v)
identityBuilder (DenseVector v
v) = forall v. Maybe v -> DenseBuilder v
DenseBuilder (forall a. a -> Maybe a
Just v
v)

class GBasicVector b v where
  gsumBuilder :: b p -> v p
  gidentityBuilder :: v p -> b p

instance (BasicVector v, b ~ VecBuilder v) => GBasicVector (K1 x b) (K1 x v) where
  gsumBuilder :: forall p. K1 x b p -> K1 x v p
gsumBuilder (K1 b
x) = forall k i c (p :: k). c -> K1 i c p
K1 (forall v. BasicVector v => VecBuilder v -> v
sumBuilder b
x)
  gidentityBuilder :: forall p. K1 x v p -> K1 x b p
gidentityBuilder (K1 v
x) = forall k i c (p :: k). c -> K1 i c p
K1 (forall v. BasicVector v => v -> VecBuilder v
identityBuilder v
x)

instance (GBasicVector b v) => GBasicVector (M1 x y b) (M1 x y' v) where
  gsumBuilder :: forall p. M1 x y b p -> M1 x y' v p
gsumBuilder (M1 b p
x) = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (forall (b :: * -> *) (v :: * -> *) p.
GBasicVector b v =>
b p -> v p
gsumBuilder b p
x)
  gidentityBuilder :: forall p. M1 x y' v p -> M1 x y b p
gidentityBuilder (M1 v p
x) = forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (forall (b :: * -> *) (v :: * -> *) p.
GBasicVector b v =>
v p -> b p
gidentityBuilder v p
x)

instance (GBasicVector bu u, GBasicVector bv v) => GBasicVector (bu :*: bv) (u :*: v) where
  gsumBuilder :: forall p. (:*:) bu bv p -> (:*:) u v p
gsumBuilder (bu p
x1 :*: bv p
x2) = forall (b :: * -> *) (v :: * -> *) p.
GBasicVector b v =>
b p -> v p
gsumBuilder bu p
x1 forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall (b :: * -> *) (v :: * -> *) p.
GBasicVector b v =>
b p -> v p
gsumBuilder bv p
x2
  gidentityBuilder :: forall p. (:*:) u v p -> (:*:) bu bv p
gidentityBuilder (u p
x1 :*: v p
x2) = forall (b :: * -> *) (v :: * -> *) p.
GBasicVector b v =>
v p -> b p
gidentityBuilder u p
x1 forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: forall (b :: * -> *) (v :: * -> *) p.
GBasicVector b v =>
v p -> b p
gidentityBuilder v p
x2

instance GBasicVector V1 V1 where
  gsumBuilder :: forall p. V1 p -> V1 p
gsumBuilder = \case {}
  gidentityBuilder :: forall p. V1 p -> V1 p
gidentityBuilder = \case {}

instance GBasicVector U1 U1 where
  gsumBuilder :: forall p. U1 p -> U1 p
gsumBuilder U1 p
U1 = forall k (p :: k). U1 p
U1
  gidentityBuilder :: forall p. U1 p -> U1 p
gidentityBuilder U1 p
U1 = forall k (p :: k). U1 p
U1

genericSumBuilder :: forall b v. (Generic b, Generic v, GBasicVector (Rep b) (Rep v)) => b -> v
genericSumBuilder :: forall b v.
(Generic b, Generic v, GBasicVector (Rep b) (Rep v)) =>
b -> v
genericSumBuilder = forall a x. Generic a => Rep a x -> a
to forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (b :: * -> *) (v :: * -> *) p.
GBasicVector b v =>
b p -> v p
gsumBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from

genericIdentityBuilder :: forall b v. (Generic b, Generic v, GBasicVector (Rep b) (Rep v)) => v -> b
genericIdentityBuilder :: forall b v.
(Generic b, Generic v, GBasicVector (Rep b) (Rep v)) =>
v -> b
genericIdentityBuilder = forall a x. Generic a => Rep a x -> a
to forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (b :: * -> *) (v :: * -> *) p.
GBasicVector b v =>
v p -> b p
gidentityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a x. Generic a => a -> Rep a x
from

genericSumMaybeBuilder ::
  forall b v.
  (Generic b, Generic v, AdditiveGroup v, GBasicVector (Rep b) (Rep v)) =>
  Maybe b ->
  v
genericSumMaybeBuilder :: forall b v.
(Generic b, Generic v, AdditiveGroup v,
 GBasicVector (Rep b) (Rep v)) =>
Maybe b -> v
genericSumMaybeBuilder = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall v. AdditiveGroup v => v
zeroV forall b v.
(Generic b, Generic v, GBasicVector (Rep b) (Rep v)) =>
b -> v
genericSumBuilder

genericIdentityMaybeBuilder :: forall b v. (Generic b, Generic v, GBasicVector (Rep b) (Rep v)) => v -> Maybe b
genericIdentityMaybeBuilder :: forall b v.
(Generic b, Generic v, GBasicVector (Rep b) (Rep v)) =>
v -> Maybe b
genericIdentityMaybeBuilder = forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b v.
(Generic b, Generic v, GBasicVector (Rep b) (Rep v)) =>
v -> b
genericIdentityBuilder