{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- | While 'BackGrad' is intended to be simple to construct manually, this module provides a way to do
--   that with a bit less of boilerplate.
module Downhill.Linear.Lift
  ( -- * Lifts
    lift1,
    lift2,
    lift3,

    -- * Dense lifts
    lift1_dense,
    lift2_dense,
    lift3_dense,

    -- * Lifts for 'SparseVector'
    lift1_sparse,
    lift2_sparse,
    lift3_sparse,
  )
where

import Downhill.Linear.BackGrad (BackGrad (..), castBackGrad, realNode)
import Downhill.Linear.Expr (BasicVector (..), Expr (ExprSum), SparseVector (unSparseVector))
import Prelude hiding (fst, snd, zip)

lift1 ::
  forall z r a.
  BasicVector z =>
  (z -> VecBuilder a) ->
  BackGrad r a ->
  BackGrad r z
lift1 :: forall z r a.
BasicVector z =>
(z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1 z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
da) = forall a v. Expr a v -> BackGrad a v
realNode Expr r z
node
  where
    node :: Expr r z
node = forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder a) -> Term r x
da z -> VecBuilder a
fa]

lift2 ::
  forall z r a b.
  BasicVector z =>
  (z -> VecBuilder a) ->
  (z -> VecBuilder b) ->
  BackGrad r a ->
  BackGrad r b ->
  BackGrad r z
lift2 :: forall z r a b.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r z
lift2 z -> VecBuilder a
fa z -> VecBuilder b
fb (BackGrad forall x. (x -> VecBuilder a) -> Term r x
da) (BackGrad forall x. (x -> VecBuilder b) -> Term r x
db) = forall a v. Expr a v -> BackGrad a v
realNode Expr r z
node
  where
    node :: Expr r z
node = forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder a) -> Term r x
da z -> VecBuilder a
fa, forall x. (x -> VecBuilder b) -> Term r x
db z -> VecBuilder b
fb]

lift3 ::
  forall z r a b c.
  BasicVector z =>
  (z -> VecBuilder a) ->
  (z -> VecBuilder b) ->
  (z -> VecBuilder c) ->
  BackGrad r a ->
  BackGrad r b ->
  BackGrad r c ->
  BackGrad r z
lift3 :: forall z r a b c.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> (z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r z
lift3 z -> VecBuilder a
fa z -> VecBuilder b
fb z -> VecBuilder c
fc (BackGrad forall x. (x -> VecBuilder a) -> Term r x
da) (BackGrad forall x. (x -> VecBuilder b) -> Term r x
db) (BackGrad forall x. (x -> VecBuilder c) -> Term r x
dc) = forall a v. Expr a v -> BackGrad a v
realNode Expr r z
node
  where
    node :: Expr r z
node = forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [forall x. (x -> VecBuilder a) -> Term r x
da z -> VecBuilder a
fa, forall x. (x -> VecBuilder b) -> Term r x
db z -> VecBuilder b
fb, forall x. (x -> VecBuilder c) -> Term r x
dc z -> VecBuilder c
fc]

-- | Same as 'sparseNode', included here for completeness.
lift1_sparse ::
  forall r a z.
  BasicVector z =>
  (VecBuilder z -> VecBuilder a) ->
  BackGrad r a ->
  BackGrad r z
lift1_sparse :: forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1_sparse VecBuilder z -> VecBuilder a
fa = forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall z r a.
BasicVector z =>
(z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1 @(SparseVector z) SparseVector z -> VecBuilder a
fa'
  where
    fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector

lift2_sparse ::
  forall r a b z.
  BasicVector z =>
  (VecBuilder z -> VecBuilder a) ->
  (VecBuilder z -> VecBuilder b) ->
  BackGrad r a ->
  BackGrad r b ->
  BackGrad r z
lift2_sparse :: forall r a b z.
BasicVector z =>
(VecBuilder z -> VecBuilder a)
-> (VecBuilder z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r z
lift2_sparse VecBuilder z -> VecBuilder a
fa VecBuilder z -> VecBuilder b
fb BackGrad r a
a BackGrad r b
b = forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad forall a b. (a -> b) -> a -> b
$ forall z r a b.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r z
lift2 @(SparseVector z) SparseVector z -> VecBuilder a
fa' SparseVector z -> VecBuilder b
fb' BackGrad r a
a BackGrad r b
b
  where
    fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector
    fb' :: SparseVector z -> VecBuilder b
fb' = VecBuilder z -> VecBuilder b
fb forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector

lift3_sparse ::
  forall r a b c z.
  BasicVector z =>
  (VecBuilder z -> VecBuilder a) ->
  (VecBuilder z -> VecBuilder b) ->
  (VecBuilder z -> VecBuilder c) ->
  BackGrad r a ->
  BackGrad r b ->
  BackGrad r c ->
  BackGrad r z
lift3_sparse :: forall r a b c z.
BasicVector z =>
(VecBuilder z -> VecBuilder a)
-> (VecBuilder z -> VecBuilder b)
-> (VecBuilder z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r z
lift3_sparse VecBuilder z -> VecBuilder a
fa VecBuilder z -> VecBuilder b
fb VecBuilder z -> VecBuilder c
fc BackGrad r a
a BackGrad r b
b BackGrad r c
c =
  forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad forall a b. (a -> b) -> a -> b
$
    forall z r a b c.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> (z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r z
lift3 @(SparseVector z) SparseVector z -> VecBuilder a
fa' SparseVector z -> VecBuilder b
fb' SparseVector z -> VecBuilder c
fc' BackGrad r a
a BackGrad r b
b BackGrad r c
c
  where
    fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector
    fb' :: SparseVector z -> VecBuilder b
fb' = VecBuilder z -> VecBuilder b
fb forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector
    fc' :: SparseVector z -> VecBuilder c
fc' = VecBuilder z -> VecBuilder c
fc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall v. SparseVector v -> VecBuilder v
unSparseVector

lift1_dense ::
  (BasicVector v, BasicVector a) =>
  ((v -> a) -> BackGrad r a -> BackGrad r v)
lift1_dense :: forall v a r.
(BasicVector v, BasicVector a) =>
(v -> a) -> BackGrad r a -> BackGrad r v
lift1_dense v -> a
fa = forall z r a.
BasicVector z =>
(z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1 (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> a
fa)

lift2_dense ::
  (BasicVector v, BasicVector a, BasicVector b) =>
  (v -> a) ->
  (v -> b) ->
  BackGrad r a ->
  BackGrad r b ->
  BackGrad r v
lift2_dense :: forall v a b r.
(BasicVector v, BasicVector a, BasicVector b) =>
(v -> a)
-> (v -> b) -> BackGrad r a -> BackGrad r b -> BackGrad r v
lift2_dense v -> a
fa v -> b
fb = forall z r a b.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r z
lift2 (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> a
fa) (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> b
fb)

lift3_dense ::
  (BasicVector v, BasicVector a, BasicVector b, BasicVector c) =>
  (v -> a) ->
  (v -> b) ->
  (v -> c) ->
  BackGrad r a ->
  BackGrad r b ->
  BackGrad r c ->
  BackGrad r v
lift3_dense :: forall v a b c r.
(BasicVector v, BasicVector a, BasicVector b, BasicVector c) =>
(v -> a)
-> (v -> b)
-> (v -> c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r v
lift3_dense v -> a
fa v -> b
fb v -> c
fc = forall z r a b c.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> (z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r z
lift3 (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> a
fa) (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> b
fb) (forall v. BasicVector v => v -> VecBuilder v
identityBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> c
fc)