module Data.Matrix.TriBase
where
import Unsafe.Coerce
import Data.Matrix.Class
import Data.Tensor.Class
import BLAS.Types ( UpLoEnum(..), DiagEnum(..), flipUpLo )
data Tri a np e = Tri UpLoEnum DiagEnum (a np e)
coerceTri :: Tri a np e -> Tri a np' e
coerceTri = unsafeCoerce
mapTri :: (a np e -> b np' e) -> Tri a np e -> Tri b np' e
mapTri f (Tri u d a) = Tri u d $ f a
triFromBase :: UpLoEnum -> DiagEnum -> a (n,p) e -> Tri a (n,p) e
triFromBase = Tri
triToBase :: Tri a (n,p) e -> (UpLoEnum, DiagEnum, a (n,p) e)
triToBase (Tri u d a) = (u,d,a)
lower :: (MatrixShaped a) => a (n,p) e -> Tri a (n,p) e
lower = Tri Lower NonUnit
lowerU :: (MatrixShaped a) => a (n,p) e -> Tri a (n,p) e
lowerU = Tri Lower Unit
upper :: (MatrixShaped a) => a (n,p) e -> Tri a (n,p) e
upper = Tri Upper NonUnit
upperU :: (MatrixShaped a) => a (n,p) e -> Tri a (n,p) e
upperU = Tri Upper Unit
instance (MatrixShaped a) => Shaped (Tri a) (Int,Int) where
shape (Tri Lower _ a) = (numRows a, min (numRows a) (numCols a))
shape (Tri Upper _ a) = (min (numRows a) (numCols a), numCols a)
bounds a = ((0,0),(m1,n1)) where (m,n) = shape a
instance (MatrixShaped a) => MatrixShaped (Tri a) where
herm (Tri u d a) = Tri (flipUpLo u) d (herm a)
instance (Show (a (n,p) e), MatrixShaped a) => Show (Tri a (n,p) e) where
show (Tri u d a) =
constructor ++ suffix ++ " (" ++ show a ++ ")"
where
constructor = case (u,d) of
(Lower, NonUnit) -> "lower"
(Lower, Unit ) -> "lowerU"
(Upper, NonUnit) -> "upper"
(Upper, Unit ) -> "upperU"
suffix = case undefined of
_ | isSquare a -> ""
_ | isFat a -> "Fat"
_ -> "Tall"