{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
module Algorithm.Massiv.Utils where
import Data.Massiv.Array hiding ( forM_, unzip, map, init, zipWith, zip, tail, replicate, take )
import qualified Data.Massiv.Array as A
import qualified Data.Massiv.Array.Unsafe as UMA
import qualified Data.Massiv.Array.Mutable as MMA
import Control.Monad
import Data.Vector.Storable ((//))
import System.IO.Unsafe
import Control.Arrow
import Data.List(unfoldr)
import Data.SRTree.Eval
type MMassArray m = MMA.MArray (PrimState m) S Ix2 Double
getRows :: SRMatrix -> Array B Ix1 PVector
getRows :: SRMatrix -> Array B Int PVector
getRows = B -> Array D Int PVector -> Array B Int PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs B
B (Array D Int PVector -> Array B Int PVector)
-> (SRMatrix -> Array D Int PVector)
-> SRMatrix
-> Array B Int PVector
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRMatrix -> Array D Int PVector
SRMatrix -> Array D Int (Array S (Lower Ix2) Double)
forall r ix e.
(Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> Array D Int (Array r (Lower ix) e)
outerSlices
{-# INLINE getRows #-}
getCols :: SRMatrix -> Array B Ix1 PVector
getCols :: SRMatrix -> Array B Int PVector
getCols = B -> Array D Int PVector -> Array B Int PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs B
B (Array D Int PVector -> Array B Int PVector)
-> (SRMatrix -> Array D Int PVector)
-> SRMatrix
-> Array B Int PVector
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Array D Int Double -> PVector)
-> Array D Int (Array D Int Double) -> Array D Int PVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
A.map (S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S) (Array D Int (Array D Int Double) -> Array D Int PVector)
-> (SRMatrix -> Array D Int (Array D Int Double))
-> SRMatrix
-> Array D Int PVector
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRMatrix -> Array D Int (Array D Int Double)
SRMatrix -> Array D Int (Array D (Lower Ix2) Double)
forall r ix e.
(Index ix, Source r e) =>
Array r ix e -> Array D Int (Array D (Lower ix) e)
innerSlices
{-# INLINE getCols #-}
appendRow :: MonadThrow m => SRMatrix -> PVector -> m SRMatrix
appendRow :: forall (m :: * -> *).
MonadThrow m =>
SRMatrix -> PVector -> m SRMatrix
appendRow SRMatrix
xs PVector
v = S -> Array DL Ix2 Double -> SRMatrix
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S (Array DL Ix2 Double -> SRMatrix)
-> m (Array DL Ix2 Double) -> m SRMatrix
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([PVector] -> m (Array DL Ix2 Double)
[Array S (Lower Ix2) Double] -> m (Array DL Ix2 Double)
forall r ix e (f :: * -> *) (m :: * -> *).
(Foldable f, MonadThrow m, Index (Lower ix), Source r e,
Index ix) =>
f (Array r (Lower ix) e) -> m (Array DL ix e)
stackOuterSlicesM ([PVector] -> m (Array DL Ix2 Double))
-> (Array DL Int PVector -> [PVector])
-> Array DL Int PVector
-> m (Array DL Ix2 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array B Int PVector -> [PVector]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
toList (Array B Int PVector -> [PVector])
-> (Array DL Int PVector -> Array B Int PVector)
-> Array DL Int PVector
-> [PVector]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. B -> Array DL Int PVector -> Array B Int PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs B
B (Array DL Int PVector -> m (Array DL Ix2 Double))
-> Array DL Int PVector -> m (Array DL Ix2 Double)
forall a b. (a -> b) -> a -> b
$ Array D Int PVector -> PVector -> Array DL Int PVector
forall r e.
(Size r, Load r Int e) =>
Vector r e -> e -> Vector DL e
snoc (SRMatrix -> Array D Int (Array S (Lower Ix2) Double)
forall r ix e.
(Index ix, Index (Lower ix), Source r e) =>
Array r ix e -> Array D Int (Array r (Lower ix) e)
outerSlices SRMatrix
xs) PVector
v)
{-# INLINE appendRow #-}
appendCol :: MonadThrow m => SRMatrix -> PVector -> m SRMatrix
appendCol :: forall (m :: * -> *).
MonadThrow m =>
SRMatrix -> PVector -> m SRMatrix
appendCol SRMatrix
xs PVector
v = S -> Array DL Ix2 Double -> SRMatrix
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S (Array DL Ix2 Double -> SRMatrix)
-> m (Array DL Ix2 Double) -> m SRMatrix
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([PVector] -> m (Array DL Ix2 Double)
[Array S (Lower Ix2) Double] -> m (Array DL Ix2 Double)
forall r ix e (f :: * -> *) (m :: * -> *).
(Foldable f, MonadThrow m, Index (Lower ix), Source r e,
Index ix) =>
f (Array r (Lower ix) e) -> m (Array DL ix e)
stackInnerSlicesM ([PVector] -> m (Array DL Ix2 Double))
-> (Array DL Int PVector -> [PVector])
-> Array DL Int PVector
-> m (Array DL Ix2 Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array B Int PVector -> [PVector]
forall ix r e. (Index ix, Source r e) => Array r ix e -> [e]
toList (Array B Int PVector -> [PVector])
-> (Array DL Int PVector -> Array B Int PVector)
-> Array DL Int PVector
-> [PVector]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. B -> Array DL Int PVector -> Array B Int PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs B
B (Array DL Int PVector -> m (Array DL Ix2 Double))
-> Array DL Int PVector -> m (Array DL Ix2 Double)
forall a b. (a -> b) -> a -> b
$ Array D Int PVector -> PVector -> Array DL Int PVector
forall r e.
(Size r, Load r Int e) =>
Vector r e -> e -> Vector DL e
snoc ((Array D Int Double -> PVector)
-> Array D Int (Array D Int Double) -> Array D Int PVector
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
A.map (S -> Array D Int Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S) (Array D Int (Array D Int Double) -> Array D Int PVector)
-> Array D Int (Array D Int Double) -> Array D Int PVector
forall a b. (a -> b) -> a -> b
$ SRMatrix -> Array D Int (Array D (Lower Ix2) Double)
forall r ix e.
(Index ix, Source r e) =>
Array r ix e -> Array D Int (Array D (Lower ix) e)
innerSlices SRMatrix
xs) PVector
v)
{-# INLINE appendCol #-}
updateS :: Array S Ix1 Double -> [(Int, Double)] -> Array S Ix1 Double
updateS :: PVector -> [(Int, Double)] -> PVector
updateS PVector
vec [(Int, Double)]
new = Comp -> Vector Double -> PVector
forall e. Comp -> Vector e -> Vector S e
fromStorableVector Comp
compMode (Vector Double -> PVector) -> Vector Double -> PVector
forall a b. (a -> b) -> a -> b
$ PVector -> Vector Double
forall ix e. Index ix => Array S ix e -> Vector e
toStorableVector PVector
vec Vector Double -> [(Int, Double)] -> Vector Double
forall a. Storable a => Vector a -> [(Int, a)] -> Vector a
// [(Int, Double)]
new
linSpace :: Int -> (Double, Double) -> [Double]
linSpace :: Int -> (Double, Double) -> [Double]
linSpace Int
num (Double
lo, Double
hi) = Int -> [Double] -> [Double]
forall a. Int -> [a] -> [a]
Prelude.take Int
num ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ (Double -> Double) -> Double -> [Double]
forall a. (a -> a) -> a -> [a]
iterate (\Double
x -> Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
step) Double
lo
where
step :: Double
step = (Double
hi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lo) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1)
{-# INLINE linSpace #-}
outer :: (MonadThrow m)
=> PVector
-> PVector
-> m SRMatrix
outer :: forall (m :: * -> *).
MonadThrow m =>
PVector -> PVector -> m SRMatrix
outer PVector
arr1 PVector
arr2
| PVector -> Bool
forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty PVector
arr1 Bool -> Bool -> Bool
|| PVector -> Bool
forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty PVector
arr2 = SRMatrix -> m SRMatrix
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRMatrix -> m SRMatrix) -> SRMatrix -> m SRMatrix
forall a b. (a -> b) -> a -> b
$ Comp -> SRMatrix -> SRMatrix
forall r ix e. Strategy r => Comp -> Array r ix e -> Array r ix e
forall ix e. Comp -> Array S ix e -> Array S ix e
setComp Comp
comp SRMatrix
forall r ix e. Load r ix e => Array r ix e
empty
| Bool
otherwise =
SRMatrix -> m SRMatrix
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRMatrix -> m SRMatrix) -> SRMatrix -> m SRMatrix
forall a b. (a -> b) -> a -> b
$ Comp -> Sz Ix2 -> (Ix2 -> Double) -> SRMatrix
forall r ix e.
Load r ix e =>
Comp -> Sz ix -> (ix -> e) -> Array r ix e
makeArray Comp
comp (Int -> Int -> Sz Ix2
Sz2 Int
m1 Int
m2) ((Ix2 -> Double) -> SRMatrix) -> (Ix2 -> Double) -> SRMatrix
forall a b. (a -> b) -> a -> b
$ \(Int
i :. Int
j) ->
PVector -> Int -> Double
forall ix. Index ix => Array S ix Double -> ix -> Double
forall r e ix. (Source r e, Index ix) => Array r ix e -> ix -> e
UMA.unsafeIndex PVector
arr1 Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
* PVector -> Int -> Double
forall ix. Index ix => Array S ix Double -> ix -> Double
forall r e ix. (Source r e, Index ix) => Array r ix e -> ix -> e
UMA.unsafeIndex PVector
arr2 Int
j
where
comp :: Comp
comp = PVector -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp PVector
arr1 Comp -> Comp -> Comp
forall a. Semigroup a => a -> a -> a
<> PVector -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp PVector
arr2
Sz1 Int
m1 = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
arr1
Sz1 Int
m2 = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
arr2
{-# INLINE outer #-}
det :: SRMatrix -> Double
det :: SRMatrix -> Double
det SRMatrix
mtx
| Int
mInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0 Bool -> Bool -> Bool
|| Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0 = Double
1
| Bool
otherwise = (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ [Double] -> Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.product [SRMatrix
l SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
i Int -> Int -> Ix2
:. Int
i) | Int
i <- [Int
0 .. Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
where
Sz (Int
m :. Int
n) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size SRMatrix
mtx
(SRMatrix
l, SRMatrix
_) = IO (SRMatrix, SRMatrix) -> (SRMatrix, SRMatrix)
forall a. IO a -> a
unsafePerformIO (SRMatrix -> IO (SRMatrix, SRMatrix)
forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> m (SRMatrix, SRMatrix)
lu SRMatrix
mtx)
detChol :: SRMatrix -> Double
detChol :: SRMatrix -> Double
detChol SRMatrix
mtx
| Int
mInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0 Bool -> Bool -> Bool
|| Int
nInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==Int
0 = Double
1
| Bool
otherwise = (Double -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ [Double] -> Double
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Prelude.product [SRMatrix
cho SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
i Int -> Int -> Ix2
:. Int
i) | Int
i <- [Int
0 .. Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
where
Sz (Int
m :. Int
n) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size SRMatrix
mtx
cho :: SRMatrix
cho = IO SRMatrix -> SRMatrix
forall a. IO a -> a
unsafePerformIO (SRMatrix -> IO SRMatrix
forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> m SRMatrix
cholesky SRMatrix
mtx)
{-# INLINE det #-}
rangedLinearDotProd :: PrimMonad m => Int -> Int -> Int -> MMassArray m -> m Double
rangedLinearDotProd :: forall (m :: * -> *).
PrimMonad m =>
Int -> Int -> Int -> MMassArray m -> m Double
rangedLinearDotProd Int
r1 Int
r2 Int
len MMassArray m
arr = Double -> Int -> m Double
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Double -> Int -> m Double
go Double
0 Int
0
where
go :: Double -> Int -> m Double
go !Double
acc Int
k
| Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len = do Double
x <- MArray (PrimState m) S Ix2 Double -> Int -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> Int -> m e
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
MArray (PrimState m) S ix Double -> Int -> m Double
UMA.unsafeLinearRead MMassArray m
MArray (PrimState m) S Ix2 Double
arr (Int
r1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
Double
y <- MArray (PrimState m) S Ix2 Double -> Int -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> Int -> m e
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
MArray (PrimState m) S ix Double -> Int -> m Double
UMA.unsafeLinearRead MMassArray m
MArray (PrimState m) S Ix2 Double
arr (Int
r2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
Double -> Int -> m Double
go (Double
acc Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
xDouble -> Double -> Double
forall a. Num a => a -> a -> a
*Double
y) (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
| Bool
otherwise = Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
acc
{-# INLINE rangedLinearDotProd #-}
data NegDef = NegDef
deriving Int -> NegDef -> ShowS
[NegDef] -> ShowS
NegDef -> String
(Int -> NegDef -> ShowS)
-> (NegDef -> String) -> ([NegDef] -> ShowS) -> Show NegDef
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NegDef -> ShowS
showsPrec :: Int -> NegDef -> ShowS
$cshow :: NegDef -> String
show :: NegDef -> String
$cshowList :: [NegDef] -> ShowS
showList :: [NegDef] -> ShowS
Show
instance Exception NegDef
cholesky :: (PrimMonad m, MonadThrow m, MonadIO m)
=> SRMatrix
-> m SRMatrix
cholesky :: forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> m SRMatrix
cholesky SRMatrix
arr
| Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n = SizeException -> m SRMatrix
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (SizeException -> m SRMatrix) -> SizeException -> m SRMatrix
forall a b. (a -> b) -> a -> b
$ Sz Ix2 -> Sz Ix2 -> SizeException
forall ix. Index ix => Sz ix -> Sz ix -> SizeException
SizeMismatchException (SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size SRMatrix
arr) (SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size SRMatrix
arr)
| SRMatrix -> Bool
forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty SRMatrix
arr = SRMatrix -> m SRMatrix
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRMatrix -> m SRMatrix) -> SRMatrix -> m SRMatrix
forall a b. (a -> b) -> a -> b
$ Comp -> SRMatrix -> SRMatrix
forall r ix e. Strategy r => Comp -> Array r ix e -> Array r ix e
forall ix e. Comp -> Array S ix e -> Array S ix e
setComp Comp
comp SRMatrix
forall r ix e. Load r ix e => Array r ix e
empty
| Bool
otherwise = Sz Ix2 -> (MArray (PrimState m) S Ix2 Double -> m ()) -> m SRMatrix
forall r ix e a (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> (MArray (PrimState m) r ix e -> m a) -> m (Array r ix e)
MMA.createArrayS_ (SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size SRMatrix
arr) MArray (PrimState m) S Ix2 Double -> m ()
forall {m :: * -> *}.
(PrimMonad m, MonadThrow m) =>
MArray (PrimState m) S Ix2 Double -> m ()
create
where
comp :: Comp
comp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
getComp SRMatrix
arr
(Sz2 Int
m Int
n) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size SRMatrix
arr
create :: MArray (PrimState m) S Ix2 Double -> m ()
create MArray (PrimState m) S Ix2 Double
l = (Ix2 -> m ()) -> [Ix2] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
Prelude.mapM_ (MArray (PrimState m) S Ix2 Double -> Ix2 -> m ()
forall {m :: * -> *}.
(PrimMonad m, MonadThrow m) =>
MArray (PrimState m) S Ix2 Double -> Ix2 -> m ()
update MArray (PrimState m) S Ix2 Double
l) [Int
i Int -> Int -> Ix2
:. Int
j | Int
i <- [Int
0..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1], Int
j <- [Int
0..Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]
update :: MArray (PrimState m) S Ix2 Double -> Ix2 -> m ()
update MArray (PrimState m) S Ix2 Double
l ix :: Ix2
ix@(Int
i :. Int
j)
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j = MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
l Ix2
ix Double
0
| Bool
otherwise = do let cur :: Double
cur = SRMatrix -> Ix2 -> Double
forall ix. Index ix => Array S ix Double -> ix -> Double
forall r e ix. (Source r e, Index ix) => Array r ix e -> ix -> e
UMA.unsafeIndex SRMatrix
arr Ix2
ix
rowI :: Int
rowI = Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m
rowJ :: Int
rowJ = Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m
Double
xjj <- MArray (PrimState m) S Ix2 Double -> Int -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> Int -> m e
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
MArray (PrimState m) S ix Double -> Int -> m Double
UMA.unsafeLinearRead MArray (PrimState m) S Ix2 Double
l (Int
rowJ Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j)
Double
tot <- Int -> Int -> Int -> MArray (PrimState m) S Ix2 Double -> m Double
forall (m :: * -> *).
PrimMonad m =>
Int -> Int -> Int -> MMassArray m -> m Double
rangedLinearDotProd Int
rowI Int
rowJ Int
j MArray (PrimState m) S Ix2 Double
l
let delta :: Double
delta = Double
cur Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
tot
if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j
then if Double
delta Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
0
then NegDef -> m ()
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM NegDef
NegDef
else MArray (PrimState m) S Ix2 Double -> Int -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> Int -> e -> m ()
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
MArray (PrimState m) S ix Double -> Int -> Double -> m ()
UMA.unsafeLinearWrite MArray (PrimState m) S Ix2 Double
l (Int
rowI Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Double -> Double
forall a. Floating a => a -> a
sqrt Double
delta)
else MArray (PrimState m) S Ix2 Double -> Int -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> Int -> e -> m ()
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
MArray (PrimState m) S ix Double -> Int -> Double -> m ()
UMA.unsafeLinearWrite MArray (PrimState m) S Ix2 Double
l (Int
rowI Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Double
delta Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
xjj)
{-# INLINE cholesky #-}
invChol :: (PrimMonad m, MonadThrow m, MonadIO m) => SRMatrix -> m SRMatrix
invChol :: forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> m SRMatrix
invChol SRMatrix
arr = do SRMatrix
l <- SRMatrix -> m SRMatrix
forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> m SRMatrix
cholesky SRMatrix
arr
MArray (PrimState m) S Ix2 Double
mtx <- SRMatrix -> m (MArray (PrimState m) S Ix2 Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
thawS SRMatrix
l
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Double
lII <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
mtx (Int
i Int -> Int -> Ix2
:. Int
i)
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
mtx (Int
i Int -> Int -> Ix2
:. Int
i) (Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
lII)
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
Double
tot <- Int -> Int -> Int -> MArray (PrimState m) S Ix2 Double -> m Double
forall (m :: * -> *).
PrimMonad m =>
Int -> Int -> Int -> MMassArray m -> m Double
rangedLinearDotProd (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
j) MArray (PrimState m) S Ix2 Double
mtx
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
mtx (Int
j Int -> Int -> Ix2
:. Int
i) ((-Double
tot)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
lII)
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
mtx (Int
i Int -> Int -> Ix2
:. Int
j) Double
0
MArray (PrimState m) S Ix2 Double
mm <- Sz Ix2 -> Double -> m (MArray (PrimState m) S Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
newMArray (Int -> Int -> Sz Ix2
Sz2 Int
m Int
m) Double
0
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Double
dii <- Int -> Int -> Int -> MArray (PrimState m) S Ix2 Double -> m Double
forall (m :: * -> *).
PrimMonad m =>
Int -> Int -> Int -> MMassArray m -> m Double
rangedLinearDotProd (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i) MArray (PrimState m) S Ix2 Double
mtx
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
mm (Int
i Int -> Int -> Ix2
:. Int
i) Double
dii
[Int] -> (Int -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1 .. Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> m ()) -> m ()) -> (Int -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \Int
j -> do
Double
dij <- Int -> Int -> Int -> MArray (PrimState m) S Ix2 Double -> m Double
forall (m :: * -> *).
PrimMonad m =>
Int -> Int -> Int -> MMassArray m -> m Double
rangedLinearDotProd (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j) (Int
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
j) MArray (PrimState m) S Ix2 Double
mtx
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
mm (Int
i Int -> Int -> Ix2
:. Int
j) Double
dij
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
mm (Int
j Int -> Int -> Ix2
:. Int
i) Double
dij
MArray (PrimState m) S Ix2 Double -> m SRMatrix
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
freezeS MArray (PrimState m) S Ix2 Double
mm
where
Sz2 Int
m Int
_ = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size SRMatrix
arr
{-# INLINE invChol #-}
lu :: (PrimMonad m, MonadThrow m, MonadIO m) => SRMatrix -> m (SRMatrix, SRMatrix)
lu :: forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> m (SRMatrix, SRMatrix)
lu SRMatrix
mtx = do
let (Sz2 Int
m Int
n) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size SRMatrix
mtx
MArray (PrimState m) S Ix2 Double
u <- SRMatrix -> m (MArray (PrimState m) S Ix2 Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
thawS (SRMatrix -> m (MArray (PrimState m) S Ix2 Double))
-> SRMatrix -> m (MArray (PrimState m) S Ix2 Double)
forall a b. (a -> b) -> a -> b
$ S -> Array DL Ix2 Double -> SRMatrix
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
computeAs S
S (Array DL Ix2 Double -> SRMatrix)
-> Array DL Ix2 Double -> SRMatrix
forall a b. (a -> b) -> a -> b
$ Sz Int -> Array DL Ix2 Double
forall e. Num e => Sz Int -> Matrix DL e
identityMatrix (Int -> Sz Int
forall ix. Index ix => ix -> Sz ix
Sz Int
m)
MArray (PrimState m) S Ix2 Double
l <- SRMatrix -> m (MArray (PrimState m) S Ix2 Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
thawS (SRMatrix -> m (MArray (PrimState m) S Ix2 Double))
-> SRMatrix -> m (MArray (PrimState m) S Ix2 Double)
forall a b. (a -> b) -> a -> b
$ Comp -> Sz Ix2 -> Double -> SRMatrix
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
A.replicate Comp
compMode (Int -> Int -> Sz Ix2
Sz2 Int
m Int
n) Double
0
let buildLVal :: Int -> Int -> m ()
buildLVal !Int
i !Int
j = do
let go :: Int -> Double -> f Double
go !Int
k !Double
s
| Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = Double -> f Double
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
s
| Bool
otherwise = do Double
lik <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState f) S Ix2 Double
l (Int
i Int -> Int -> Ix2
:. Int
k)
Double
ukj <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState f) S Ix2 Double
u (Int
k Int -> Int -> Ix2
:. Int
j)
Int -> Double -> f Double
go (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) ( Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Double
lik Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
ukj) )
Double
s' <- Int -> Double -> m Double
forall {f :: * -> *}.
(PrimState f ~ PrimState m, PrimMonad f) =>
Int -> Double -> f Double
go Int
0 Double
0
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
MArray (PrimState m) S Ix2 Double
l (Int
i Int -> Int -> Ix2
:. Int
j) ((SRMatrix
mtx SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
i Int -> Int -> Ix2
:. Int
j)) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
s')
buildL :: Int -> Int -> f ()
buildL !Int
i !Int
j
= Bool -> f () -> f ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n) (f () -> f ()) -> f () -> f ()
forall a b. (a -> b) -> a -> b
$ do Int -> Int -> f ()
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> Int -> m ()
buildLVal Int
i Int
j
Int -> Int -> f ()
buildL (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
j
buildUVal :: Int -> Int -> m ()
buildUVal !Int
i !Int
j = do
let go :: Int -> Double -> m Double
go !Int
k !Double
s
| Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j = Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
s
| Bool
otherwise = do Double
ljk <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
l (Int
j Int -> Int -> Ix2
:. Int
k)
Double
uki <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
u (Int
k Int -> Int -> Ix2
:. Int
i)
Int -> Double -> m Double
go (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
ljk Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
uki)
Double
s' <- Int -> Double -> m Double
go Int
0 Double
0
Double
ljj <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
l (Int
j Int -> Int -> Ix2
:. Int
j)
MArray (PrimState m) S Ix2 Double -> Ix2 -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Ix2 Double
u (Int
j Int -> Int -> Ix2
:. Int
i) (((SRMatrix
mtx SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
j Int -> Int -> Ix2
:. Int
i)) Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
s') Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
ljj))
buildU :: Int -> Int -> m ()
buildU !Int
i !Int
j
= Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do Int -> Int -> m ()
buildUVal Int
i Int
j
Int -> Int -> m ()
buildU (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
j
buildLU :: Int -> m ()
buildLU !Int
j
= Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
do Int -> Int -> m ()
forall {m :: * -> *}.
(PrimState m ~ PrimState m, PrimMonad m) =>
Int -> Int -> m ()
buildL Int
j Int
j
Int -> Int -> m ()
buildU Int
j Int
j
Int -> m ()
buildLU (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
Int -> m ()
buildLU Int
0
SRMatrix
finalL <- MArray (PrimState m) S Ix2 Double -> m SRMatrix
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
freezeS MArray (PrimState m) S Ix2 Double
l
SRMatrix
finalU <- MArray (PrimState m) S Ix2 Double -> m SRMatrix
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
freezeS MArray (PrimState m) S Ix2 Double
u
(SRMatrix, SRMatrix) -> m (SRMatrix, SRMatrix)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SRMatrix
finalL, SRMatrix
finalU)
forwardSub :: (PrimMonad m, MonadThrow m, MonadIO m) => SRMatrix -> PVector -> m PVector
forwardSub :: forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> PVector -> m PVector
forwardSub SRMatrix
a PVector
b = do
let (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
b
MArray (PrimState m) S Int Double
x <- PVector -> m (MArray (PrimState m) S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
thawS (PVector -> m (MArray (PrimState m) S Int Double))
-> PVector -> m (MArray (PrimState m) S Int Double)
forall a b. (a -> b) -> a -> b
$ Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
A.replicate Comp
compMode (Int -> Sz Int
Sz1 Int
m) Double
0
let coeff :: Int -> Int -> Double -> m Double
coeff !Int
i !Int
j !Double
s
| Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
s
| Bool
otherwise = do let aij :: Double
aij = SRMatrix
a SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
i Int -> Int -> Ix2
:. Int
j)
Double
xj <- MArray (PrimState m) S Int Double -> Int -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Int Double
x Int
j
Int -> Int -> Double -> m Double
coeff Int
i (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
aij Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xj)
go :: Int -> m ()
go !Int
i = Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
iInt -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
m) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
do let bi :: Double
bi = PVector
b PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
i
aii :: Double
aii = SRMatrix
a SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
i Int -> Int -> Ix2
:. Int
i)
Double
c <- Int -> Int -> Double -> m Double
coeff Int
i Int
0 Double
0
MArray (PrimState m) S Int Double -> Int -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Int Double
x Int
i ((Double
bi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
c)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
aii)
Int -> m ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
Int -> m ()
go Int
0
MArray (PrimState m) S Int Double -> m PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
freezeS MArray (PrimState m) S Int Double
x
backwardSub :: (PrimMonad m, MonadThrow m, MonadIO m) => SRMatrix -> PVector -> m PVector
backwardSub :: forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> PVector -> m PVector
backwardSub SRMatrix
a PVector
b = do
let (Sz Int
m) = PVector -> Sz Int
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
size PVector
b
MArray (PrimState m) S Int Double
x <- PVector -> m (MArray (PrimState m) S Int Double)
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Array r ix e -> m (MArray (PrimState m) r ix e)
thawS (PVector -> m (MArray (PrimState m) S Int Double))
-> PVector -> m (MArray (PrimState m) S Int Double)
forall a b. (a -> b) -> a -> b
$ Comp -> Sz Int -> Double -> PVector
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
A.replicate Comp
compMode (Int -> Sz Int
Sz1 Int
m) Double
0
let coeff :: Int -> Int -> Double -> m Double
coeff !Int
i !Int
j !Double
s
| Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m = Double -> m Double
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
s
| Bool
otherwise = do let aij :: Double
aij = SRMatrix
a SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
i Int -> Int -> Ix2
:. Int
j)
Double
xj <- MArray (PrimState m) S Int Double -> Int -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Int Double
x Int
j
Int -> Int -> Double -> m Double
coeff Int
i (Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
aij Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
xj)
go :: Int -> m ()
go !Int
i = Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
do let bi :: Double
bi = PVector
b PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
i
aii :: Double
aii = SRMatrix
a SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
i Int -> Int -> Ix2
:. Int
i)
Double
c <- Int -> Int -> Double -> m Double
coeff Int
i (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Double
0
MArray (PrimState m) S Int Double -> Int -> Double -> m ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState m) S Int Double
x Int
i ((Double
bi Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
c)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
aii)
Int -> m ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
Int -> m ()
go (Int
mInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
MArray (PrimState m) S Int Double -> m PVector
forall r ix e (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> m (Array r ix e)
freezeS MArray (PrimState m) S Int Double
x
luSolve :: (PrimMonad m, MonadThrow m, MonadIO m) => SRMatrix -> PVector -> m PVector
luSolve :: forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> PVector -> m PVector
luSolve SRMatrix
a PVector
b = do (SRMatrix
l, SRMatrix
u) <- SRMatrix -> m (SRMatrix, SRMatrix)
forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> m (SRMatrix, SRMatrix)
lu SRMatrix
a
SRMatrix -> PVector -> m PVector
forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> PVector -> m PVector
forwardSub SRMatrix
l PVector
b m PVector -> (PVector -> m PVector) -> m PVector
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SRMatrix -> PVector -> m PVector
forall (m :: * -> *).
(PrimMonad m, MonadThrow m, MonadIO m) =>
SRMatrix -> PVector -> m PVector
backwardSub SRMatrix
u
type PolyCos = (Double, Double, Double)
cubicSplineCoefficients :: [(Double, Double)] -> [PolyCos]
cubicSplineCoefficients :: [(Double, Double)] -> [PolyCos]
cubicSplineCoefficients [(Double, Double)]
xs = [Double] -> [Double] -> [Double] -> [PolyCos]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
Prelude.zip3 [Double]
x [Double]
y [Double]
z'
where
x :: [Double]
x = ((Double, Double) -> Double) -> [(Double, Double)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double, Double) -> Double
forall a b. (a, b) -> a
fst [(Double, Double)]
xs
y :: [Double]
y = ((Double, Double) -> Double) -> [(Double, Double)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double, Double) -> Double
forall a b. (a, b) -> b
snd [(Double, Double)]
xs
xdiff :: [Double]
xdiff = (Double -> Double -> Double) -> [Double] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (-) ([Double] -> [Double]
forall a. HasCallStack => [a] -> [a]
tail [Double]
x) [Double]
x
xdiff' :: PVector
xdiff' = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
fromList Comp
compMode [Double]
xdiff :: Vector S Double
dydx :: Vector S Double
dydx :: PVector
dydx = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
fromList Comp
compMode ([Double] -> PVector) -> [Double] -> PVector
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double -> Double)
-> [Double] -> [Double] -> [Double] -> [Double]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
Prelude.zipWith3 (\Double
y0 Double
y1 Double
xd -> (Double
y0Double -> Double -> Double
forall a. Num a => a -> a -> a
-Double
y1)Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/Double
xd) ([Double] -> [Double]
forall a. HasCallStack => [a] -> [a]
tail [Double]
y) [Double]
y [Double]
xdiff
n :: Int
n = [Double] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Double]
x
w :: [Double]
w :: [Double]
w = Double
0 Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Int -> [Double] -> [Double]
nextW Int
1 [Double]
w
where
nextW :: Int -> [Double] -> [Double]
nextW Int
ix (Double
wi : [Double]
t)
| Int
ix Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 = []
| Bool
otherwise = let m :: Double
m = (PVector
xdiff' PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
wi) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* (PVector
xdiff' PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
ix)
wn :: Double
wn = (PVector
xdiff' PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
ix) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
m
in Double
wn Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Int -> [Double] -> [Double]
nextW (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Double]
t
z :: [Double]
z :: [Double]
z = Double
0 Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Int -> [Double] -> [Double]
nextZ Int
1 [Double]
z
where
nextZ :: Int -> [Double] -> [Double]
nextZ Int
ix (Double
zi : [Double]
t)
| Int
ix Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 = [Double
0]
| Bool
otherwise = let m :: Double
m = (PVector
xdiff' PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)) Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
- ([Double]
w [Double] -> Int -> Double
forall a. HasCallStack => [a] -> Int -> a
!! (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
2 Double -> Double -> Double
forall a. Num a => a -> a -> a
* (PVector
xdiff' PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
ix)
zn :: Double
zn = (Double
6Double -> Double -> Double
forall a. Num a => a -> a -> a
*((PVector
dydx PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Int
ix) Double -> Double -> Double
forall a. Num a => a -> a -> a
- (PVector
dydx PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))) Double -> Double -> Double
forall a. Num a => a -> a -> a
- (PVector
xdiff' PVector -> Int -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)) Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
zi) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
m
in Double
zn Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: Int -> [Double] -> [Double]
nextZ (Int
ixInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [Double]
t
z' :: [Double]
z' :: [Double]
z' = [Double] -> [Double]
forall a. [a] -> [a]
Prelude.reverse ([Double] -> [Double]) -> [Double] -> [Double]
forall a b. (a -> b) -> a -> b
$ Double
0 Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [[Double]
z [Double] -> Int -> Double
forall a. HasCallStack => [a] -> Int -> a
!! Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
- [Double]
w [Double] -> Int -> Double
forall a. HasCallStack => [a] -> Int -> a
!! Int
i Double -> Double -> Double
forall a. Num a => a -> a -> a
* [Double]
z [Double] -> Int -> Double
forall a. HasCallStack => [a] -> Int -> a
!! (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) | Int
i <- [Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2,Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
3 .. Int
0]]
chunkBy :: Int -> [t] -> [[t]]
chunkBy :: forall t. Int -> [t] -> [[t]]
chunkBy Int
n = ([t] -> Maybe ([t], [t])) -> [t] -> [[t]]
forall b a. (b -> Maybe (a, b)) -> b -> [a]
unfoldr [t] -> Maybe ([t], [t])
forall {a}. [a] -> Maybe ([a], [a])
go
where go :: [a] -> Maybe ([a], [a])
go [] = Maybe ([a], [a])
forall a. Maybe a
Nothing
go [a]
x = ([a], [a]) -> Maybe ([a], [a])
forall a. a -> Maybe a
Just (([a], [a]) -> Maybe ([a], [a])) -> ([a], [a]) -> Maybe ([a], [a])
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [a]
x
genSplineFun :: [(Double, Double)] -> Double -> Double
genSplineFun :: [(Double, Double)] -> Double -> Double
genSplineFun [(Double, Double)]
pts Double
x = [Double] -> [(PolyCos, PolyCos)] -> Double
go [Double]
xs ([(PolyCos, PolyCos)] -> Double) -> [(PolyCos, PolyCos)] -> Double
forall a b. (a -> b) -> a -> b
$ [PolyCos] -> [PolyCos] -> [(PolyCos, PolyCos)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PolyCos]
coefs ([PolyCos] -> [PolyCos]
forall a. HasCallStack => [a] -> [a]
tail [PolyCos]
coefs)
where
xs :: [Double]
xs = ((Double, Double) -> Double) -> [(Double, Double)] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map (Double, Double) -> Double
forall a b. (a, b) -> a
fst [(Double, Double)]
pts
coefs :: [PolyCos]
coefs = [(Double, Double)] -> [PolyCos]
cubicSplineCoefficients [(Double, Double)]
pts
evalAt :: (a, a, a) -> (a, a, a) -> a -> a
evalAt (a
a1,a
b1,a
c1) (a
a2,a
b2,a
c2) a
y = let hi1 :: a
hi1 = a
a2 a -> a -> a
forall a. Num a => a -> a -> a
- a
a1
in a
c1a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
6a -> a -> a
forall a. Num a => a -> a -> a
*a
hi1)a -> a -> a
forall a. Num a => a -> a -> a
*(a
a2a -> a -> a
forall a. Num a => a -> a -> a
-a
y)a -> Integer -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3 a -> a -> a
forall a. Num a => a -> a -> a
+ a
c2a -> a -> a
forall a. Fractional a => a -> a -> a
/(a
6a -> a -> a
forall a. Num a => a -> a -> a
*a
hi1)a -> a -> a
forall a. Num a => a -> a -> a
*(a
ya -> a -> a
forall a. Num a => a -> a -> a
-a
a1)a -> Integer -> a
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3 a -> a -> a
forall a. Num a => a -> a -> a
+ (a
b2a -> a -> a
forall a. Fractional a => a -> a -> a
/a
hi1 a -> a -> a
forall a. Num a => a -> a -> a
- a
c2a -> a -> a
forall a. Num a => a -> a -> a
*a
hi1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
6)a -> a -> a
forall a. Num a => a -> a -> a
*(a
ya -> a -> a
forall a. Num a => a -> a -> a
-a
a1) a -> a -> a
forall a. Num a => a -> a -> a
+ (a
b1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
hi1 a -> a -> a
forall a. Num a => a -> a -> a
- a
c1a -> a -> a
forall a. Num a => a -> a -> a
*a
hi1a -> a -> a
forall a. Fractional a => a -> a -> a
/a
6)a -> a -> a
forall a. Num a => a -> a -> a
*(a
a2a -> a -> a
forall a. Num a => a -> a -> a
-a
y)
go :: [Double] -> [(PolyCos, PolyCos)] -> Double
go [Double
x1,Double
x2] [(PolyCos
c1,PolyCos
c2)] = PolyCos -> PolyCos -> Double -> Double
forall {a}. Fractional a => (a, a, a) -> (a, a, a) -> a -> a
evalAt PolyCos
c1 PolyCos
c2 Double
x
go (Double
x1:Double
x2:[Double]
xs) ((PolyCos
c1,PolyCos
c2):[(PolyCos, PolyCos)]
cs)
| Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
x1 = PolyCos -> PolyCos -> Double -> Double
forall {a}. Fractional a => (a, a, a) -> (a, a, a) -> a -> a
evalAt PolyCos
c1 PolyCos
c2 Double
x
| Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>= Double
x1 Bool -> Bool -> Bool
&& Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
<= Double
x2 = PolyCos -> PolyCos -> Double -> Double
forall {a}. Fractional a => (a, a, a) -> (a, a, a) -> a -> a
evalAt PolyCos
c1 PolyCos
c2 Double
x
| Bool
otherwise = [Double] -> [(PolyCos, PolyCos)] -> Double
go (Double
x2Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
:[Double]
xs) [(PolyCos, PolyCos)]
cs