{-# LANGUAGE MultiParamTypeClasses, FlexibleContexts, FlexibleInstances #-} ----------------------------------------------------------------------------- -- | -- Module : Data.Matrix.Tri -- Copyright : Copyright (c) , Patrick Perry -- License : BSD3 -- Maintainer : Patrick Perry -- Stability : experimental -- module Data.Matrix.Tri ( Tri(..), UpLo(..), Diag(..), fromBase, toBase, mapTri, lower, lowerFat, lowerTall, lowerU, lowerUFat, lowerUTall, upper, upperFat, upperTall, upperU, upperUFat, upperUTall, coerceTri, ) where import BLAS.Internal ( checkSquare, checkFat, checkTall ) import BLAS.Matrix import BLAS.Tensor import BLAS.Types ( UpLo(..), Diag(..), flipUpLo ) import Unsafe.Coerce data Tri a mn e = Tri UpLo Diag (a mn e) -- | Coerce the shape type. coerceTri :: Tri a mn e -> Tri a mn' e coerceTri = unsafeCoerce mapTri :: (a (m,n) e -> b (m,n) e) -> Tri a (m,n) e -> Tri b (m,n) e mapTri f (Tri u d a) = Tri u d $ f a fromBase :: UpLo -> Diag -> a (m,n) e -> Tri a (m,n) e fromBase = Tri toBase :: Tri a (m,n) e -> (UpLo, Diag, a (m,n) e) toBase (Tri u d a) = (u,d,a) lower :: (Matrix a) => a (n,n) e -> Tri a (n,n) e lower a = checkSquare (shape a) $ Tri Lower NonUnit a lowerFat :: (Matrix a) => a (m,n) e -> Tri a (m,m) e lowerFat a = checkFat (shape a) $ Tri Lower NonUnit (unsafeCoerce a) lowerTall :: (Matrix a) => a (m,n) e -> Tri a (m,n) e lowerTall a = checkTall (shape a) $ Tri Lower NonUnit a lowerU :: (Matrix a) => a (n,n) e -> Tri a (n,n) e lowerU a = checkSquare (shape a) $ Tri Lower Unit a lowerUFat :: (Matrix a) => a (m,n) e -> Tri a (m,m) e lowerUFat a = checkFat (shape a) $ Tri Lower Unit (unsafeCoerce a) lowerUTall :: (Matrix a) => a (m,n) e -> Tri a (m,n) e lowerUTall a = checkTall (shape a) $ Tri Lower Unit a upper :: (Matrix a) => a (n,n) e -> Tri a (n,n) e upper a = checkSquare (shape a) $ Tri Upper NonUnit a upperFat :: (Matrix a) => a (m,n) e -> Tri a (m,n) e upperFat a = checkFat (shape a) $ Tri Upper NonUnit a upperTall :: (Matrix a) => a (m,n) e -> Tri a (n,n) e upperTall a = checkTall (shape a) $ Tri Upper NonUnit (unsafeCoerce a) upperU :: (Matrix a) => a (n,n) e -> Tri a (n,n) e upperU a = checkSquare (shape a) $ Tri Upper Unit a upperUFat :: (Matrix a) => a (m,n) e -> Tri a (m,n) e upperUFat a = checkFat (shape a) $ Tri Upper Unit a upperUTall :: (Matrix a) => a (m,n) e -> Tri a (n,n) e upperUTall a = checkTall (shape a) $ Tri Upper Unit (unsafeCoerce a) instance Matrix a => Matrix (Tri a) where numRows (Tri Lower _ a) = numRows a numRows (Tri Upper _ a) = min (numRows a) (numCols a) numCols (Tri Lower _ a) = min (numRows a) (numCols a) numCols (Tri Upper _ a) = numCols a herm (Tri u d a) = Tri (flipUpLo u) d (herm a) instance (Show (a (m,n) e), Matrix a) => Show (Tri a (m,n) e) where show (Tri u d a) = constructor ++ suffix ++ " (" ++ show a ++ ")" where constructor = case (u,d) of (Lower, NonUnit) -> "lower" (Lower, Unit ) -> "lowerU" (Upper, NonUnit) -> "upper" (Upper, Unit ) -> "upperU" suffix = case undefined of _ | isSquare a -> "" _ | isFat a -> "Fat" _ -> "Tall"