module Data.Matrix.Tri (
Tri(..),
UpLo(..), Diag(..),
fromBase,
toBase,
mapTri,
lower,
lowerFat,
lowerTall,
lowerU,
lowerUFat,
lowerUTall,
upper,
upperFat,
upperTall,
upperU,
upperUFat,
upperUTall,
coerceTri,
) where
import BLAS.Internal ( checkSquare, checkFat, checkTall )
import BLAS.Matrix
import BLAS.Tensor
import BLAS.Types ( UpLo(..), Diag(..), flipUpLo )
import Unsafe.Coerce
data Tri a mn e = Tri UpLo Diag (a mn e)
coerceTri :: Tri a mn e -> Tri a mn' e
coerceTri = unsafeCoerce
mapTri :: (a (m,n) e -> b (m,n) e) -> Tri a (m,n) e -> Tri b (m,n) e
mapTri f (Tri u d a) = Tri u d $ f a
fromBase :: UpLo -> Diag -> a (m,n) e -> Tri a (m,n) e
fromBase = Tri
toBase :: Tri a (m,n) e -> (UpLo, Diag, a (m,n) e)
toBase (Tri u d a) = (u,d,a)
lower :: (Matrix a) => a (n,n) e -> Tri a (n,n) e
lower a = checkSquare (shape a) $ Tri Lower NonUnit a
lowerFat :: (Matrix a) => a (m,n) e -> Tri a (m,m) e
lowerFat a = checkFat (shape a) $ Tri Lower NonUnit (unsafeCoerce a)
lowerTall :: (Matrix a) => a (m,n) e -> Tri a (m,n) e
lowerTall a = checkTall (shape a) $ Tri Lower NonUnit a
lowerU :: (Matrix a) => a (n,n) e -> Tri a (n,n) e
lowerU a = checkSquare (shape a) $ Tri Lower Unit a
lowerUFat :: (Matrix a) => a (m,n) e -> Tri a (m,m) e
lowerUFat a = checkFat (shape a) $ Tri Lower Unit (unsafeCoerce a)
lowerUTall :: (Matrix a) => a (m,n) e -> Tri a (m,n) e
lowerUTall a = checkTall (shape a) $ Tri Lower Unit a
upper :: (Matrix a) => a (n,n) e -> Tri a (n,n) e
upper a = checkSquare (shape a) $ Tri Upper NonUnit a
upperFat :: (Matrix a) => a (m,n) e -> Tri a (m,n) e
upperFat a = checkFat (shape a) $ Tri Upper NonUnit a
upperTall :: (Matrix a) => a (m,n) e -> Tri a (n,n) e
upperTall a = checkTall (shape a) $ Tri Upper NonUnit (unsafeCoerce a)
upperU :: (Matrix a) => a (n,n) e -> Tri a (n,n) e
upperU a = checkSquare (shape a) $ Tri Upper Unit a
upperUFat :: (Matrix a) => a (m,n) e -> Tri a (m,n) e
upperUFat a = checkFat (shape a) $ Tri Upper Unit a
upperUTall :: (Matrix a) => a (m,n) e -> Tri a (n,n) e
upperUTall a = checkTall (shape a) $ Tri Upper Unit (unsafeCoerce a)
instance Matrix a => Matrix (Tri a) where
numRows (Tri Lower _ a) = numRows a
numRows (Tri Upper _ a) = min (numRows a) (numCols a)
numCols (Tri Lower _ a) = min (numRows a) (numCols a)
numCols (Tri Upper _ a) = numCols a
herm (Tri u d a) = Tri (flipUpLo u) d (herm a)
instance (Show (a (m,n) e), Matrix a) => Show (Tri a (m,n) e) where
show (Tri u d a) =
constructor ++ suffix ++ " (" ++ show a ++ ")"
where
constructor = case (u,d) of
(Lower, NonUnit) -> "lower"
(Lower, Unit ) -> "lowerU"
(Upper, NonUnit) -> "upper"
(Upper, Unit ) -> "upperU"
suffix = case undefined of
_ | isSquare a -> ""
_ | isFat a -> "Fat"
_ -> "Tall"