{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash, UnboxedTuples, DataKinds #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.Matrix.Base.FloatXNM
-- Copyright   :  (c) Artem Chirkin
-- License     :  MIT
--
-- Maintainer  :  chirkin@arch.ethz.ch
--
--
-----------------------------------------------------------------------------

module Numeric.Matrix.Base.FloatXNM () where

#include "MachDeps.h"
#include "HsBaseConfig.h"

import GHC.Base (runRW#)
import GHC.Prim
import GHC.Types
import GHC.TypeLits


import Numeric.Commons
import Numeric.Matrix.Class
import Numeric.Matrix.Family




instance (KnownNat n, KnownNat m) => Show (MFloatXNM n m) where
  show x@(MFloatXNM arr) = "{" ++ drop 2 (loop' (n -# 1#) (m -# 1#) " }")
    where
      loop' i j acc | isTrue# (i ==# -1#) = acc
                    | isTrue# (j ==# -1#) = loop' (i -# 1#) (m -# 1#) ('\n':acc)
                    | otherwise           = loop' i (j -# 1#) (", " ++ show (F# (indexFloatArray# arr (i +# n *# j))) ++ acc)
      n = dimN# x
      m = dimM# x



instance (KnownNat n, KnownNat m) => Eq (MFloatXNM n m) where
  a == b = accumV2 (\x y r -> r && isTrue# (x `eqFloat#` y)) a b True
  {-# INLINE (==) #-}
  a /= b = accumV2 (\x y r -> r || isTrue# (x `neFloat#` y)) a b False
  {-# INLINE (/=) #-}




-- | Implement partial ordering for `>`, `<`, `>=`, `<=` and lexicographical ordering for `compare`
instance (KnownNat n, KnownNat m) => Ord (MFloatXNM n m) where
  a > b = accumV2 (\x y r -> r && isTrue# (x `gtFloat#` y)) a b True
  {-# INLINE (>) #-}
  a < b = accumV2 (\x y r -> r && isTrue# (x `ltFloat#` y)) a b True
  {-# INLINE (<) #-}
  a >= b = accumV2 (\x y r -> r && isTrue# (x `geFloat#` y)) a b True
  {-# INLINE (>=) #-}
  a <= b = accumV2 (\x y r -> r && isTrue# (x `leFloat#` y)) a b True
  {-# INLINE (<=) #-}
  -- | Compare lexicographically
  compare a b = accumV2 (\x y r -> r `mappend`
                          if isTrue# (x `gtFloat#` y)
                          then GT
                          else if isTrue# (x `ltFloat#` y)
                               then LT
                               else EQ
                        ) a b EQ
  {-# INLINE compare #-}




instance (KnownNat n, KnownNat m) => Num (MFloatXNM n m) where
  (+) = zipV plusFloat#
  {-# INLINE (+) #-}
  (-) = zipV minusFloat#
  {-# INLINE (-) #-}
  (*) = zipV timesFloat#
  {-# INLINE (*) #-}
  negate = mapV negateFloat#
  {-# INLINE negate #-}
  abs = mapV (\x -> if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x)
  {-# INLINE abs #-}
  signum = mapV (\x -> if isTrue# (x `gtFloat#` 0.0#) then 1.0# else if isTrue# (x `ltFloat#` 0.0#) then -1.0# else 0.0#)
  {-# INLINE signum #-}
  fromInteger = broadcastMat . fromInteger
  {-# INLINE fromInteger #-}



instance (KnownNat n, KnownNat m) => Fractional (MFloatXNM n m) where
  (/) = zipV divideFloat#
  {-# INLINE (/) #-}
  recip = mapV (divideFloat# 1.0#)
  {-# INLINE recip #-}
  fromRational = broadcastMat . fromRational
  {-# INLINE fromRational #-}



instance (KnownNat n, KnownNat m) => Floating (MFloatXNM n m) where
  pi = broadcastMat pi
  {-# INLINE pi #-}
  exp = mapV expFloat#
  {-# INLINE exp #-}
  log = mapV logFloat#
  {-# INLINE log #-}
  sqrt = mapV sqrtFloat#
  {-# INLINE sqrt #-}
  sin = mapV sinFloat#
  {-# INLINE sin #-}
  cos = mapV cosFloat#
  {-# INLINE cos #-}
  tan = mapV tanFloat#
  {-# INLINE tan #-}
  asin = mapV asinFloat#
  {-# INLINE asin #-}
  acos = mapV acosFloat#
  {-# INLINE acos #-}
  atan = mapV atanFloat#
  {-# INLINE atan #-}
  sinh = mapV sinFloat#
  {-# INLINE sinh #-}
  cosh = mapV coshFloat#
  {-# INLINE cosh #-}
  tanh = mapV tanhFloat#
  {-# INLINE tanh #-}
  (**) = zipV powerFloat#
  {-# INLINE (**) #-}

  logBase = zipV (\x y -> logFloat# y `divideFloat#` logFloat# x)
  {-# INLINE logBase #-}
  asinh = mapV (\x -> logFloat# (x `plusFloat#` sqrtFloat# (1.0# `plusFloat#` timesFloat# x x)))
  {-# INLINE asinh #-}
  acosh = mapV (\x ->  case plusFloat# x 1.0# of
                 y -> logFloat# ( x `plusFloat#` timesFloat# y (sqrtFloat# (minusFloat# x 1.0# `divideFloat#` y)))
               )
  {-# INLINE acosh #-}
  atanh = mapV (\x -> 0.5# `timesFloat#` logFloat# (plusFloat# 1.0# x `divideFloat#` minusFloat# 1.0# x))
  {-# INLINE atanh #-}








instance (KnownNat n, KnownNat m) => MatrixCalculus Float n m (MFloatXNM n m) where
  broadcastMat (F# x) = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop# n
               (\i s' -> writeFloatArray# marr i x s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> MFloatXNM r
    where
      n = dimN# (undefined :: MFloatXNM n m) *# dimM# (undefined :: MFloatXNM n m)
      bs = n *# SIZEOF_HSFLOAT#
  {-# INLINE broadcastMat #-}
  indexMat (I# i) (I# j) x@(MFloatXNM arr)
#ifndef UNSAFE_INDICES
      | isTrue# ( (i ># n)
           `orI#` (i <=# 0#)
           `orI#` (j ># _m)
           `orI#` (j <=# 0#)
          )       = error $ "Bad index (" ++ show (I# i) ++ ", " ++ show (I# j) ++ ") for "
                          ++ show (I# n) ++ "x" ++ show (I# _m) ++ "D matrix"
      | otherwise
#endif
                  = F# (indexFloatArray# arr (i -# 1# +# n *# (j -# 1#)))
    where
      n = dimN# x
      _m = dimM# x
  {-# INLINE indexMat #-}
  transpose x@(MFloatXNM arr) = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop2# n m
               (\i j s' -> writeFloatArray# marr (i +# n *# j) (indexFloatArray# arr (i *# m +# j)) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes r
    where
      n = dimN# x
      m = dimM# x
      bs = n *# m *# SIZEOF_HSFLOAT#
  dimN x = I# (dimN# x)
  {-# INLINE dimN #-}
  dimM x = I# (dimM# x)
  {-# INLINE dimM #-}
  indexCol (I# j) x@(MFloatXNM arr)
#ifndef UNSAFE_INDICES
      | isTrue# ( (j ># dimM# x)
           `orI#` (j <=# 0#)
          )       = error $ "Bad column index " ++ show (I# j) ++ " for "
                          ++ show (I# n) ++ "x" ++ show (dimM x) ++ "D matrix"
      | otherwise
#endif
                  = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case copyByteArray# arr (bs *# (j -# 1#)) marr 0# bs s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes r
    where
      n = dimN# x
      bs = n *# SIZEOF_HSFLOAT#
  indexRow (I# i) x@(MFloatXNM arr)
#ifndef UNSAFE_INDICES
      | isTrue# ( (i ># n)
           `orI#` (i <=# 0#)
          )       = error $ "Bad row index " ++ show (I# i) ++ " for "
                          ++ show (I# n) ++ "x" ++ show (I# m) ++ "D matrix"
      | otherwise
#endif
                  = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop# m
               (\j s' -> writeFloatArray# marr j (indexFloatArray# arr (i -# 1# +# n *# j)) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes r
    where
      n = dimN# x
      m = dimM# x
      bs = m *# SIZEOF_HSFLOAT#


instance (KnownNat n, KnownNat m) => PrimBytes (MFloatXNM n m) where
  toBytes (MFloatXNM a) = a
  {-# INLINE toBytes #-}
  fromBytes = MFloatXNM
  {-# INLINE fromBytes #-}
  byteSize x = SIZEOF_HSFLOAT# *# dimN# x *# dimM# x
  {-# INLINE byteSize #-}
  byteAlign _ = ALIGNMENT_HSFLOAT#
  {-# INLINE byteAlign #-}

instance FloatBytes (MFloatXNM n m) where
  ixF i (MFloatXNM a) = indexFloatArray# a i
  {-# INLINE ixF #-}


instance KnownNat n => SquareMatrixCalculus Float n (MFloatXNM n n) where
  eye = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop# n
               (\j s' -> writeFloatArray# marr (j *# n1) 1.0# s'
               ) (setByteArray# marr 0# bs 0# s1) of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes r
    where
      n1 = n +# 1#
      n = dimN# (undefined :: MFloatXNM n n)
      bs = n *# n *# SIZEOF_HSFLOAT#
  {-# INLINE eye #-}
  diag (F# v) = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop# n
               (\j s' -> writeFloatArray# marr (j *# n1) v s'
               ) (setByteArray# marr 0# bs 0# s1) of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes r
    where
      n1 = n +# 1#
      n = dimN# (undefined :: MFloatXNM n n)
      bs = n *# n *# SIZEOF_HSFLOAT#
  {-# INLINE diag #-}


  det v@(MFloatXNM arr) = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, mat #) -> case newByteArray# (n *# SIZEOF_HSFLOAT#) (copyByteArray# arr 0# mat 0# bs s1) of
           (# s2, vec #) ->
              let f i x s | isTrue# (i >=# n) = (# s, x #)
                          | otherwise =
                              let (# s' , j  #) = maxInRowRem# n n i mat s
                                  (# s'', x' #) = if isTrue# (i /=# j) then (# swapCols# n i j vec mat s', negateFloat# x #)
                                                                       else (# s', x #)
                                  (# s''', y #) = clearRowEnd# n n i mat s''
                              in if isTrue# (eqFloat# 0.0# y) then (# s''', 0.0# #)
                                                              else f (i +# 1#) (timesFloat# x' y) s'''
              in f 0# 1.0# s2
     ) of (# _, r #) -> F# r
    where
      n = dimN# v
      bs = n *# n *# SIZEOF_HSFLOAT#

  {-# INLINE det #-}



  trace x@(MFloatXNM a) = F# (loop' 0# 0.0#)
    where
      n1 = n +# 1#
      n = dimN# x
      nn = n *# n
      loop' i acc | isTrue# (i ># nn) = acc
                  | otherwise = loop' (i +# n1) (indexFloatArray# a i `plusFloat#` acc)
  {-# INLINE trace #-}
  fromDiag x@(MFloatXNM a) = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop# n
               (\j s' -> writeFloatArray# marr j (indexFloatArray# a (j *# n1)) s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes r
    where
      n1 = n +# 1#
      n = dimN# x
      bs = n *# SIZEOF_HSFLOAT#
  {-# INLINE fromDiag #-}
  toDiag x = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop# n
               (\j s' -> writeFloatArray# marr (j *# n1) (indexFloatArray# a j) s'
               ) (setByteArray# marr 0# bs 0# s1) of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> fromBytes r
    where
      a = toBytes x
      n1 = n +# 1#
      n = dimN# (undefined :: MFloatXNM n n)
      bs = n *# n *# SIZEOF_HSFLOAT#
  {-# INLINE toDiag #-}


instance KnownNat n => MatrixInverse (MFloatXNM n n) where
  inverse v@(MFloatXNM arr) = case runRW#
     ( \s0 -> case newByteArray# (bs *# 2#) s0 of
         (# s1, mat #) -> case newByteArray# (n *# SIZEOF_HSFLOAT#)
                -- copy original matrix to the top of an augmented matrix
                (loop# n (\i s -> writeFloatArray# mat (i *# nn +# i +# n) 1.0# (copyByteArray# arr (i *# vs) mat (2# *# i *# vs) vs s))
                         (setByteArray# mat 0# (bs *# 2#) 0# s1)
                ) of
           (# s2, vec #) ->
              let f i s | isTrue# (i >=# n) = s
                        | otherwise =
                            let (# s' , j  #) = maxInRowRem# nn n i mat s
                                s''           = if isTrue# (i /=# j) then swapCols# nn i j vec mat s'
                                                                     else s'
                                (# s''', _ #) = clearRowAll# nn n i mat s''
                            in f (i +# 1#) s'''
              in unsafeFreezeByteArray# mat
                  ( shrinkMutableByteArray# mat bs
                   (-- copy inverse matrix from the augmented part
                    loop# n (\i s -> copyMutableByteArray# mat (2# *# i *# vs +# vs) mat (i *# vs) vs s)
                   (f 0# s2)
                   )
                  )
     ) of (# _, r #) -> MFloatXNM r
    where
      nn = 2# *# n
      n = dimN# v
      vs = n *# SIZEOF_HSFLOAT#
      bs = n *# n *# SIZEOF_HSFLOAT#


-----------------------------------------------------------------------------
-- Helpers
-----------------------------------------------------------------------------


proxyN# :: MFloatXNM n m -> Proxy# n
proxyN# _ = proxy#

dimN# :: KnownNat n => MFloatXNM n m -> Int#
dimN# x = case fromInteger (natVal' (proxyN# x)) of I# n -> n
{-# INLINE dimN# #-}

dimM# :: KnownNat m => MFloatXNM n m -> Int#
dimM# x = case fromInteger (natVal x) of I# n -> n
{-# INLINE dimM# #-}

-- | Do something in a loop for int i from 0 to n
loop# :: Int# -> (Int# -> State# s -> State# s) -> State# s -> State# s
loop# n f = loop' 0#
  where
    loop' i s | isTrue# (i ==# n) = s
              | otherwise = case f i s of s1 -> loop' (i +# 1#) s1
{-# INLINE loop# #-}

-- | Do something in a loop for int i from 0 to n-1 and j from 0 to m-1
loop2# :: Int# -> Int# -> (Int# -> Int#-> State# s -> State# s) -> State# s -> State# s
loop2# n m f = loop' 0# 0#
  where
    loop' i j s | isTrue# (j ==# m) = s
                | isTrue# (i ==# n) = loop' 0# (j +# 1#) s
                | otherwise         = case f i j s of s1 -> loop' (i +# 1#) j s1
{-# INLINE loop2# #-}

zipV :: (KnownNat n, KnownNat m) => (Float# -> Float# -> Float#) -> MFloatXNM n m -> MFloatXNM n m -> MFloatXNM n m
zipV f x@(MFloatXNM a) (MFloatXNM b) = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop# n
               (\i s' -> case f (indexFloatArray# a i) (indexFloatArray# b i) of
                 r -> writeFloatArray# marr i r s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> MFloatXNM r
  where
    n = dimN# x *# dimM# x
    bs = n *# SIZEOF_HSFLOAT#

mapV :: (KnownNat n, KnownNat m) => (Float# -> Float#) -> MFloatXNM n m -> MFloatXNM n m
mapV f x@(MFloatXNM a) = case runRW#
     ( \s0 -> case newByteArray# bs s0 of
         (# s1, marr #) -> case loop# n
               (\i s' -> case f (indexFloatArray# a i) of
                 r -> writeFloatArray# marr i r s'
               ) s1 of
             s2 -> unsafeFreezeByteArray# marr s2
     ) of (# _, r #) -> MFloatXNM r
  where
    n = dimN# x *# dimM# x
    bs = n *# SIZEOF_HSFLOAT#


--accumVFloat :: (KnownNat n, KnownNat m) => (Float# -> Float# -> Float#) -> MFloatXNM n m -> Float# -> Float#
--accumVFloat f x@(MFloatXNM a) = loop' 0#
--  where
--    loop' i acc | isTrue# (i ==# n) = acc
--                | otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) acc)
--    n = dimN# x *# dimM# x

accumV2 :: (KnownNat n, KnownNat m) => (Float# -> Float# -> a -> a) -> MFloatXNM n m -> MFloatXNM n m -> a -> a
accumV2 f x@(MFloatXNM a) (MFloatXNM b) = loop' 0#
  where
    loop' i acc | isTrue# (i ==# n) = acc
                | otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) (indexFloatArray# b i) acc)
    n = dimN# x *# dimM# x



-- | Swap columns i and j. Does not check if i or j is larger than matrix width m
swapCols# :: Int# -- n
          -> Int# -- ith column to swap
          -> Int# -- jth column to swap
          -> MutableByteArray# s -- buffer byte array of length of n elems
          -> MutableByteArray# s -- byte array of matrix
          -> State# s -- previous state
          -> State# s -- next state
swapCols# n i j vec mat s0 =
  -- copy ith column to bugger vec
  case copyMutableByteArray# mat (i *# bs) vec 0# bs s0 of
    s1 -> case copyMutableByteArray# mat (j *# bs) mat (i *# bs) bs s1 of
      s2 -> copyMutableByteArray# vec 0# mat (j *# bs) bs s2
 where
  bs = n *# SIZEOF_HSFLOAT#

-- | Starting from i-th row and i+1-th column, substract a multiple of i-th column from i+1 .. m columns,
--   such that there are only zeroes in i-th row and i+1..m columns elements.
clearRowEnd# :: Int# -- n
             -> Int# -- m
             -> Int# -- ith column to remove from all others
             -> MutableByteArray# s -- byte array of matrix
             -> State# s -- previous state
             -> (# State# s, Float# #) -- next state and a diagonal element
clearRowEnd# n m i mat s0 = (# loop' (i +# 1#) s1, y' #)
  where
    y0 = (n +# 1#) *# i +# 1# -- first element in source column
    (# s1, y' #) = readFloatArray# mat ((n +# 1#) *# i) s0 -- diagonal element, must be non-zero
    yrc = 1.0# `divideFloat#` y'
    n' = n -# i -# 1#
    loop' k s | isTrue# (k >=# m) = s
              | otherwise = loop' (k +# 1#)
       ( let x0 = k *# n +# i
             (# s', a' #) = readFloatArray# mat x0 s
             s'' = writeFloatArray# mat x0 0.0# s'
             a  = a' `timesFloat#` yrc
         in multNRem# n' (x0 +# 1#) y0 a mat s''
       )

-- | Substract a multiple of i-th column from 0 .. i-1 and i+1 .. m columns,
--   such that there are only zeroes in i-th row and i+1..m columns elements.
--   Assuming that elements 0..i-1 in i-th row are zeroes, so they do not affect other columns.
--   After all columns updated, divide i-th row by its diagonal element
clearRowAll# :: Int# -- n
             -> Int# -- m
             -> Int# -- ith column to remove from all others
             -> MutableByteArray# s -- byte array of matrix
             -> State# s -- previous state
             -> (# State# s, Float# #) -- next state and a diagonal element
clearRowAll# n m i mat s0 = (# divLoop (i +# 1#) (writeFloatArray# mat ((n +# 1#) *# i) 1.0# (loop' 0# i (loop' (i +# 1#) m s1))), y' #)
  where
    y0 = (n +# 1#) *# i +# 1# -- first element in source column
    (# s1, y' #) = readFloatArray# mat ((n +# 1#) *# i) s0 -- diagonal element, must be non-zero
    yrc = 1.0# `divideFloat#` y'
    n' = n -# i -# 1#
    loop' k km s | isTrue# (k >=# km) = s
                 | otherwise = loop' (k +# 1#) km
       ( let x0 = k *# n +# i
             (# s', a' #) = readFloatArray# mat x0 s
             s'' = writeFloatArray# mat x0 0.0# s'
             a  = a' `timesFloat#` yrc
         in multNRem# n' (x0 +# 1#) y0 a mat s''
       )
    divLoop k s | isTrue# (k >=# n) = s
                | otherwise = divLoop (k +# 1#)
       ( let x0 = n *# i +# k
             (# s', x #) = readFloatArray# mat x0 s
         in writeFloatArray# mat x0 (timesFloat# x yrc) s'
       )


-- | Remove a multiple of one row from another one.
--   do: xi = xi - yi*a
multNRem# :: Int# -- n - nr of elements to go through
          -> Int# -- start idx of y
          -> Int# -- start idx of x
          -> Float# -- multiplier a
          -> MutableByteArray# s -- byte array of matrix
          -> State# s -- previous state
          -> State# s -- next state
multNRem# 0# _ _  _ _ s = s
multNRem# n x0 y0 a mat s = multNRem# (n -# 1#) (x0 +# 1#) (y0 +# 1#) a mat
  ( case readFloatArray# mat y0 s of
     (# s1, y #) -> case readFloatArray# mat x0 s1 of
       (# s2, x #) -> writeFloatArray# mat x0 (x `minusFloat#` timesFloat# y a) s2
  )



-- | Gives index of maximum (absolute) element in i-th row, starting from i-th element only.
--   If i >= m then returns i.
maxInRowRem# :: Int# -- n
             -> Int# -- m
             -> Int# -- ith column to start to search for and a row to look in
             -> MutableByteArray# s -- byte array of matrix
             -> State# s -- previous state
             -> (# State# s, Int# #) -- next state
maxInRowRem# n m i mat s0 = loop' i (abs# v) i s1
  where
    (# s1, v #) = readFloatArray# mat ((n +# 1#) *# i) s0
    abs# x = if isTrue# (x `geFloat#` 0.0#) then x else negateFloat# x
    loop' ok ov k s | isTrue# (k >=# m) = (# s, ok #)
                    | otherwise = case readFloatArray# mat (n *# k +# i) s of
                        (# s', v' #) -> if isTrue# (abs# v' `gtFloat#` ov)
                                        then loop' k (abs# v') (k +# 1#) s'
                                        else loop' ok ov (k +# 1#) s'




--accumV2Float :: (KnownNat n, KnownNat m) => (Float# -> Float# -> Float# -> Float#) -> MFloatXNM n m -> MFloatXNM n m -> Float# -> Float#
--accumV2Float f x@(MFloatXNM a) (MFloatXNM b) = loop' 0#
--  where
--    loop' i acc | isTrue# (i ==# n) = acc
--                | otherwise = loop' (i +# 1#) (f (indexFloatArray# a i) (indexFloatArray# b i) acc)
--    n = dimN# x *# dimM# x
--
--
--accumVReverse :: (KnownNat n, KnownNat m) => (Float# -> a -> a) -> MFloatXNM n m -> a -> a
--accumVReverse f x@(MFloatXNM a) = loop' (n -# 1#)
--  where
--    loop' i acc | isTrue# (i ==# -1#) = acc
--                | otherwise = loop' (i -# 1#) (f (indexFloatArray# a i) acc)
--    n = dimN# x *# dimM# x