{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Downhill.Linear.Lift
(
lift1,
lift2,
lift3,
lift1_dense,
lift2_dense,
lift3_dense,
lift1_sparse,
lift2_sparse,
lift3_sparse,
)
where
import Downhill.Linear.BackGrad (BackGrad (..), castBackGrad, realNode)
import Downhill.Linear.Expr (BasicVector (..), Expr (ExprSum), FullVector (identityBuilder), SparseVector (unSparseVector))
import Prelude hiding (fst, snd, zip)
lift1 ::
forall z r a.
BasicVector z =>
(z -> VecBuilder a) ->
BackGrad r a ->
BackGrad r z
lift1 :: (z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1 z -> VecBuilder a
fa (BackGrad forall x. (x -> VecBuilder a) -> Term r x
da) = Expr r z -> BackGrad r z
forall a v. Expr a v -> BackGrad a v
realNode Expr r z
node
where
node :: Expr r z
node = [Term r z] -> Expr r z
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(z -> VecBuilder a) -> Term r z
forall x. (x -> VecBuilder a) -> Term r x
da z -> VecBuilder a
fa]
lift2 ::
forall z r a b.
BasicVector z =>
(z -> VecBuilder a) ->
(z -> VecBuilder b) ->
BackGrad r a ->
BackGrad r b ->
BackGrad r z
lift2 :: (z -> VecBuilder a)
-> (z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r z
lift2 z -> VecBuilder a
fa z -> VecBuilder b
fb (BackGrad forall x. (x -> VecBuilder a) -> Term r x
da) (BackGrad forall x. (x -> VecBuilder b) -> Term r x
db) = Expr r z -> BackGrad r z
forall a v. Expr a v -> BackGrad a v
realNode Expr r z
node
where
node :: Expr r z
node = [Term r z] -> Expr r z
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(z -> VecBuilder a) -> Term r z
forall x. (x -> VecBuilder a) -> Term r x
da z -> VecBuilder a
fa, (z -> VecBuilder b) -> Term r z
forall x. (x -> VecBuilder b) -> Term r x
db z -> VecBuilder b
fb]
lift3 ::
forall z r a b c.
BasicVector z =>
(z -> VecBuilder a) ->
(z -> VecBuilder b) ->
(z -> VecBuilder c) ->
BackGrad r a ->
BackGrad r b ->
BackGrad r c ->
BackGrad r z
lift3 :: (z -> VecBuilder a)
-> (z -> VecBuilder b)
-> (z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r z
lift3 z -> VecBuilder a
fa z -> VecBuilder b
fb z -> VecBuilder c
fc (BackGrad forall x. (x -> VecBuilder a) -> Term r x
da) (BackGrad forall x. (x -> VecBuilder b) -> Term r x
db) (BackGrad forall x. (x -> VecBuilder c) -> Term r x
dc) = Expr r z -> BackGrad r z
forall a v. Expr a v -> BackGrad a v
realNode Expr r z
node
where
node :: Expr r z
node = [Term r z] -> Expr r z
forall v a. BasicVector v => [Term a v] -> Expr a v
ExprSum [(z -> VecBuilder a) -> Term r z
forall x. (x -> VecBuilder a) -> Term r x
da z -> VecBuilder a
fa, (z -> VecBuilder b) -> Term r z
forall x. (x -> VecBuilder b) -> Term r x
db z -> VecBuilder b
fb, (z -> VecBuilder c) -> Term r z
forall x. (x -> VecBuilder c) -> Term r x
dc z -> VecBuilder c
fc]
lift1_sparse ::
forall r a z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) ->
BackGrad r a ->
BackGrad r z
lift1_sparse :: (VecBuilder z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1_sparse VecBuilder z -> VecBuilder a
fa = BackGrad r (SparseVector z) -> BackGrad r z
forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (BackGrad r (SparseVector z) -> BackGrad r z)
-> (BackGrad r a -> BackGrad r (SparseVector z))
-> BackGrad r a
-> BackGrad r z
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SparseVector z -> VecBuilder a)
-> BackGrad r a -> BackGrad r (SparseVector z)
forall z r a.
BasicVector z =>
(z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1 @(SparseVector z) SparseVector z -> VecBuilder a
fa'
where
fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa (VecBuilder z -> VecBuilder a)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector
lift2_sparse ::
forall r a b z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) ->
(VecBuilder z -> VecBuilder b) ->
BackGrad r a ->
BackGrad r b ->
BackGrad r z
lift2_sparse :: (VecBuilder z -> VecBuilder a)
-> (VecBuilder z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r z
lift2_sparse VecBuilder z -> VecBuilder a
fa VecBuilder z -> VecBuilder b
fb BackGrad r a
a BackGrad r b
b = BackGrad r (SparseVector z) -> BackGrad r z
forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (BackGrad r (SparseVector z) -> BackGrad r z)
-> BackGrad r (SparseVector z) -> BackGrad r z
forall a b. (a -> b) -> a -> b
$ (SparseVector z -> VecBuilder a)
-> (SparseVector z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r (SparseVector z)
forall z r a b.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r z
lift2 @(SparseVector z) SparseVector z -> VecBuilder a
fa' SparseVector z -> VecBuilder b
fb' BackGrad r a
a BackGrad r b
b
where
fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa (VecBuilder z -> VecBuilder a)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector
fb' :: SparseVector z -> VecBuilder b
fb' = VecBuilder z -> VecBuilder b
fb (VecBuilder z -> VecBuilder b)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector
lift3_sparse ::
forall r a b c z.
BasicVector z =>
(VecBuilder z -> VecBuilder a) ->
(VecBuilder z -> VecBuilder b) ->
(VecBuilder z -> VecBuilder c) ->
BackGrad r a ->
BackGrad r b ->
BackGrad r c ->
BackGrad r z
lift3_sparse :: (VecBuilder z -> VecBuilder a)
-> (VecBuilder z -> VecBuilder b)
-> (VecBuilder z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r z
lift3_sparse VecBuilder z -> VecBuilder a
fa VecBuilder z -> VecBuilder b
fb VecBuilder z -> VecBuilder c
fc BackGrad r a
a BackGrad r b
b BackGrad r c
c =
BackGrad r (SparseVector z) -> BackGrad r z
forall r v z.
(VecBuilder z ~ VecBuilder v) =>
BackGrad r v -> BackGrad r z
castBackGrad (BackGrad r (SparseVector z) -> BackGrad r z)
-> BackGrad r (SparseVector z) -> BackGrad r z
forall a b. (a -> b) -> a -> b
$
(SparseVector z -> VecBuilder a)
-> (SparseVector z -> VecBuilder b)
-> (SparseVector z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r (SparseVector z)
forall z r a b c.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> (z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r z
lift3 @(SparseVector z) SparseVector z -> VecBuilder a
fa' SparseVector z -> VecBuilder b
fb' SparseVector z -> VecBuilder c
fc' BackGrad r a
a BackGrad r b
b BackGrad r c
c
where
fa' :: SparseVector z -> VecBuilder a
fa' = VecBuilder z -> VecBuilder a
fa (VecBuilder z -> VecBuilder a)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector
fb' :: SparseVector z -> VecBuilder b
fb' = VecBuilder z -> VecBuilder b
fb (VecBuilder z -> VecBuilder b)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector
fc' :: SparseVector z -> VecBuilder c
fc' = VecBuilder z -> VecBuilder c
fc (VecBuilder z -> VecBuilder c)
-> (SparseVector z -> VecBuilder z)
-> SparseVector z
-> VecBuilder c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SparseVector z -> VecBuilder z
forall v. SparseVector v -> VecBuilder v
unSparseVector
lift1_dense ::
(BasicVector v, FullVector a) =>
((v -> a) -> BackGrad r a -> BackGrad r v)
lift1_dense :: (v -> a) -> BackGrad r a -> BackGrad r v
lift1_dense v -> a
fa = (v -> VecBuilder a) -> BackGrad r a -> BackGrad r v
forall z r a.
BasicVector z =>
(z -> VecBuilder a) -> BackGrad r a -> BackGrad r z
lift1 (a -> VecBuilder a
forall v. FullVector v => v -> VecBuilder v
identityBuilder (a -> VecBuilder a) -> (v -> a) -> v -> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> a
fa)
lift2_dense ::
(BasicVector v, FullVector a, FullVector b) =>
(v -> a) ->
(v -> b) ->
BackGrad r a ->
BackGrad r b ->
BackGrad r v
lift2_dense :: (v -> a)
-> (v -> b) -> BackGrad r a -> BackGrad r b -> BackGrad r v
lift2_dense v -> a
fa v -> b
fb = (v -> VecBuilder a)
-> (v -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r v
forall z r a b.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r z
lift2 (a -> VecBuilder a
forall v. FullVector v => v -> VecBuilder v
identityBuilder (a -> VecBuilder a) -> (v -> a) -> v -> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> a
fa) (b -> VecBuilder b
forall v. FullVector v => v -> VecBuilder v
identityBuilder (b -> VecBuilder b) -> (v -> b) -> v -> VecBuilder b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> b
fb)
lift3_dense ::
(BasicVector v, FullVector a, FullVector b, FullVector c) =>
(v -> a) ->
(v -> b) ->
(v -> c) ->
BackGrad r a ->
BackGrad r b ->
BackGrad r c ->
BackGrad r v
lift3_dense :: (v -> a)
-> (v -> b)
-> (v -> c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r v
lift3_dense v -> a
fa v -> b
fb v -> c
fc = (v -> VecBuilder a)
-> (v -> VecBuilder b)
-> (v -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r v
forall z r a b c.
BasicVector z =>
(z -> VecBuilder a)
-> (z -> VecBuilder b)
-> (z -> VecBuilder c)
-> BackGrad r a
-> BackGrad r b
-> BackGrad r c
-> BackGrad r z
lift3 (a -> VecBuilder a
forall v. FullVector v => v -> VecBuilder v
identityBuilder (a -> VecBuilder a) -> (v -> a) -> v -> VecBuilder a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> a
fa) (b -> VecBuilder b
forall v. FullVector v => v -> VecBuilder v
identityBuilder (b -> VecBuilder b) -> (v -> b) -> v -> VecBuilder b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> b
fb) (c -> VecBuilder c
forall v. FullVector v => v -> VecBuilder v
identityBuilder (c -> VecBuilder c) -> (v -> c) -> v -> VecBuilder c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. v -> c
fc)