module Internal.Matrix where
import Internal.Vector
import Internal.Devel
import Internal.Vectorized hiding ((#), (#!))
import Foreign.Marshal.Alloc ( free )
import Foreign.Marshal.Array(newArray)
import Foreign.Ptr ( Ptr )
import Foreign.Storable ( Storable )
import Data.Complex ( Complex )
import Foreign.C.Types ( CInt(..) )
import Foreign.C.String ( CString, newCString )
import System.IO.Unsafe ( unsafePerformIO )
import Control.DeepSeq ( NFData(..) )
import Text.Printf
data MatrixOrder = RowMajor | ColumnMajor deriving (Show,Eq)
data Matrix t = Matrix
    { irows ::  !Int
    , icols ::  !Int
    , xRow  ::  !Int
    , xCol  ::  !Int
    , xdat  ::  !(Vector t)
    }
rows :: Matrix t -> Int
rows = irows
cols :: Matrix t -> Int
cols = icols
size m = (irows m, icols m)
rowOrder m = xCol m == 1 || cols m == 1
colOrder m = xRow m == 1 || rows m == 1
is1d (size->(r,c)) = r==1 || c==1
isSlice m@(size->(r,c)) = r*c < dim (xdat m)
orderOf :: Matrix t -> MatrixOrder
orderOf m = if rowOrder m then RowMajor else ColumnMajor
showInternal :: Storable t => Matrix t -> IO ()
showInternal m = printf "%dx%d %s %s %d:%d (%d)\n" r c slc ord xr xc dv
  where
    r  = rows m
    c  = cols m
    xr = xRow m
    xc = xCol m
    slc = if isSlice m then "slice" else "full"
    ord = if is1d m then "1d" else if rowOrder m then "rows" else "cols"
    dv = dim (xdat m)
trans :: Matrix t -> Matrix t
trans m@Matrix { irows = r, icols = c, xRow = xr, xCol = xc } =
             m { irows = c, icols = r, xRow = xc, xCol = xr }
cmat :: (Element t) => Matrix t -> Matrix t
cmat m
    | rowOrder m = m
    | otherwise  = extractAll RowMajor m
fmat :: (Element t) => Matrix t -> Matrix t
fmat m
    | colOrder m = m
    | otherwise  = extractAll ColumnMajor m
amatr :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> Ptr a -> f) -> IO r
amatr x f g = unsafeWith (xdat x) (f . g r c)
  where
    r  = fi (rows x)
    c  = fi (cols x)
amat :: Storable a => Matrix a -> (f -> IO r) -> (CInt -> CInt -> CInt -> CInt -> Ptr a -> f) -> IO r
amat x f g = unsafeWith (xdat x) (f . g r c sr sc)
  where
    r  = fi (rows x)
    c  = fi (cols x)
    sr = fi (xRow x)
    sc = fi (xCol x)
instance Storable t => TransArray (Matrix t)
  where
    type TransRaw (Matrix t) b = CInt -> CInt -> Ptr t -> b
    type Trans (Matrix t) b    = CInt -> CInt -> CInt -> CInt -> Ptr t -> b
    apply = amat
    
    applyRaw = amatr
    
infixr 1 #
a # b = apply a b
a #! b = a # b # id
copy ord m = extractR ord m 0 (idxs[0,rows m1]) 0 (idxs[0,cols m1])
extractAll ord m = unsafePerformIO (copy ord m)
flatten :: Element t => Matrix t -> Vector t
flatten m
    | isSlice m || not (rowOrder m) = xdat (extractAll RowMajor m)
    | otherwise                     = xdat m
toLists :: (Element t) => Matrix t -> [[t]]
toLists = map toList . toRows
compatdim :: [Int] -> Maybe Int
compatdim [] = Nothing
compatdim [a] = Just a
compatdim (a:b:xs)
    | a==b = compatdim (b:xs)
    | a==1 = compatdim (b:xs)
    | b==1 = compatdim (a:xs)
    | otherwise = Nothing
fromRows :: Element t => [Vector t] -> Matrix t
fromRows [] = emptyM 0 0
fromRows vs = case compatdim (map dim vs) of
    Nothing -> error $ "fromRows expects vectors with equal sizes (or singletons), given: " ++ show (map dim vs)
    Just 0  -> emptyM r 0
    Just c  -> matrixFromVector RowMajor r c . vjoin . map (adapt c) $ vs
  where
    r = length vs
    adapt c v
        | c == 0 = fromList[]
        | dim v == c = v
        | otherwise = constantD (v@>0) c
toRows :: Element t => Matrix t -> [Vector t]
toRows m
    | rowOrder m = map sub rowRange
    | otherwise  = map ext rowRange
  where
    rowRange = [0..rows m1]
    sub k = subVector (k*xRow m) (cols m) (xdat m)
    ext k = xdat $ unsafePerformIO $ extractR RowMajor m 1 (idxs[k]) 0 (idxs[0,cols m1])
fromColumns :: Element t => [Vector t] -> Matrix t
fromColumns m = trans . fromRows $ m
toColumns :: Element t => Matrix t -> [Vector t]
toColumns m = toRows . trans $ m
(@@>) :: Storable t => Matrix t -> (Int,Int) -> t
infixl 9 @@>
m@Matrix {irows = r, icols = c} @@> (i,j)
    | i<0 || i>=r || j<0 || j>=c = error "matrix indexing out of range"
    | otherwise = atM' m i j
atM' m i j = xdat m `at'` (i * (xRow m) + j * (xCol m))
matrixFromVector _ 1 _ v@(dim->d) = Matrix { irows = 1, icols = d, xdat = v, xRow = d, xCol = 1 }
matrixFromVector _ _ 1 v@(dim->d) = Matrix { irows = d, icols = 1, xdat = v, xRow = 1, xCol = d }
matrixFromVector o r c v
    | r * c == dim v = m
    | otherwise = error $ "can't reshape vector dim = "++ show (dim v)++" to matrix " ++ shSize m
  where
    m | o == RowMajor = Matrix { irows = r, icols = c, xdat = v, xRow = c, xCol = 1 }
      | otherwise     = Matrix { irows = r, icols = c, xdat = v, xRow = 1, xCol = r }
createMatrix :: (Storable a) => MatrixOrder -> Int -> Int -> IO (Matrix a)
createMatrix ord r c = do
    p <- createVector (r*c)
    return (matrixFromVector ord r c p)
reshape :: Storable t => Int -> Vector t -> Matrix t
reshape 0 v = matrixFromVector RowMajor 0 0 v
reshape c v = matrixFromVector RowMajor (dim v `div` c) c v
liftMatrix :: (Element a, Element b) => (Vector a -> Vector b) -> Matrix a -> Matrix b
liftMatrix f m@Matrix { irows = r, icols = c, xdat = d}
    | isSlice m = matrixFromVector RowMajor r c (f (flatten m))
    | otherwise = matrixFromVector (orderOf m) r c (f d)
liftMatrix2 :: (Element t, Element a, Element b) => (Vector a -> Vector b -> Vector t) -> Matrix a -> Matrix b -> Matrix t
liftMatrix2 f m1@(size->(r,c)) m2
    | (r,c)/=size m2 = error "nonconformant matrices in liftMatrix2"
    | rowOrder m1 = matrixFromVector RowMajor    r c (f (flatten m1) (flatten m2))
    | otherwise   = matrixFromVector ColumnMajor r c (f (flatten (trans m1)) (flatten (trans m2)))
class (Storable a) => Element a where
    constantD  :: a -> Int -> Vector a
    extractR :: MatrixOrder -> Matrix a -> CInt -> Vector CInt -> CInt -> Vector CInt -> IO (Matrix a)
    setRect  :: Int -> Int -> Matrix a -> Matrix a -> IO ()
    sortI    :: Ord a => Vector a -> Vector CInt
    sortV    :: Ord a => Vector a -> Vector a
    compareV :: Ord a => Vector a -> Vector a -> Vector CInt
    selectV  :: Vector CInt -> Vector a -> Vector a -> Vector a -> Vector a
    remapM   :: Matrix CInt -> Matrix CInt -> Matrix a -> Matrix a
    rowOp    :: Int -> a -> Int -> Int -> Int -> Int -> Matrix a -> IO ()
    gemm     :: Vector a -> Matrix a -> Matrix a -> Matrix a -> IO ()
    reorderV :: Vector CInt-> Vector CInt-> Vector a -> Vector a 
instance Element Float where
    constantD  = constantAux cconstantF
    extractR   = extractAux c_extractF
    setRect    = setRectAux c_setRectF
    sortI      = sortIdxF
    sortV      = sortValF
    compareV   = compareF
    selectV    = selectF
    remapM     = remapF
    rowOp      = rowOpAux c_rowOpF
    gemm       = gemmg c_gemmF
    reorderV   = reorderAux c_reorderF
instance Element Double where
    constantD  = constantAux cconstantR
    extractR   = extractAux c_extractD
    setRect    = setRectAux c_setRectD
    sortI      = sortIdxD
    sortV      = sortValD
    compareV   = compareD
    selectV    = selectD
    remapM     = remapD
    rowOp      = rowOpAux c_rowOpD
    gemm       = gemmg c_gemmD
    reorderV   = reorderAux c_reorderD
instance Element (Complex Float) where
    constantD  = constantAux cconstantQ
    extractR   = extractAux c_extractQ
    setRect    = setRectAux c_setRectQ
    sortI      = undefined
    sortV      = undefined
    compareV   = undefined
    selectV    = selectQ
    remapM     = remapQ
    rowOp      = rowOpAux c_rowOpQ
    gemm       = gemmg c_gemmQ
    reorderV   = reorderAux c_reorderQ
instance Element (Complex Double) where
    constantD  = constantAux cconstantC
    extractR   = extractAux c_extractC
    setRect    = setRectAux c_setRectC
    sortI      = undefined
    sortV      = undefined
    compareV   = undefined
    selectV    = selectC
    remapM     = remapC
    rowOp      = rowOpAux c_rowOpC
    gemm       = gemmg c_gemmC
    reorderV   = reorderAux c_reorderC
instance Element (CInt) where
    constantD  = constantAux cconstantI
    extractR   = extractAux c_extractI
    setRect    = setRectAux c_setRectI
    sortI      = sortIdxI
    sortV      = sortValI
    compareV   = compareI
    selectV    = selectI
    remapM     = remapI
    rowOp      = rowOpAux c_rowOpI
    gemm       = gemmg c_gemmI
    reorderV   = reorderAux c_reorderI
instance Element Z where
    constantD  = constantAux cconstantL
    extractR   = extractAux c_extractL
    setRect    = setRectAux c_setRectL
    sortI      = sortIdxL
    sortV      = sortValL
    compareV   = compareL
    selectV    = selectL
    remapM     = remapL
    rowOp      = rowOpAux c_rowOpL
    gemm       = gemmg c_gemmL
    reorderV   = reorderAux c_reorderL
subMatrix :: Element a
            => (Int,Int) 
            -> (Int,Int) 
            -> Matrix a 
            -> Matrix a 
subMatrix (r0,c0) (rt,ct) m
    | rt <= 0 || ct <= 0 = matrixFromVector RowMajor (max 0 rt) (max 0 ct) (fromList [])
    | 0 <= r0 && 0 <= rt && r0+rt <= rows m &&
      0 <= c0 && 0 <= ct && c0+ct <= cols m = res
    | otherwise = error $ "wrong subMatrix "++show ((r0,c0),(rt,ct))++" of "++shSize m
  where
    p = r0 * xRow m + c0 * xCol m
    tot | rowOrder m = ct + (rt1) * xRow m
        | otherwise  = rt + (ct1) * xCol m
    res = m { irows = rt, icols = ct, xdat = subVector p tot (xdat m) }
maxZ xs = if minimum xs == 0 then 0 else maximum xs
conformMs ms = map (conformMTo (r,c)) ms
  where
    r = maxZ (map rows ms)
    c = maxZ (map cols ms)
conformVs vs = map (conformVTo n) vs
  where
    n = maxZ (map dim vs)
conformMTo (r,c) m
    | size m == (r,c) = m
    | size m == (1,1) = matrixFromVector RowMajor r c (constantD (m@@>(0,0)) (r*c))
    | size m == (r,1) = repCols c m
    | size m == (1,c) = repRows r m
    | otherwise = error $ "matrix " ++ shSize m ++ " cannot be expanded to " ++ shDim (r,c)
conformVTo n v
    | dim v == n = v
    | dim v == 1 = constantD (v@>0) n
    | otherwise = error $ "vector of dim=" ++ show (dim v) ++ " cannot be expanded to dim=" ++ show n
repRows n x = fromRows (replicate n (flatten x))
repCols n x = fromColumns (replicate n (flatten x))
shSize = shDim . size
shDim (r,c) = "(" ++ show r ++"x"++ show c ++")"
emptyM r c = matrixFromVector RowMajor r c (fromList[])
instance (Storable t, NFData t) => NFData (Matrix t)
  where
    rnf m | d > 0     = rnf (v @> 0)
          | otherwise = ()
      where
        d = dim v
        v = xdat m
extractAux f ord m moder vr modec vc = do
    let nr = if moder == 0 then fromIntegral $ vr@>1  vr@>0 + 1 else dim vr
        nc = if modec == 0 then fromIntegral $ vc@>1  vc@>0 + 1 else dim vc
    r <- createMatrix ord nr nc
    (vr # vc # m #! r) (f moder modec)  #|"extract"
    return r
type Extr x = CInt -> CInt -> CIdxs (CIdxs (OM x (OM x (IO CInt))))
foreign import ccall unsafe "extractD" c_extractD :: Extr Double
foreign import ccall unsafe "extractF" c_extractF :: Extr Float
foreign import ccall unsafe "extractC" c_extractC :: Extr (Complex Double)
foreign import ccall unsafe "extractQ" c_extractQ :: Extr (Complex Float)
foreign import ccall unsafe "extractI" c_extractI :: Extr CInt
foreign import ccall unsafe "extractL" c_extractL :: Extr Z
setRectAux f i j m r = (m #! r) (f (fi i) (fi j)) #|"setRect"
type SetRect x = I -> I -> x ::> x::> Ok
foreign import ccall unsafe "setRectD" c_setRectD :: SetRect Double
foreign import ccall unsafe "setRectF" c_setRectF :: SetRect Float
foreign import ccall unsafe "setRectC" c_setRectC :: SetRect (Complex Double)
foreign import ccall unsafe "setRectQ" c_setRectQ :: SetRect (Complex Float)
foreign import ccall unsafe "setRectI" c_setRectI :: SetRect I
foreign import ccall unsafe "setRectL" c_setRectL :: SetRect Z
sortG f v = unsafePerformIO $ do
    r <- createVector (dim v)
    (v #! r) f #|"sortG"
    return r
sortIdxD = sortG c_sort_indexD
sortIdxF = sortG c_sort_indexF
sortIdxI = sortG c_sort_indexI
sortIdxL = sortG c_sort_indexL
sortValD = sortG c_sort_valD
sortValF = sortG c_sort_valF
sortValI = sortG c_sort_valI
sortValL = sortG c_sort_valL
foreign import ccall unsafe "sort_indexD" c_sort_indexD :: CV Double (CV CInt (IO CInt))
foreign import ccall unsafe "sort_indexF" c_sort_indexF :: CV Float  (CV CInt (IO CInt))
foreign import ccall unsafe "sort_indexI" c_sort_indexI :: CV CInt   (CV CInt (IO CInt))
foreign import ccall unsafe "sort_indexL" c_sort_indexL :: Z :> I :> Ok
foreign import ccall unsafe "sort_valuesD" c_sort_valD :: CV Double (CV Double (IO CInt))
foreign import ccall unsafe "sort_valuesF" c_sort_valF :: CV Float  (CV Float (IO CInt))
foreign import ccall unsafe "sort_valuesI" c_sort_valI :: CV CInt   (CV CInt (IO CInt))
foreign import ccall unsafe "sort_valuesL" c_sort_valL :: Z :> Z :> Ok
compareG f u v = unsafePerformIO $ do
    r <- createVector (dim v)
    (u # v #! r) f #|"compareG"
    return r
compareD = compareG c_compareD
compareF = compareG c_compareF
compareI = compareG c_compareI
compareL = compareG c_compareL
foreign import ccall unsafe "compareD" c_compareD :: CV Double (CV Double (CV CInt (IO CInt)))
foreign import ccall unsafe "compareF" c_compareF :: CV Float (CV Float  (CV CInt (IO CInt)))
foreign import ccall unsafe "compareI" c_compareI :: CV CInt (CV CInt   (CV CInt (IO CInt)))
foreign import ccall unsafe "compareL" c_compareL :: Z :> Z :> I :> Ok
selectG f c u v w = unsafePerformIO $ do
    r <- createVector (dim v)
    (c # u # v # w #! r) f #|"selectG"
    return r
selectD = selectG c_selectD
selectF = selectG c_selectF
selectI = selectG c_selectI
selectL = selectG c_selectL
selectC = selectG c_selectC
selectQ = selectG c_selectQ
type Sel x = CV CInt (CV x (CV x (CV x (CV x (IO CInt)))))
foreign import ccall unsafe "chooseD" c_selectD :: Sel Double
foreign import ccall unsafe "chooseF" c_selectF :: Sel Float
foreign import ccall unsafe "chooseI" c_selectI :: Sel CInt
foreign import ccall unsafe "chooseC" c_selectC :: Sel (Complex Double)
foreign import ccall unsafe "chooseQ" c_selectQ :: Sel (Complex Float)
foreign import ccall unsafe "chooseL" c_selectL :: Sel Z
remapG f i j m = unsafePerformIO $ do
    r <- createMatrix RowMajor (rows i) (cols i)
    (i # j # m #! r) f #|"remapG"
    return r
remapD = remapG c_remapD
remapF = remapG c_remapF
remapI = remapG c_remapI
remapL = remapG c_remapL
remapC = remapG c_remapC
remapQ = remapG c_remapQ
type Rem x = OM CInt (OM CInt (OM x (OM x (IO CInt))))
foreign import ccall unsafe "remapD" c_remapD :: Rem Double
foreign import ccall unsafe "remapF" c_remapF :: Rem Float
foreign import ccall unsafe "remapI" c_remapI :: Rem CInt
foreign import ccall unsafe "remapC" c_remapC :: Rem (Complex Double)
foreign import ccall unsafe "remapQ" c_remapQ :: Rem (Complex Float)
foreign import ccall unsafe "remapL" c_remapL :: Rem Z
rowOpAux f c x i1 i2 j1 j2 m = do
    px <- newArray [x]
    (m # id) (f (fi c) px (fi i1) (fi i2) (fi j1) (fi j2)) #|"rowOp"
    free px
type RowOp x = CInt -> Ptr x -> CInt -> CInt -> CInt -> CInt -> x ::> Ok
foreign import ccall unsafe "rowop_double"  c_rowOpD :: RowOp R
foreign import ccall unsafe "rowop_float"   c_rowOpF :: RowOp Float
foreign import ccall unsafe "rowop_TCD"     c_rowOpC :: RowOp C
foreign import ccall unsafe "rowop_TCF"     c_rowOpQ :: RowOp (Complex Float)
foreign import ccall unsafe "rowop_int32_t" c_rowOpI :: RowOp I
foreign import ccall unsafe "rowop_int64_t" c_rowOpL :: RowOp Z
foreign import ccall unsafe "rowop_mod_int32_t" c_rowOpMI :: I -> RowOp I
foreign import ccall unsafe "rowop_mod_int64_t" c_rowOpML :: Z -> RowOp Z
gemmg f v m1 m2 m3 = (v # m1 # m2 #! m3) f #|"gemmg"
type Tgemm x = x :> x ::> x ::> x ::> Ok
foreign import ccall unsafe "gemm_double"  c_gemmD :: Tgemm R
foreign import ccall unsafe "gemm_float"   c_gemmF :: Tgemm Float
foreign import ccall unsafe "gemm_TCD"     c_gemmC :: Tgemm C
foreign import ccall unsafe "gemm_TCF"     c_gemmQ :: Tgemm (Complex Float)
foreign import ccall unsafe "gemm_int32_t" c_gemmI :: Tgemm I
foreign import ccall unsafe "gemm_int64_t" c_gemmL :: Tgemm Z
foreign import ccall unsafe "gemm_mod_int32_t" c_gemmMI :: I -> Tgemm I
foreign import ccall unsafe "gemm_mod_int64_t" c_gemmML :: Z -> Tgemm Z
reorderAux f s d v = unsafePerformIO $ do
    k <- createVector (dim s)
    r <- createVector (dim v)
    (k # s # d # v #! r) f #| "reorderV"
    return r
type Reorder x = CV CInt (CV CInt (CV CInt (CV x (CV x (IO CInt)))))
foreign import ccall unsafe "reorderD" c_reorderD :: Reorder Double
foreign import ccall unsafe "reorderF" c_reorderF :: Reorder Float
foreign import ccall unsafe "reorderI" c_reorderI :: Reorder CInt
foreign import ccall unsafe "reorderC" c_reorderC :: Reorder (Complex Double)
foreign import ccall unsafe "reorderQ" c_reorderQ :: Reorder (Complex Float)
foreign import ccall unsafe "reorderL" c_reorderL :: Reorder Z
reorderVector :: Element a
                    => Vector CInt 
                    -> Vector CInt 
                    -> Vector a    
                    -> Vector a    
reorderVector = reorderV
foreign import ccall unsafe "saveMatrix" c_saveMatrix
    :: CString -> CString -> Double ::> Ok
saveMatrix
    :: FilePath
    -> String        
    -> Matrix Double
    -> IO ()
saveMatrix name format m = do
    cname   <- newCString name
    cformat <- newCString format
    (m # id) (c_saveMatrix cname cformat) #|"saveMatrix"
    free cname
    free cformat
    return ()