{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
module Numeric.Subroutine.SolveTriangular
( solveUpperTriangularR
, solveUpperTriangularL
, solveLowerTriangularR
, solveLowerTriangularL
) where
import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.Type
import Numeric.Dimensions
solveUpperTriangularR ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
. (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n)
=> DataFrame t '[n,m]
-> STDataFrame s t (m :+ ds)
-> ST s ()
solveUpperTriangularR r bPtr | Dict <- Dict @(m <= n) = mapM_ go [m-1,m-2..0]
where
m = fromIntegral $ dimVal' @m :: Int
k = fromIntegral kw :: Int
CumulDims (_:kw:_) = getDataFrameSteps bPtr
go :: Int -> ST s ()
go i | rii == 0 = forM_ [0..k-1] $ \j -> writeDataFrameOff bPtr (ki + j) 0
| otherwise = forM_ [0..k-1] $ \j -> do
bij <- (rrii*) <$> readDataFrameOff bPtr (ki + j)
writeDataFrameOff bPtr (ki + j) bij
forM_ [0..i-1] $ \t -> do
let rti = scalar (ixOff (m*t + i) r)
ix = k*t + j
btj <- readDataFrameOff bPtr ix
writeDataFrameOff bPtr ix (btj - rti*bij)
where
mi = m*i
ki = k*i
rii = scalar (ixOff (mi + i) r)
rrii = recip rii
solveUpperTriangularL ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
. (PrimBytes t, Fractional t, Eq t, KnownDim m, m <= n)
=> STDataFrame s t (ds +: m)
-> DataFrame t '[n,m]
-> ST s ()
solveUpperTriangularL bPtr r | Dict <- Dict @(m <= n) = mapM_ go [0..m-1]
where
m = fromIntegral $ dimVal' @m :: Int
k = fromIntegral mkw `quot` m
CumulDims (mkw:_) = getDataFrameSteps bPtr
go :: Int -> ST s ()
go i | rii == 0 = forM_ [0..k-1] $ \j -> writeDataFrameOff bPtr (m*j + i) 0
| otherwise = forM_ [0..k-1] $ \j -> do
let mj = m*j
bji <- (rrii*) <$> readDataFrameOff bPtr (mj + i)
writeDataFrameOff bPtr (mj + i) bji
forM_ [i+1..m-1] $ \t -> do
let rit = scalar (ixOff (mi + t) r)
bjt <- readDataFrameOff bPtr (mj + t)
writeDataFrameOff bPtr (mj + t) (bjt - rit*bji)
where
mi = m*i
rii = scalar (ixOff (mi + i) r)
rrii = recip rii
solveLowerTriangularR ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
. (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m)
=> DataFrame t '[n,m]
-> STDataFrame s t (n :+ ds)
-> ST s ()
solveLowerTriangularR l bPtr | Dict <- Dict @(n <= m) = mapM_ go [0..n-1]
where
m = fromIntegral $ dimVal' @m :: Int
n = fromIntegral $ dimVal' @n :: Int
k = fromIntegral kw :: Int
CumulDims (_:kw:_) = getDataFrameSteps bPtr
go :: Int -> ST s ()
go i | lii == 0 = forM_ [0..k-1] $ \j -> writeDataFrameOff bPtr (ki + j) 0
| otherwise = forM_ [0..k-1] $ \j -> do
bij <- (rlii*) <$> readDataFrameOff bPtr (ki + j)
writeDataFrameOff bPtr (ki + j) bij
forM_ [i+1..n-1] $ \t -> do
let rti = scalar (ixOff (m*t + i) l)
ix = k*t + j
btj <- readDataFrameOff bPtr ix
writeDataFrameOff bPtr ix (btj - rti*bij)
where
mi = m*i
ki = k*i
lii = scalar (ixOff (mi + i) l)
rlii = recip lii
solveLowerTriangularL ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat) (ds :: [Nat])
. (PrimBytes t, Fractional t, Eq t, KnownDim n, KnownDim m, n <= m)
=> STDataFrame s t (ds +: m)
-> DataFrame t '[n,m]
-> ST s ()
solveLowerTriangularL bPtr l | Dict <- Dict @(n <= m) = mapM_ go [n-1,n-2..0]
where
m = fromIntegral $ dimVal' @m :: Int
n = fromIntegral $ dimVal' @n :: Int
k = fromIntegral kmw `quot` m
CumulDims (kmw:_) = getDataFrameSteps bPtr
go :: Int -> ST s ()
go i | lii == 0 = forM_ [0..k-1] $ \j -> writeDataFrameOff bPtr (m*j + i) 0
| otherwise = forM_ [0..k-1] $ \j -> do
let mj = m*j
bji <- (rlii*) <$> readDataFrameOff bPtr (mj + i)
writeDataFrameOff bPtr (mj + i) bji
forM_ [0..i-1] $ \t -> do
let lit = scalar (ixOff (mi + t) l)
ix = m*j + t
bjt <- readDataFrameOff bPtr ix
writeDataFrameOff bPtr ix (bjt - lit*bji)
where
mi = m*i
lii = scalar (ixOff (mi + i) l)
rlii = recip lii