{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
module Numeric.Subroutine.Householder
  ( householderReflectionInplaceR
  , householderReflectionInplaceL
  ) where


import Control.Monad
import Control.Monad.ST
import Data.Kind
import Numeric.Basics
import Numeric.DataFrame.ST
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.Scalar.Internal


{- | Run a Householder transformation inplace.

     Given some orthongonal matrix \(P\), some matrix \(R\) and index \((k,l)\),
     reflects \(R\) along some hyperplane, such that all elements of \(R\)
     below index \( (k, l) \) become zeros,
     then updates \(P\) with the inverse of the same transform as \(R\).

     Notes and invariants:

       1. The transformation happens inplace for both matrices \(P\) and \(R\);
          if \( R = P^\intercal A \), then \( R' = P^*PR = P'^\intercal A \), where
           \( P' \) and \( R' \) are the updated versions of the input matrices,
           \( P^* \) and \( A \) are implicit matrices.

       2. All elements below and to the left of index \(k,l\) in \(R\)
          are assumed (and not checked) to be zeros;
          these are not touched by the subroutine to save flops.

       3. A logical starting value for \(P\) is an identity matrix.
          The subroutine can be used for a QR decomposition:
            \( Q = P \).

     Returns @True@ if reflection has been performed, and @False@ if it was not needed.
     This can be used to track the sign of @det P@.
 -}
householderReflectionInplaceL ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m)
    => STDataFrame s t '[n] -- ^ Temporary buffer for a Householder axis vector
    -> STDataFrame s t '[n,n]  -- ^ Current state of \(P^\intercal\)
    -> STDataFrame s t '[n,m]  -- ^ Current state of \(R\)
    -> Idxs '[n,m] -- ^ Pivot element
    -> ST s Bool
