module Numeric.LinearAlgebra.Static.Backprop (
H.R
, H.โ
, vec2
, vec3
, vec4
, (&)
, (#)
, split
, headTail
, vector
, linspace
, H.range
, H.dim
, H.L
, H.Sq
, row
, col
, (|||)
, (===)
, splitRows
, splitCols
, unrow
, uncol
, tr
, H.eye
, diag
, matrix
, H.โ
, H.C
, H.M
, H.๐
, (<>)
, (#>)
, (<.>)
, svd
, svd_
, H.Eigen
, eigensystem
, eigenvalues
, chol
, H.Normed
, norm_0
, norm_1V
, norm_1M
, norm_2V
, norm_2M
, norm_InfV
, norm_InfM
, mean
, meanCov
, meanL
, cov
, H.Disp(..)
, H.Domain
, mul
, app
, dot
, cross
, diagR
, dvmap
, dvmap'
, dmmap
, dmmap'
, outer
, zipWithVector
, zipWithVector'
, det
, invlndet
, lndet
, inv
, toRows
, toColumns
, fromRows
, fromColumns
, konst
, sumElements
, extractV
, extractM
, create
, H.Diag
, takeDiag
, H.Sym
, sym
, mTm
, unSym
, (<ยท>)
) where
import Data.ANum
import Data.Maybe
import Data.Proxy
import Foreign.Storable
import GHC.TypeLits
import Lens.Micro hiding ((&))
import Numeric.Backprop
import Numeric.Backprop.Op
import Numeric.Backprop.Tuple
import Unsafe.Coerce
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Sized as SVG
import qualified Data.Vector.Sized as SV
import qualified Data.Vector.Storable.Sized as SVS
import qualified Numeric.LinearAlgebra as HU
import qualified Numeric.LinearAlgebra.Devel as HU
import qualified Numeric.LinearAlgebra.Static as H
import qualified Numeric.LinearAlgebra.Static.Vector as H
#if MIN_VERSION_base(4,11,0)
import Prelude hiding ((<>))
#endif
vec2
:: Reifies s W
=> BVar s H.โ
-> BVar s H.โ
-> BVar s (H.R 2)
vec2 = liftOp2 $ opIsoN
(\(x ::< y ::< ร) -> H.vec2 x y )
(\(HU.toList.H.extract->[dx,dy]) -> dx ::< dy ::< ร)
vec3
:: Reifies s W
=> BVar s H.โ
-> BVar s H.โ
-> BVar s H.โ
-> BVar s (H.R 3)
vec3 = liftOp3 $ opIsoN
(\(x ::< y ::< z ::< ร) -> H.vec3 x y z )
(\(HU.toList.H.extract->[dx,dy,dz]) -> dx ::< dy ::< dz ::< ร)
vec4
:: Reifies s W
=> BVar s H.โ
-> BVar s H.โ
-> BVar s H.โ
-> BVar s H.โ
-> BVar s (H.R 4)
vec4 vX vY vZ vW = liftOp o (vX :< vY :< vZ :< vW :< ร)
where
o :: Op '[H.โ, H.โ, H.โ, H.โ] (H.R 4)
o = opIsoN
(\(x ::< y ::< z ::< w ::< ร) -> H.vec4 x y z w )
(\(HU.toList.H.extract->[dx,dy,dz,dw]) -> dx ::< dy ::< dz ::< dw ::< ร)
(&) :: (Reifies s W, KnownNat n, 1 <= n, KnownNat (n + 1))
=> BVar s (H.R n)
-> BVar s H.โ
-> BVar s (H.R (n + 1))
(&) = liftOp2 $ opIsoN
(\(xs ::< y ::< ร) -> xs H.& y )
(\(H.split->(dxs,dy)) -> dxs ::< fst (H.headTail dy) ::< ร)
infixl 4 &
(#) :: (Reifies s W, KnownNat n, KnownNat m)
=> BVar s (H.R n)
-> BVar s (H.R m)
-> BVar s (H.R (n + m))
(#) = liftOp2 $ opIsoN
(\(x ::< y ::< ร) -> x H.# y )
(\(H.split->(dX,dY)) -> dX ::< dY ::< ร)
infixl 4 #
split
:: forall p n s. (Reifies s W, KnownNat p, KnownNat n, p <= n)
=> BVar s (H.R n)
-> (BVar s (H.R p), BVar s (H.R (n p)))
split v = (t ^^. _1, t ^^. _2)
where
t = liftOp1 (opIso (tupT2 . H.split)
(uncurryT2 (H.#))
) v
headTail
:: (Reifies s W, KnownNat n, 1 <= n)
=> BVar s (H.R n)
-> (BVar s H.โ, BVar s (H.R (n 1)))
headTail v = (t ^^. _1, t ^^. _2)
where
t = liftOp1 (opIso (tupT2 . H.headTail)
(\(T2 d dx) -> (H.konst d :: H.R 1) H.# dx)
) v
vector
:: forall n s. (Reifies s W, KnownNat n)
=> SV.Vector n (BVar s H.โ)
-> BVar s (H.R n)
vector vs =
liftOp1 (opIso (H.vecR . SVG.convert) (SVG.convert . H.rVec))
(collectVar vs)
linspace
:: forall n s. (Reifies s W, KnownNat n)
=> BVar s H.โ
-> BVar s H.โ
-> BVar s (H.R n)
linspace = liftOp2 . op2 $ \l u ->
( H.linspace (l, u)
, \d -> let n1 = fromInteger $ natVal (Proxy @n) 1
dDot = ((H.range 1) H.<.> d) / n1
dSum = HU.sumElements . H.extract $ d
in (dSum dDot, dDot)
)
row :: (Reifies s W, KnownNat n)
=> BVar s (H.R n)
-> BVar s (H.L 1 n)
row = liftOp1 $ opIso H.row H.unrow
col :: (Reifies s W, KnownNat n)
=> BVar s (H.R n)
-> BVar s (H.L n 1)
col = liftOp1 $ opIso H.col H.uncol
(|||) :: (Reifies s W, KnownNat c, KnownNat r1, KnownNat (r1 + r2))
=> BVar s (H.L c r1)
-> BVar s (H.L c r2)
-> BVar s (H.L c (r1 + r2))
(|||) = liftOp2 $ opIsoN
(\(x ::< y ::< ร) -> x H.||| y )
(\(H.splitCols->(dX,dY)) -> dX ::< dY ::< ร)
infixl 3 |||
(===) :: (Reifies s W, KnownNat c, KnownNat r1, KnownNat (r1 + r2))
=> BVar s (H.L r1 c)
-> BVar s (H.L r2 c)
-> BVar s (H.L (r1 + r2) c)
(===) = liftOp2 $ opIsoN
(\(x ::< y ::< ร) -> x H.=== y )
(\(H.splitRows->(dX,dY)) -> dX ::< dY ::< ร)
infixl 2 ===
splitRows
:: forall p m n s. (Reifies s W, KnownNat p, KnownNat m, KnownNat n, p <= m)
=> BVar s (H.L m n)
-> (BVar s (H.L p n), BVar s (H.L (m p) n))
splitRows v = (t ^^. _1, t ^^. _2)
where
t = liftOp1 (opIso (tupT2 . H.splitRows)
(\(T2 dx dy) -> dx H.=== dy)
) v
splitCols
:: forall p m n s. (Reifies s W, KnownNat p, KnownNat m, KnownNat n, KnownNat (n p), p <= n)
=> BVar s (H.L m n)
-> (BVar s (H.L m p), BVar s (H.L m (n p)))
splitCols v = (t ^^. _1, t ^^. _2)
where
t = liftOp1 (opIso (tupT2 . H.splitCols)
(uncurryT2 (H.|||))
) v
unrow
:: (Reifies s W, KnownNat n)
=> BVar s (H.L 1 n)
-> BVar s (H.R n)
unrow = liftOp1 $ opIso H.unrow H.row
uncol
:: (Reifies s W, KnownNat n)
=> BVar s (H.L n 1)
-> BVar s (H.R n)
uncol = liftOp1 $ opIso H.uncol H.col
tr :: (Reifies s W, HU.Transposable m mt, HU.Transposable mt m, Num m, Num mt)
=> BVar s m
-> BVar s mt
tr = liftOp1 $ opIso H.tr H.tr
diag
:: (Reifies s W, KnownNat n)
=> BVar s (H.R n)
-> BVar s (H.Sq n)
diag = liftOp1 . op1 $ \x -> (H.diag x, H.takeDiag)
matrix
:: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> [BVar s H.โ]
-> BVar s (H.L m n)
matrix vs = case SV.fromList @(m * n) vs of
Nothing -> error "matrix: invalid number of elements"
Just vs' ->
liftOp1 (opIso (fromJust . H.create . HU.reshape n . VG.convert . SV.fromSized)
(SV.concatMap (SVG.convert . H.rVec) . H.lRows)
)
(collectVar vs')
where
n = fromInteger $ natVal (Proxy @n)
(<>)
:: (Reifies s W, KnownNat m, KnownNat k, KnownNat n)
=> BVar s (H.L m k)
-> BVar s (H.L k n)
-> BVar s (H.L m n)
(<>) = mul
infixr 8 <>
(#>)
:: (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> BVar s (H.R n)
-> BVar s (H.R m)
(#>) = app
infixr 8 #>
(<.>)
:: (Reifies s W, KnownNat n)
=> BVar s (H.R n)
-> BVar s (H.R n)
-> BVar s H.โ
(<.>) = dot
infixr 8 <.>
svd :: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> BVar s (H.R n)
svd = liftOp1 . op1 $ \x ->
let (u, ฯ, v) = H.svd x
in ( ฯ
, \(dฮฃ :: H.R n) -> (u H.<> H.diagR 0 dฮฃ) H.<> H.tr v
)
svd_
:: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> (BVar s (H.L m m), BVar s (H.R n), BVar s (H.L n n))
svd_ r = (t ^^. _1, t ^^. _2, t ^^. _3)
where
o :: Op '[H.L m n] (T3 (H.L m m) (H.R n) (H.L n n))
o = op1 $ \x ->
let (u, ฯ, v) = H.svd x
in ( T3 u ฯ v
, \(T3 dU dฮฃ dV) ->
if H.norm_0 dU == 0 && H.norm_0 dV == 0
then (u H.<> H.diagR 0 dฮฃ) H.<> H.tr v
else error "svd_: Cannot backprop if U and V are used."
)
t = liftOp1 o r
helpEigen :: KnownNat n => H.Sym n -> (H.R n, H.L n n, H.L n n, H.L n n)
helpEigen x = (l, v, H.inv v, H.tr v)
where
(l, v) = H.eigensystem x
eigensystem
:: forall n s. (Reifies s W, KnownNat n)
=> BVar s (H.Sym n)
-> (BVar s (H.R n), BVar s (H.L n n))
eigensystem u = (t ^^. _1, t ^^. _2)
where
o :: Op '[H.Sym n] (T2 (H.R n) (H.L n n))
o = op1 $ \x ->
let (l, v, vInv, vTr) = helpEigen x
lRep = H.rowsL . SV.replicate $ l
fMat = (1 H.eye) * (lRep H.tr lRep)
in ( T2 l v
, \(T2 dL dV) -> unsafeCoerce $
H.tr vInv
H.<> (H.diag dL + fMat * (vTr H.<> dV))
H.<> vTr
)
t = liftOp1 o u
eigenvalues
:: forall n s. (Reifies s W, KnownNat n)
=> BVar s (H.Sym n)
-> BVar s (H.R n)
eigenvalues = liftOp1 . op1 $ \x ->
let (l, _, vInv, vTr) = helpEigen x
in ( l
, \dL -> unsafeCoerce $
H.tr vInv H.<> H.diag dL H.<> vTr
)
chol
:: forall n s. (Reifies s W, KnownNat n)
=> BVar s (H.Sym n)
-> BVar s (H.Sq n)
chol = liftOp1 . op1 $ \x ->
let l = H.chol x
lInv = H.inv l
phi :: H.Sq n
phi = H.build $ \i j -> case compare i j of
LT -> 1
EQ -> 0.5
GT -> 0
in ( l
, \dL -> let s = H.tr lInv H.<> (phi * (H.tr l H.<> dL)) H.<> lInv
in unsafeCoerce $ s + H.tr s H.eye * s
)
norm_0
:: (Reifies s W, H.Normed a, Num a)
=> BVar s a
-> BVar s H.โ
norm_0 = liftOp1 . op1 $ \x -> (H.norm_0 x, const 0)
norm_1V
:: (Reifies s W, KnownNat n)
=> BVar s (H.R n)
-> BVar s H.โ
norm_1V = liftOp1 . op1 $ \x -> (H.norm_1 x, (* signum x) . H.konst)
norm_1M
:: (Reifies s W, KnownNat n, KnownNat m)
=> BVar s (H.L n m)
-> BVar s H.โ
norm_1M = liftOp1 . op1 $ \x ->
let n = H.norm_1 x
in (n, \d -> let d' = H.konst d
in H.colsL
. SV.map (\c -> if H.norm_1 c == n
then d' * signum c
else 0
)
. H.lCols
$ x
)
norm_2V
:: (Reifies s W, KnownNat n)
=> BVar s (H.R n)
-> BVar s H.โ
norm_2V = liftOp1 . op1 $ \x ->
let n = H.norm_2 x
in (n, \d -> x * H.konst (d / n))
norm_2M
:: (Reifies s W, KnownNat n, KnownNat m)
=> BVar s (H.L n m)
-> BVar s H.โ
norm_2M = liftOp1 . op1 $ \x ->
let n = H.norm_2 x
(head.H.toColumns->u1,_,head.H.toColumns->v1) = H.svd x
in (n, \d -> H.konst d * (u1 `H.outer` v1))
norm_InfV
:: (Reifies s W, KnownNat n)
=> BVar s (H.R n)
-> BVar s H.โ
norm_InfV = liftOp1 . op1 $ \x ->
let n :: H.โ
n = H.norm_Inf x
in (n, \d -> H.vecR
. SVS.map (\e -> if abs e == n
then signum e * d
else 0
)
. H.rVec
$ x
)
norm_InfM
:: (Reifies s W, KnownNat n, KnownNat m)
=> BVar s (H.L n m)
-> BVar s H.โ
norm_InfM = liftOp1 . op1 $ \x ->
let n = H.norm_Inf x
in (n, \d -> let d' = H.konst d
in H.rowsL
. SV.map (\c -> if H.norm_1 c == n
then d' * signum c
else 0
)
. H.lRows
$ x
)
mean
:: (Reifies s W, KnownNat n, 1 <= n)
=> BVar s (H.R n)
-> BVar s H.โ
mean = liftOp1 . op1 $ \x -> (H.mean x, H.konst . (/ H.norm_0 x))
gradCov
:: forall m n. (KnownNat m, KnownNat n)
=> H.L m n
-> H.R n
-> H.Sym n
-> H.L m n
gradCov x ฮผ dฯ = H.rowsL
. SV.map (subtract (dDiffsSum / m))
. H.lRows
$ dDiffs
where
diffs = H.rowsL . SV.map (subtract ฮผ) . H.lRows $ x
dDiffs = H.konst (2/n) * (diffs H.<> H.tr (H.unSym dฯ))
dDiffsSum = sum . H.toRows $ dDiffs
m = fromIntegral $ natVal (Proxy @m)
n = fromIntegral $ natVal (Proxy @n)
meanCov
:: forall m n s. (Reifies s W, KnownNat n, KnownNat m, 1 <= m)
=> BVar s (H.L m n)
-> (BVar s (H.R n), BVar s (H.Sym n))
meanCov v = (t ^^. _1, t ^^. _2)
where
m = fromInteger $ natVal (Proxy @m)
t = ($ v) . liftOp1 . op1 $ \x ->
let (ฮผ, ฯ) = H.meanCov x
in ( T2 ฮผ ฯ
, \(T2 dฮผ dฯ) ->
let gradMean = H.rowsL
. SV.replicate
$ (dฮผ / H.konst m)
in gradMean + gradCov x ฮผ dฯ
)
meanL
:: forall m n s. (Reifies s W, KnownNat n, KnownNat m, 1 <= m)
=> BVar s (H.L m n)
-> BVar s (H.R n)
meanL = liftOp1 . op1 $ \x ->
( fst (H.meanCov x)
, H.rowsL . SV.replicate . (/ H.konst m)
)
where
m = fromInteger $ natVal (Proxy @m)
cov
:: forall m n s. (Reifies s W, KnownNat n, KnownNat m, 1 <= m)
=> BVar s (H.L m n)
-> BVar s (H.Sym n)
cov = liftOp1 . op1 $ \x ->
let (ฮผ, ฯ) = H.meanCov x
in (ฯ, gradCov x ฮผ)
mul :: ( Reifies s W
, KnownNat m
, KnownNat k
, KnownNat n
, H.Domain field vec mat
, Num (mat m k)
, Num (mat k n)
, Num (mat m n)
, HU.Transposable (mat m k) (mat k m)
, HU.Transposable (mat k n) (mat n k)
)
=> BVar s (mat m k)
-> BVar s (mat k n)
-> BVar s (mat m n)
mul = liftOp2 . op2 $ \x y ->
( x `H.mul` y
, \d -> (d `H.mul` H.tr y, H.tr x `H.mul` d)
)
app :: ( Reifies s W
, KnownNat m
, KnownNat n
, H.Domain field vec mat
, Num (mat m n)
, Num (vec n)
, Num (vec m)
, HU.Transposable (mat m n) (mat n m)
)
=> BVar s (mat m n)
-> BVar s (vec n)
-> BVar s (vec m)
app = liftOp2 . op2 $ \xs y ->
( xs `H.app` y
, \d -> (d `H.outer` y, H.tr xs `H.app` d)
)
dot :: ( Reifies s W
, KnownNat n
, H.Domain field vec mat
, H.Sized field (vec n) d
, Num (vec n)
)
=> BVar s (vec n)
-> BVar s (vec n)
-> BVar s field
dot = liftOp2 . op2 $ \x y ->
( x `H.dot` y
, \d -> let d' = H.konst d
in (d' * y, x * d')
)
cross
:: ( Reifies s W
, H.Domain field vec mat
, Num (vec 3)
)
=> BVar s (vec 3)
-> BVar s (vec 3)
-> BVar s (vec 3)
cross = liftOp2 . op2 $ \x y ->
( x `H.cross` y
, \d -> (y `H.cross` d, d `H.cross` x)
)
diagR
:: forall m n k field vec mat s.
( Reifies s W
, H.Domain field vec mat
, Num (vec k)
, Num (mat m n)
, KnownNat m
, KnownNat n
, KnownNat k
, HU.Container HU.Vector field
, H.Sized field (mat m n) HU.Matrix
, H.Sized field (vec k) HU.Vector
)
=> BVar s field
-> BVar s (vec k)
-> BVar s (mat m n)
diagR = liftOp2 . op2 $ \c x ->
( H.diagR c x
, \d -> ( HU.sumElements . H.extract $ H.diagR 1 (0 :: vec k) * d
, fromJust . H.create . HU.takeDiag . H.extract $ d
)
)
dvmap
:: ( Reifies s W
, Num (vec n)
, Storable field
, Storable (field, field)
, H.Sized field (vec n) HU.Vector
)
=> (forall s'. Reifies s' W => BVar s' field -> BVar s' field)
-> BVar s (vec n)
-> BVar s (vec n)
dvmap f = liftOp1 . op1 $ \x ->
let (y, dx) = HU.unzipVector $ VG.map (backprop f) (H.extract x)
in ( fromJust (H.create y)
, \d -> d * fromJust (H.create dx)
)
dvmap'
:: ( Reifies s W
, KnownNat n
, H.Domain field vec mat
, Num (vec n)
, Num field
)
=> (forall s'. Reifies s' W => BVar s' field -> BVar s' field)
-> BVar s (vec n)
-> BVar s (vec n)
dvmap' f = liftOp1 . op1 $ \x ->
( H.dvmap (evalBP f) x
, (H.dvmap (gradBP f) x *)
)
dmmap
:: forall n m mat field s.
( Reifies s W
, KnownNat m
, Num (mat n m)
, Storable (field, field)
, H.Sized field (mat n m) HU.Matrix
, HU.Element field
)
=> (forall s'. Reifies s' W => BVar s' field -> BVar s' field)
-> BVar s (mat n m)
-> BVar s (mat n m)
dmmap f = liftOp1 . op1 $ \x ->
let (y', dx') = HU.unzipVector
. VG.map (backprop f)
. HU.flatten
$ H.extract x
in ( fromJust . H.create . HU.reshape m $ y'
, \d -> (* d) . fromJust . H.create . HU.reshape m $ dx'
)
where
m :: Int
m = fromInteger $ natVal (Proxy @m)
dmmap'
:: ( Reifies s W
, KnownNat n
, KnownNat m
, H.Domain field vec mat
, Num (mat n m)
, Num field
)
=> (forall s'. Reifies s' W => BVar s' field -> BVar s' field)
-> BVar s (mat n m)
-> BVar s (mat n m)
dmmap' f = liftOp1 . op1 $ \x ->
( H.dmmap (evalBP f) x
, (H.dmmap (gradBP f) x *)
)
outer
:: ( Reifies s W
, KnownNat m
, KnownNat n
, H.Domain field vec mat
, HU.Transposable (mat n m) (mat m n)
, Num (vec n)
, Num (vec m)
, Num (mat n m)
)
=> BVar s (vec n)
-> BVar s (vec m)
-> BVar s (mat n m)
outer = liftOp2 . op2 $ \x y ->
( x `H.outer` y
, \d -> ( d `H.app` y
, H.tr d `H.app` x)
)
zipWithVector
:: ( Reifies s W
, Num (vec n)
, Storable field
, Storable (field, field, field)
, H.Sized field (vec n) HU.Vector
)
=> (forall s'. Reifies s' W => BVar s' field -> BVar s' field -> BVar s' field)
-> BVar s (vec n)
-> BVar s (vec n)
-> BVar s (vec n)
zipWithVector f = liftOp2 . op2 $ \(H.extract->x) (H.extract->y) ->
let (z,dx,dy) = VG.unzip3
$ VG.zipWith (\x' y' ->
let (z', (dx', dy')) = backprop2 f x' y'
in (z', dx', dy')
) x y
in ( fromJust (H.create z)
, \d -> (d * fromJust (H.create dx), d * fromJust (H.create dy))
)
zipWithVector'
:: ( Reifies s W
, KnownNat n
, H.Domain field vec mat
, Num (vec n)
, Num field
)
=> (forall s'. Reifies s' W => BVar s' field -> BVar s' field -> BVar s' field)
-> BVar s (vec n)
-> BVar s (vec n)
-> BVar s (vec n)
zipWithVector' f = liftOp2 . op2 $ \x y ->
( H.zipWithVector (evalBP2 f) x y
, \d -> let dx = H.zipWithVector (\x' -> fst . gradBP2 f x') x y
dy = H.zipWithVector (\x' -> snd . gradBP2 f x') x y
in (d * dx, d * dy)
)
det :: ( Reifies s W
, KnownNat n
, Num (mat n n)
, H.Domain field vec mat
, H.Sized field (mat n n) d
, HU.Transposable (mat n n) (mat n n)
)
=> BVar s (mat n n)
-> BVar s field
det = liftOp1 . op1 $ \x ->
let xDet = H.det x
xInv = H.inv x
in ( xDet, \d -> H.konst (d * xDet) * H.tr xInv )
invlndet
:: forall n mat field vec d s.
( Reifies s W
, KnownNat n
, Num (mat n n)
, H.Domain field vec mat
, H.Sized field (mat n n) d
, HU.Transposable (mat n n) (mat n n)
)
=> BVar s (mat n n)
-> (BVar s (mat n n), (BVar s field, BVar s field))
invlndet v = (t ^^. _1, (t ^^. _2, t ^^. _3))
where
o :: Op '[mat n n] (T3 (mat n n) field field)
o = op1 $ \x ->
let (i,(ldet, s)) = H.invlndet x
iTr = H.tr i
in ( T3 i ldet s
, \(T3 dI dLDet _) ->
let gradI = iTr `H.mul` dI `H.mul` iTr
gradLDet = H.konst dLDet * H.tr i
in gradI + gradLDet
)
t = liftOp1 o v
lndet
:: forall n mat field vec d s.
( Reifies s W
, KnownNat n
, Num (mat n n)
, H.Domain field vec mat
, H.Sized field (mat n n) d
, HU.Transposable (mat n n) (mat n n)
)
=> BVar s (mat n n)
-> BVar s field
lndet = liftOp1 . op1 $ \x ->
let (i,(ldet,_)) = H.invlndet x
in (ldet, (* H.tr i) . H.konst)
inv :: ( Reifies s W
, KnownNat n
, Num (mat n n)
, H.Domain field vec mat
, HU.Transposable (mat n n) (mat n n)
)
=> BVar s (mat n n)
-> BVar s (mat n n)
inv = liftOp1 . op1 $ \x ->
let xInv = H.inv x
xInvTr = H.tr xInv
in ( xInv, \d -> xInvTr `H.mul` d `H.mul` xInvTr )
toRows
:: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> SV.Vector m (BVar s (H.R n))
toRows = sequenceVar . liftOp1 (opIso H.lRows H.rowsL)
toColumns
:: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> SV.Vector n (BVar s (H.R m))
toColumns = sequenceVar . liftOp1 (opIso H.lCols H.colsL)
fromRows
:: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> SV.Vector m (BVar s (H.R n))
-> BVar s (H.L m n)
fromRows = liftOp1 (opIso H.rowsL H.lRows) . collectVar
fromColumns
:: forall m n s. (Reifies s W, KnownNat m, KnownNat n)
=> SV.Vector n (BVar s (H.R m))
-> BVar s (H.L m n)
fromColumns = liftOp1 (opIso H.colsL H.lCols) . collectVar
konst
:: forall t s d q. (Reifies q W, H.Sized t s d, HU.Container d t, Num s)
=> BVar q t
-> BVar q s
konst = liftOp1 . op1 $ \x ->
( H.konst x
, HU.sumElements . H.extract
)
sumElements
:: forall t s d q. (Reifies q W, H.Sized t s d, HU.Container d t, Num s)
=> BVar q s
-> BVar q t
sumElements = liftOp1 . op1 $ \x ->
( HU.sumElements . H.extract $ x
, H.konst
)
extractV
:: forall t s q.
( Reifies q W
, H.Sized t s HU.Vector
, Num s
, HU.Konst t Int HU.Vector
, HU.Container HU.Vector t
, Num (HU.Vector t)
)
=> BVar q s
-> BVar q (HU.Vector t)
extractV = liftOp1 . op1 $ \x ->
let n = H.size x
in ( H.extract x
, \d -> let m = HU.size d
m' = case compare n m of
LT -> HU.subVector 0 n d
EQ -> d
GT -> HU.vjoin [d, HU.konst 0 (n m)]
in fromJust . H.create $ m'
)
extractM
:: forall t s q.
( Reifies q W
, H.Sized t s HU.Matrix
, Num s
, HU.Konst t (Int, Int) HU.Matrix
, HU.Container HU.Matrix t
, Num (HU.Matrix t)
)
=> BVar q s
-> BVar q (HU.Matrix t)
extractM = liftOp1 . op1 $ \x ->
let (xI,xJ) = H.size x
in ( H.extract x
, \d -> let (dI,dJ) = HU.size d
m' = case (compare xI dI, compare xJ dJ) of
(LT, LT) -> d HU.?? (HU.Take xI, HU.Take xJ)
(LT, EQ) -> d HU.?? (HU.Take xI, HU.All)
(LT, GT) -> d HU.?? (HU.Take xI, HU.All)
HU.||| HU.konst 0 (xI, xJ dJ)
(EQ, LT) -> d HU.?? (HU.All , HU.Take xJ)
(EQ, EQ) -> d
(EQ, GT) -> d HU.?? (HU.All, HU.All)
HU.||| HU.konst 0 (xI, xJ dJ)
(GT, LT) -> d HU.?? (HU.All, HU.Take xJ)
HU.=== HU.konst 0 (xI dI, xJ)
(GT, EQ) -> d HU.?? (HU.All, HU.All)
HU.=== HU.konst 0 (xI dI, xJ)
(GT, GT) -> HU.fromBlocks
[[d,0 ]
,[0,HU.konst 0 (xI dI, xJ dJ)]
]
in fromJust . H.create $ m'
)
create
:: forall t s d q. (Reifies q W, H.Sized t s d, Num s, Num (d t))
=> BVar q (d t)
-> Maybe (BVar q s)
create = fmap (unANum . sequenceVar) . liftOp1 $
opIso (ANum . H.create)
(maybe 0 H.extract . unANum )
takeDiag
:: ( Reifies s W
, KnownNat n
, H.Diag (mat n n) (vec n)
, H.Domain field vec mat
, Num (vec n)
, Num (mat n n)
, Num field
)
=> BVar s (mat n n)
-> BVar s (vec n)
takeDiag = liftOp1 . op1 $ \x ->
( H.takeDiag x
, H.diagR 0
)
sym :: (Reifies s W, KnownNat n)
=> BVar s (H.Sq n)
-> BVar s (H.Sym n)
sym = liftOp1 . op1 $ \x ->
( H.sym x
, H.unSym . H.sym . H.unSym
)
mTm :: (Reifies s W, KnownNat m, KnownNat n)
=> BVar s (H.L m n)
-> BVar s (H.Sym n)
mTm = liftOp1 . op1 $ \x ->
( H.mTm x
, \d -> 2 * (x H.<> H.unSym d)
)
unSym
:: (Reifies s W, KnownNat n)
=> BVar s (H.Sym n)
-> BVar s (H.Sq n)
unSym = liftOp1 (opIso H.unSym unsafeCoerce)
(<ยท>)
:: (Reifies s W, KnownNat n)
=> BVar s (H.R n)
-> BVar s (H.R n)
-> BVar s H.โ
(<ยท>) = dot
infixr 8 <ยท>