{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.Subroutine.Householder
( householderReflectionInplaceR
, householderReflectionInplaceL
) where
import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.Basics
import Numeric.DataFrame.ST
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Scalar.Internal
householderReflectionInplaceL ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
. (PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m)
=> STDataFrame s t '[n]
-> STDataFrame s t '[n,n]
-> STDataFrame s t '[n,m]
-> Idxs '[n,m]
-> ST s Bool
householderReflectionInplaceL u p r (Idx i :* Idx j :* U)
= householderReflectionInplaceL' u p r
(fromIntegral $ dimVal' @n)
(fromIntegral $ dimVal' @m)
(fromIntegral i)
(fromIntegral j)
householderReflectionInplaceL' ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
. (PrimBytes t, Epsilon t, Ord t)
=> STDataFrame s t '[n]
-> STDataFrame s t '[n,n]
-> STDataFrame s t '[n,m]
-> Int
-> Int
-> Int
-> Int
-> ST s Bool
householderReflectionInplaceL' uPtr pPtr rPtr n m k l = do
alpha <- getAlphaAndUpdateU
u2 <- getU2
if u2 > M_EPS
then do
let c = 2 / u2
updateRl alpha
forM_ [l+1..m-1] $ updateRi c
forM_ [0..n-1] $ updatePi c
return True
else return False
where
n' = n - k
rOff0 = k*m + l
getAlphaAndUpdateU :: ST s (Scalar t)
getAlphaAndUpdateU = do
alpha' <- sqrt . fst <$> nTimesM n'
(\(r, off) -> do
x <- readDataFrameOff rPtr off
when (abs x <= M_EPS) $ writeDataFrameOff rPtr off 0
return (r + x*x, off + m)
) (0, rOff0)
x0 <- readDataFrameOff rPtr rOff0
let alpha = if x0 >= 0 then negate alpha' else alpha'
writeDataFrameOff uPtr k (x0 - alpha)
when (n' >= 1) $ void $ nTimesM (n' - 1)
(\(i, off) -> (i+1, off+m) <$
(readDataFrameOff rPtr off >>= writeDataFrameOff uPtr i)
) (k+1, rOff0+m)
return alpha
updateRl :: Scalar t -> ST s ()
updateRl alpha = do
writeDataFrameOff rPtr rOff0 alpha
when (n' >= 1) $ void $ nTimesM (n' - 1)
(\off -> (off+m) <$ writeDataFrameOff rPtr off 0) (rOff0+m)
updateRi :: Scalar t -> Int -> ST s ()
updateRi c i = do
uRi <- fmap fst . flip (nTimesM n') (0, (k, k*m+i)) $ \(r, (j, off)) -> do
ju <- readDataFrameOff uPtr j
jiR <- readDataFrameOff rPtr off
return (r + ju * jiR, (j+1, off+m))
let c' = c * uRi
void $ flip (nTimesM n') (k, k*m+i) $ \(j, off) -> do
ju <- readDataFrameOff uPtr j
jiR <- readDataFrameOff rPtr off
writeDataFrameOff rPtr off $ jiR - c'*ju
return (j+1, off+m)
updatePi :: Scalar t -> Int -> ST s ()
updatePi c i = do
let off0 = i*n
uPi <- fmap fst . flip (nTimesM n') (0, k) $ \(r, j) -> do
ju <- readDataFrameOff uPtr j
ijP <- readDataFrameOff pPtr (off0 + j)
return (r + ju * ijP, j+1)
let c' = c * uPi
forM_ [k..n-1] $ \j -> do
ju <- readDataFrameOff uPtr j
ijP <- readDataFrameOff pPtr (off0 + j)
writeDataFrameOff pPtr (off0 + j) $ ijP - c'*ju
getU2 :: ST s (Scalar t)
getU2 = fst <$> nTimesM n'
(\(r, off) -> (\x -> (r + x*x, off + 1)) <$> readDataFrameOff uPtr off) (0, k)
householderReflectionInplaceR ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
. (PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m)
=> STDataFrame s t '[m]
-> STDataFrame s t '[m,m]
-> STDataFrame s t '[n,m]
-> Idxs '[n,m]
-> ST s Bool
householderReflectionInplaceR u p r (Idx i :* Idx j :* U)
= householderReflectionInplaceR' u p r
(fromIntegral $ dimVal' @n)
(fromIntegral $ dimVal' @m)
(fromIntegral i)
(fromIntegral j)
householderReflectionInplaceR' ::
forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
. (PrimBytes t, Epsilon t, Ord t)
=> STDataFrame s t '[m]
-> STDataFrame s t '[m,m]
-> STDataFrame s t '[n,m]
-> Int
-> Int
-> Int
-> Int
-> ST s Bool
householderReflectionInplaceR' uPtr pPtr rPtr n m k l = do
alpha <- getAlphaAndUpdateU
u2 <- getU2
if u2 > M_EPS
then do
let c = 2 / u2
updateRk alpha
forM_ [k+1..n-1] $ updateRi c
forM_ [0..m-1] $ updatePi c
return True
else return False
where
m' = m - l
rOff0 = k*m + l
getAlphaAndUpdateU :: ST s (Scalar t)
getAlphaAndUpdateU = do
alpha' <- sqrt . fst <$> nTimesM m'
(\(r, off) -> do
x <- readDataFrameOff rPtr off
when (abs x <= M_EPS) $ writeDataFrameOff rPtr off 0
return (r + x*x, off + 1)
) (0, rOff0)
x0 <- readDataFrameOff rPtr rOff0
let alpha = if x0 >= 0 then negate alpha' else alpha'
writeDataFrameOff uPtr l (x0 - alpha)
forM_ [1..m'-1] $ \i ->
readDataFrameOff rPtr (rOff0 + i) >>= writeDataFrameOff uPtr (l + i)
return alpha
updateRk :: Scalar t -> ST s ()
updateRk alpha = do
writeDataFrameOff rPtr rOff0 alpha
forM_ [rOff0+1..rOff0+m'-1] $ flip (writeDataFrameOff rPtr) 0
updateRi :: Scalar t -> Int -> ST s ()
updateRi c i = do
let off0 = i*m
uRi <- fmap fst . flip (nTimesM m') (0, l) $ \(r, j) -> do
ju <- readDataFrameOff uPtr j
jiR <- readDataFrameOff rPtr (off0 + j)
return (r + ju * jiR, j+1)
let c' = c * uRi
forM_ [l..m-1] $ \j -> do
ju <- readDataFrameOff uPtr j
jiR <- readDataFrameOff rPtr (off0 + j)
writeDataFrameOff rPtr (off0 + j) $ jiR - c'*ju
updatePi :: Scalar t -> Int -> ST s ()
updatePi c i = do
let off0 = i*m
uPi <- fmap fst . flip (nTimesM m') (0, l) $ \(r, j) -> do
ju <- readDataFrameOff uPtr j
ijP <- readDataFrameOff pPtr (off0 + j)
return (r + ju * ijP, j+1)
let c' = c * uPi
forM_ [l..m-1] $ \j -> do
ju <- readDataFrameOff uPtr j
ijP <- readDataFrameOff pPtr (off0 + j)
writeDataFrameOff pPtr (off0 + j) $ ijP - c'*ju
getU2 :: ST s (Scalar t)
getU2 = fst <$> nTimesM m'
(\(r, off) -> (\x -> (r + x*x, off + 1)) <$> readDataFrameOff uPtr off) (0, l)
nTimesM :: Monad m => Int -> (a -> m a) -> a -> m a
nTimesM 0 _ x = pure x
nTimesM n m x = m x >>= nTimesM (n - 1) m