householderReflectionInplaceL u p r (Idx i :* Idx j :* U)
    = householderReflectionInplaceL' u p r
      (fromIntegral $ dimVal' @n)
      (fromIntegral $ dimVal' @m)
      (fromIntegral i)
      (fromIntegral j)

householderReflectionInplaceL' ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (PrimBytes t, Epsilon t, Ord t)
    => STDataFrame s t '[n] -- ^ Temporary buffer for a Householder axis vector
    -> STDataFrame s t '[n,n]  -- ^ \(P^\intercal\)
    -> STDataFrame s t '[n,m]  -- ^ \(R\)
    -> Int -- ^ \(n\)
    -> Int -- ^ \(m\)
    -> Int -- ^ \( 0 \leq k < n \)
    -> Int -- ^ \( 0 \leq l < m \)
    -> ST s Bool
householderReflectionInplaceL' uPtr pPtr rPtr n m k l = do
    -- pivot element (k,l) of new R
    alpha <- getAlphaAndUpdateU
    u2 <- getU2
    -- u2 == 0 means the column is already zeroed
    if u2 > M_EPS
    then do
      let c = 2 / u2 -- a mult constant for updating matrices
      -- update R
      updateRl alpha
      forM_ [l+1..m-1] $ updateRi c
      -- update P
      forM_ [0..n-1] $ updatePi c
      return True
    else return False
  where
    n' = n - k -- remaining rows
    rOff0 = k*m + l -- offset of element (k,l) in matrix R

    -- u = Rk - alpha*ek
    getAlphaAndUpdateU :: ST s (Scalar t)
    getAlphaAndUpdateU = do
      alpha' <- sqrt . fst <$> nTimesM n'
        (\(r, off) -> do
          x <- readDataFrameOff rPtr off
          when (abs x <= M_EPS) $ writeDataFrameOff rPtr off 0
          return (r + x*x, off + m)
        ) (0, rOff0)
      x0 <- readDataFrameOff rPtr rOff0
      let alpha = if x0 >= 0 then negate alpha' else alpha'
      -- update (lower part of) u
      writeDataFrameOff uPtr k (x0 - alpha)
      when (n' >= 1) $ void $ nTimesM (n' - 1)
        (\(i, off) -> (i+1, off+m) <$
          (readDataFrameOff rPtr off >>= writeDataFrameOff uPtr i)
        ) (k+1, rOff0+m)
      return alpha

    -- l-th column of R zeroes below pivot
    updateRl :: Scalar t -> ST s ()
    updateRl alpha = do
      writeDataFrameOff rPtr rOff0 alpha
      when (n' >= 1) $ void $ nTimesM (n' - 1)
        (\off -> (off+m) <$ writeDataFrameOff rPtr off 0) (rOff0+m)

    -- update i-th column of R
    updateRi :: Scalar t -> Int -> ST s ()
    updateRi c i = do
      -- dot product of u and Ri
      uRi <- fmap fst . flip (nTimesM n') (0, (k, k*m+i)) $ \(r, (j, off)) -> do
        ju  <- readDataFrameOff uPtr j
        jiR <- readDataFrameOff rPtr off
        return (r + ju * jiR, (j+1, off+m))
      let c' = c * uRi
      -- update each element
      void $ flip (nTimesM n') (k, k*m+i) $ \(j, off) -> do
        ju  <- readDataFrameOff uPtr j
        jiR <- readDataFrameOff rPtr off
        writeDataFrameOff rPtr off $ jiR - c'*ju
        return (j+1, off+m)

    -- update i-th row of P
    updatePi :: Scalar t -> Int -> ST s ()
    updatePi c i = do
      let off0 = i*n
      -- dot product of u and Pi
      uPi <- fmap fst . flip (nTimesM n') (0, k) $ \(r, j) -> do
        ju  <- readDataFrameOff uPtr  j
        ijP <- readDataFrameOff pPtr (off0 + j)
        return (r + ju * ijP, j+1)
      let c' = c * uPi
      -- update each element
      forM_ [k..n-1] $ \j -> do
        ju  <- readDataFrameOff uPtr j
        ijP <- readDataFrameOff pPtr (off0 + j)
        writeDataFrameOff pPtr (off0 + j) $ ijP - c'*ju

    -- get module squared of u (for Q = I - 2 u*uT / |u|^2 )
    getU2 :: ST s (Scalar t)
    getU2 = fst <$> nTimesM n'
      (\(r, off) -> (\x -> (r + x*x, off + 1)) <$> readDataFrameOff uPtr off) (0, k)

{- | Run a Householder transformation inplace.

  Similar to `householderReflectionInplaceR`, but works from right to left
   - use to zero elements to the right from the pivot.

     Returns @True@ if reflection has been performed, and @False@ if it was not needed.
     This can be used to track the sign of @det P@.
 -}
householderReflectionInplaceR ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m)
    => STDataFrame s t '[m] -- ^ Temporary buffer for a Householder axis vector
    -> STDataFrame s t '[m,m]  -- ^ Current state of \(P^\intercal\)
    -> STDataFrame s t '[n,m]  -- ^ Current state of \(R\)
    -> Idxs '[n,m] -- ^ Pivot element
    -> ST s Bool
householderReflectionInplaceR u p r (Idx i :* Idx j :* U)
    = householderReflectionInplaceR' u p r
      (fromIntegral $ dimVal' @n)
      (fromIntegral $ dimVal' @m)
      (fromIntegral i)
      (fromIntegral j)

householderReflectionInplaceR' ::
       forall (s :: Type) (t :: Type) (n :: Nat) (m :: Nat)
     . (PrimBytes t, Epsilon t, Ord t)
    => STDataFrame s t '[m] -- ^ Temporary buffer for a Householder axis vector
    -> STDataFrame s t '[m,m]  -- ^ \(P^\intercal\)
    -> STDataFrame s t '[n,m]  -- ^ \(R\)
    -> Int -- ^ \(n\)
    -> Int -- ^ \(m\)
    -> Int -- ^ \( 0 \leq k < n \)
    -> Int -- ^ \( 0 \leq l < m \)
    -> ST s Bool
householderReflectionInplaceR' uPtr pPtr rPtr n m k l = do
    -- pivot element (k,l) of new R
    alpha <- getAlphaAndUpdateU
    u2 <- getU2
    -- u2 == 0 means the column is already zeroed
    if u2 > M_EPS
    then do
      let c = 2 / u2 -- a mult constant for updating matrices
      -- update R
      updateRk alpha
      forM_ [k+1..n-1] $ updateRi c
      -- update P
      forM_ [0..m-1] $ updatePi c
      return True
    else return False
  where
    m' = m - l -- remaining cols
    rOff0 = k*m + l -- offset of element (k,l) in matrix R

    -- u = Rl - alpha*el
    getAlphaAndUpdateU :: ST s (Scalar t)
    getAlphaAndUpdateU = do
      alpha' <- sqrt . fst <$> nTimesM m'
        (\(r, off) -> do
          x <- readDataFrameOff rPtr off
          when (abs x <= M_EPS) $ writeDataFrameOff rPtr off 0
          return (r + x*x, off + 1)
        ) (0, rOff0)
      x0 <- readDataFrameOff rPtr rOff0
      let alpha = if x0 >= 0 then negate alpha' else alpha'
      -- update (lower part of) u
      writeDataFrameOff uPtr l (x0 - alpha)
      forM_ [1..m'-1] $ \i ->
        readDataFrameOff rPtr (rOff0 + i) >>= writeDataFrameOff uPtr (l + i)
      return alpha

    -- k-th row of R zeroes below pivot
    updateRk :: Scalar t -> ST s ()
    updateRk alpha = do
      writeDataFrameOff rPtr rOff0 alpha
      forM_ [rOff0+1..rOff0+m'-1] $ flip (writeDataFrameOff rPtr) 0

    -- update i-th row of R
    updateRi :: Scalar t -> Int -> ST s ()
    updateRi c i = do
      let off0 = i*m
      -- dot product of u and Ri
      uRi <- fmap fst . flip (nTimesM m') (0, l) $ \(r, j) -> do
        ju  <- readDataFrameOff uPtr  j
        jiR <- readDataFrameOff rPtr (off0 + j)
        return (r + ju * jiR, j+1)
      let c' = c * uRi
      -- update each element
      forM_ [l..m-1] $ \j -> do
        ju  <- readDataFrameOff uPtr j
        jiR <- readDataFrameOff rPtr (off0 + j)
        writeDataFrameOff rPtr (off0 + j) $ jiR - c'*ju

    -- update i-th row of P
    updatePi :: Scalar t -> Int -> ST s ()
    updatePi c i = do
      let off0 = i*m
      -- dot product of u and Pi
      uPi <- fmap fst . flip (nTimesM m') (0, l) $ \(r, j) -> do
        ju  <- readDataFrameOff uPtr  j
        ijP <- readDataFrameOff pPtr (off0 + j)
        return (r + ju * ijP, j+1)
      let c' = c * uPi
      -- update each element
      forM_ [l..m-1] $ \j -> do
        ju  <- readDataFrameOff uPtr j
        ijP <- readDataFrameOff pPtr (off0 + j)
        writeDataFrameOff pPtr (off0 + j) $ ijP - c'*ju

    -- get module squared of u (for Q = I - 2 u*uT / |u|^2 )
    getU2 :: ST s (Scalar t)
    getU2 = fst <$> nTimesM m'
      (\(r, off) -> (\x -> (r + x*x, off + 1)) <$> readDataFrameOff uPtr off) (0, l)


nTimesM :: Monad m => Int -> (a -> m a) -> a -> m a
nTimesM 0 _ x = pure x
nTimesM n m x = m x >>= nTimesM (n - 1) m