```---------------------------------------------------------------------------
-- | Module    : Math.Statistics.Dirichlet.Matrix
-- Copyright   : (c) 2009-2012 Felipe Lessa
-- License     : BSD3
--
-- Maintainer  : felipe.lessa@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- Implement matrices using plain 'U.Vector's with data stored in
-- row-major order (i.e. the first elements correspond to the
-- first row).
--
--------------------------------------------------------------------------

module Math.Statistics.Dirichlet.Matrix
( -- * Basic
Matrix(..)
, size
, (!)
-- * Constructing
, replicate
, replicateRows
, fromVector
, fromVectorT
-- * Rows
, rows
, (!!!)
-- * Columns
, cols
, col
-- * Maps and zips
, umap
, map
, imap
, rowmap
, irowmap
, uzipWith
, zipWith
, izipWith
, rzipWith
-- * Other
, transpose
) where

import Prelude hiding (replicate, map, zipWith)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.Vector as V
import qualified Data.Vector.Fusion.Stream as S
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU

-- | A matrix.
data Matrix = M { mRows :: !Int
, mCols :: !Int
, mData :: !(U.Vector Double)}
deriving (Eq, Ord, Show)

-- | Size of the matrix.
size :: Matrix -> (Int,Int)
size m = (mRows m, mCols m)

-- | Element at position.
(!) :: Matrix -> (Int,Int) -> Double
(!) m (r,c) = mData m U.! (r * mCols m + c)

-- | A matrix where all elements are of the same value.
replicate :: (Int,Int) -> Double -> Matrix
replicate (r,c) v = M { mRows = r
, mCols = c
, mData = U.replicate (r*c) v}

-- | A matrix where all rows are of the same value.
replicateRows :: Int -> U.Vector Double -> Matrix
replicateRows r v =
let c = U.length v
in M { mRows = r
, mCols = c
, mData = U.generate (r*c) (\i -> v U.! (i `mod` c))}

-- | Creates a matrix from a vector of vectors.  It *is not*
-- verified that the vectors have the right length.
fromVector :: (G.Vector v (w Double), G.Vector w Double)
=> v (w Double) -> Matrix
fromVector v =
M { mRows = G.length v
, mCols = G.length (G.head v)
, mData = G.unstream \$ S.concatMap G.stream \$ G.stream v}

-- | Creates a matrix from a vector of vectors.  The vectors are
-- transposed, so @fromVectorT@ is the same as @transpose
-- . fromVector@. It *is* verified that the vectors have the
-- right length.
fromVectorT :: (G.Vector v (w Double), G.Vector w Double)
=> v (w Double) -> Matrix
fromVectorT v =
M { mRows = c
, mCols = r
, mData = unsafePerformIO \$ do
m <- MU.new (r*c)
fillCol m r
G.unsafeFreeze m}
where
r = G.length v
c = G.length (G.head v)
fillCol _ 0 = return ()
fillCol m j = let j' = j-1
in fillRow m (v G.! j') j' c >> fillCol m j'
fillRow _ _   _  0 = return ()
fillRow m clm j' i = let i' = i-1
x  = clm G.! i'
in MU.write m (i' * r + j') x >> fillRow m clm j' i'

-- | /O(rows)/ Rows of the matrix.  Each element takes /O(1)/ time and
-- storage.
rows :: Matrix -> V.Vector (U.Vector Double)
rows m = G.map (\i -> U.unsafeSlice i (mCols m) (mData m)) \$
G.enumFromStepN 0 (mCols m) (mRows m)

-- | /O(1)/ @m !!! i@ is the @i@-th row of the matrix.
(!!!) :: Matrix -> Int -> U.Vector Double
m !!! i = U.slice (i * mCols m) (mCols m) (mData m)

-- | /O(rows*cols)/ Columns of the matrix.  Each element takes
-- /O(rows)/ time and storage.
cols :: Matrix -> V.Vector (U.Vector Double)
cols m = V.generate (mCols m) (m `col`)

-- | /O(rows)/ @m `col` i@ is the @i@-th column of the matrix.
col :: Matrix -> Int -> U.Vector Double
m `col` i = U.backpermute (mData m) \$ U.enumFromStepN i (mCols m) (mRows m)

umap :: (U.Vector Double -> U.Vector Double) -> Matrix -> Matrix
umap f m = m {mData = f (mData m)}

map :: (Double -> Double) -> Matrix -> Matrix
map f = umap (U.map f)

imap :: ((Int,Int) -> Double -> Double) -> Matrix -> Matrix
imap f m = umap (U.imap (f . indices m)) m

rowmap :: (U.Vector Double -> Double) -> Matrix -> U.Vector Double
rowmap f m = U.generate (mRows m) (f . s)
where s i = U.unsafeSlice (i * mCols m) (mCols m) (mData m)

irowmap :: (Int -> U.Vector Double -> Double) -> Matrix -> U.Vector Double
irowmap f m = U.generate (mRows m) (\i -> f i \$ s i)
where s i = U.unsafeSlice (i * mCols m) (mCols m) (mData m)

uzipWith :: (U.Vector Double -> U.Vector Double -> U.Vector Double)
-> Matrix -> Matrix -> Matrix
uzipWith f m n
| mRows m /= mRows n = materror "uzipWith" "mRows"
| mCols m /= mCols n = materror "uzipWith" "mCols"
| otherwise          = m {mData = f (mData m) (mData n)}

zipWith :: (Double -> Double -> Double) -> Matrix -> Matrix -> Matrix
zipWith f = uzipWith (U.zipWith f)

izipWith :: ((Int,Int) -> Double -> Double -> Double)
-> Matrix -> Matrix -> Matrix
izipWith f m = uzipWith (U.izipWith (f . indices m)) m

-- | @rzipWith f m n@ is a matrix with the same number of rows as
-- @m@.  The @i@-th row is obtained by applying @f@ to the @i@-th
-- rows of @m@ and @n@.
rzipWith :: (Int -> U.Vector Double -> U.Vector Double -> U.Vector Double)
-> Matrix -> Matrix -> Matrix
rzipWith f m n
| rm /= rn = materror "rzipWithN" \$ "mRows " ++ s
| cm /= cn = materror "rzipWithN" \$ "mCols " ++ s
| otherwise          = fromVector \$ V.izipWith f (rows m) (rows n)
where rm = mRows m; cm = mCols m
rn = mRows n; cn = mCols n
s = show ((rm,cm),(rn,cn))

indices :: Matrix -> Int -> (Int, Int)
indices m i = i `divMod` mCols m

transpose :: Matrix -> Matrix
transpose m =
let f i = let (r,c) = i `divMod` mRows m
in m ! (c,r)
in M { mRows = mCols m
, mCols = mRows m
, mData = U.generate (mRows m * mCols m) f}

{-# RULES
"transpose/transpose"   forall m. transpose (transpose m) = m;
"transpose/fromVector"  forall v. transpose (fromVector v) = fromVectorT v;
"transpose/fromVectorT" forall v. transpose (fromVectorT v) = fromVector v;
#-}

materror :: String -> String -> a
materror f e = error \$ "Math.Statistics.Dirichlet.Matrix." ++ f ++ ": " ++ e
```