```{-# LANGUAGE MultiParamTypeClasses, FlexibleContexts, FlexibleInstances #-}
-----------------------------------------------------------------------------
-- |
-- Module     : Data.Matrix.Tri
-- Maintainer : Patrick Perry <patperry@stanford.edu>
-- Stability  : experimental
--

module Data.Matrix.Tri (
Tri(..),
UpLo(..), Diag(..),

fromBase,
toBase,
mapTri,

lower,
lowerU,
upper,
upperU,

) where

import qualified BLAS.Elem as E
import BLAS.Matrix
import BLAS.Tensor
import BLAS.Types ( UpLo(..), Diag(..), flipUpLo )

data Tri a nn e = Tri UpLo Diag e (a nn e)

mapTri :: (a (n,n) e -> b (n,n) e) -> Tri a (n,n) e -> Tri b (n,n) e
mapTri f (Tri u d n a) = Tri u d n \$ f a

fromBase :: UpLo -> Diag -> e -> a (n,n) e -> Tri a (n,n) e
fromBase = Tri

toBase :: Tri a (n,n) e -> (UpLo, Diag, e, a (n,n) e)
toBase (Tri u d e a) = (u,d,e,a)

lower :: (Num e) => a (n,n) e -> Tri a (n,n) e
lower = Tri Lower NonUnit 1

lowerU :: (Num e) => a (n,n) e -> Tri a (n,n) e
lowerU = Tri Lower Unit 1

upper :: (Num e) => a (n,n) e -> Tri a (n,n) e
upper = Tri Upper NonUnit 1

upperU :: (Num e) => a (n,n) e -> Tri a (n,n) e
upperU = Tri Upper Unit 1

instance Matrix a => Matrix (Tri a) where
numRows (Tri _ _ _ a) = numRows a
numCols (Tri _ _ _ a) = numCols a
herm    (Tri u d e a) = Tri (flipUpLo u) d (E.conj e) (herm a)

instance (Num e) => Scalable (Tri a nn) e where
(*>) k (Tri u d e a) = Tri u d (k*e) a

instance (Show (a mn e), Show e, Num e) => Show (Tri a mn e) where
show (Tri u d k a)
| k /= 1 = "(" ++ show k ++ ") *> " ++ show (Tri u d 1 a)
| otherwise =
constructor ++ " (" ++ show a ++ ")"
where
constructor = case (u,d) of
(Lower, NonUnit) -> "lower"
(Lower, Unit   ) -> "lowerU"
(Upper, NonUnit) -> "upper"
(Upper, Unit   ) -> "upperU"

```