module Numeric.LinearAlgebra.Algorithms (
    Field(),
    linearSolve,
    mbLinearSolve,
    luSolve,
    cholSolve,
    linearSolveLS,
    linearSolveSVD,
    inv, pinv, pinvTol,
    det, invlndet,
    rank, rcond,
    svd,
    fullSVD,
    thinSVD,
    compactSVD,
    singularValues,
    leftSV, rightSV,
    eig, eigSH, eigSH',
    eigenvalues, eigenvaluesSH, eigenvaluesSH',
    geigSH',
    qr, rq, qrRaw, qrgr,
    chol, cholSH, mbCholSH,
    hess,
    schur,
    lu, luPacked,
    expm,
    sqrtm,
    matFunc,
    nullspacePrec,
    nullVector,
    nullspaceSVD,
    orthSVD,
    orth,
    Normed(..), NormType(..),
    relativeError', relativeError,
    eps, peps, i,
    haussholder,
    unpackQR, unpackHess,
    ranksv
) where
import Data.Packed
import Numeric.LinearAlgebra.LAPACK as LAPACK
import Data.List(foldl1')
import Data.Array
import Data.Packed.Internal.Numeric
import Data.Packed.Internal(shSize)
class (Product t,
       Convert t,
       Container Vector t,
       Container Matrix t,
       Normed Matrix t,
       Normed Vector t,
       Floating t,
       RealOf t ~ Double) => Field t where
    svd'         :: Matrix t -> (Matrix t, Vector Double, Matrix t)
    thinSVD'     :: Matrix t -> (Matrix t, Vector Double, Matrix t)
    sv'          :: Matrix t -> Vector Double
    luPacked'    :: Matrix t -> (Matrix t, [Int])
    luSolve'     :: (Matrix t, [Int]) -> Matrix t -> Matrix t
    mbLinearSolve' :: Matrix t -> Matrix t -> Maybe (Matrix t)
    linearSolve' :: Matrix t -> Matrix t -> Matrix t
    cholSolve'   :: Matrix t -> Matrix t -> Matrix t
    linearSolveSVD' :: Matrix t -> Matrix t -> Matrix t
    linearSolveLS'  :: Matrix t -> Matrix t -> Matrix t
    eig'         :: Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
    eigSH''      :: Matrix t -> (Vector Double, Matrix t)
    eigOnly      :: Matrix t -> Vector (Complex Double)
    eigOnlySH    :: Matrix t -> Vector Double
    cholSH'      :: Matrix t -> Matrix t
    mbCholSH'    :: Matrix t -> Maybe (Matrix t)
    qr'          :: Matrix t -> (Matrix t, Vector t)
    qrgr'        :: Int -> (Matrix t, Vector t) -> Matrix t
    hess'        :: Matrix t -> (Matrix t, Matrix t)
    schur'       :: Matrix t -> (Matrix t, Matrix t)
instance Field Double where
    svd' = svdRd
    thinSVD' = thinSVDRd
    sv' = svR
    luPacked' = luR
    luSolve' (l_u,perm) = lusR l_u perm
    linearSolve' = linearSolveR                 
    mbLinearSolve' = mbLinearSolveR
    cholSolve' = cholSolveR
    linearSolveLS' = linearSolveLSR
    linearSolveSVD' = linearSolveSVDR Nothing
    eig' = eigR
    eigSH'' = eigS
    eigOnly = eigOnlyR
    eigOnlySH = eigOnlyS
    cholSH' = cholS
    mbCholSH' = mbCholS
    qr' = qrR
    qrgr' = qrgrR
    hess' = unpackHess hessR
    schur' = schurR
instance Field (Complex Double) where
#ifdef NOZGESDD
    svd' = svdC
    thinSVD' = thinSVDC
#else
    svd' = svdCd
    thinSVD' = thinSVDCd
#endif
    sv' = svC
    luPacked' = luC
    luSolve' (l_u,perm) = lusC l_u perm
    linearSolve' = linearSolveC
    mbLinearSolve' = mbLinearSolveC
    cholSolve' = cholSolveC
    linearSolveLS' = linearSolveLSC
    linearSolveSVD' = linearSolveSVDC Nothing
    eig' = eigC
    eigOnly = eigOnlyC
    eigSH'' = eigH
    eigOnlySH = eigOnlyH
    cholSH' = cholH
    mbCholSH' = mbCholH
    qr' = qrC
    qrgr' = qrgrC
    hess' = unpackHess hessC
    schur' = schurC
square m = rows m == cols m
vertical m = rows m >= cols m
exactHermitian m = m `equal` ctrans m
svd :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t)
svd =  svd'
thinSVD :: Field t => Matrix t -> (Matrix t, Vector Double, Matrix t)
thinSVD =  thinSVD'
singularValues :: Field t => Matrix t -> Vector Double
singularValues =  sv'
fullSVD :: Field t => Matrix t -> (Matrix t, Matrix Double, Matrix t)
fullSVD m = (u,d,v) where
    (u,s,v) = svd m
    d = diagRect 0 s r c
    r = rows m
    c = cols m
compactSVD :: Field t  => Matrix t -> (Matrix t, Vector Double, Matrix t)
compactSVD m = (u', subVector 0 d s, v') where
    (u,s,v) = thinSVD m
    d = rankSVD (1*eps) m s `max` 1
    u' = takeColumns d u
    v' = takeColumns d v
rightSV :: Field t => Matrix t -> (Vector Double, Matrix t)
rightSV m | vertical m = let (_,s,v) = thinSVD m in (s,v)
          | otherwise  = let (_,s,v) = svd m     in (s,v)
leftSV :: Field t => Matrix t -> (Matrix t, Vector Double)
leftSV m  | vertical m = let (u,s,_) = svd m     in (u,s)
          | otherwise  = let (u,s,_) = thinSVD m in (u,s)
luPacked :: Field t => Matrix t -> (Matrix t, [Int])
luPacked =  luPacked'
luSolve :: Field t => (Matrix t, [Int]) -> Matrix t -> Matrix t
luSolve =  luSolve'
linearSolve :: Field t => Matrix t -> Matrix t -> Matrix t
linearSolve =  linearSolve'
mbLinearSolve :: Field t => Matrix t -> Matrix t -> Maybe (Matrix t)
mbLinearSolve =  mbLinearSolve'
cholSolve :: Field t => Matrix t -> Matrix t -> Matrix t
cholSolve =  cholSolve'
linearSolveSVD :: Field t => Matrix t -> Matrix t -> Matrix t
linearSolveSVD =  linearSolveSVD'
linearSolveLS :: Field t => Matrix t -> Matrix t -> Matrix t
linearSolveLS =  linearSolveLS'
eig :: Field t => Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
eig =  eig'
eigenvalues :: Field t => Matrix t -> Vector (Complex Double)
eigenvalues =  eigOnly
eigSH' :: Field t => Matrix t -> (Vector Double, Matrix t)
eigSH' =  eigSH''
eigenvaluesSH' :: Field t => Matrix t -> Vector Double
eigenvaluesSH' =  eigOnlySH
eigSH :: Field t => Matrix t -> (Vector Double, Matrix t)
eigSH m | exactHermitian m = eigSH' m
        | otherwise = error "eigSH requires complex hermitian or real symmetric matrix"
eigenvaluesSH :: Field t => Matrix t -> Vector Double
eigenvaluesSH m | exactHermitian m = eigenvaluesSH' m
                | otherwise = error "eigenvaluesSH requires complex hermitian or real symmetric matrix"
qr :: Field t => Matrix t -> (Matrix t, Matrix t)
qr =  unpackQR . qr'
qrRaw m = qr' m
qrgr n (a,t)
    | dim t > min (cols a) (rows a) || n < 0 || n > dim t = error "qrgr expects k <= min(rows,cols)"
    | otherwise = qrgr' n (a,t)
rq :: Field t => Matrix t -> (Matrix t, Matrix t)
rq m =   (r,q) where
    (q',r') = qr $ trans $ rev1 m
    r = rev2 (trans r')
    q = rev2 (trans q')
    rev1 = flipud . fliprl
    rev2 = fliprl . flipud
hess        :: Field t => Matrix t -> (Matrix t, Matrix t)
hess = hess'
schur       :: Field t => Matrix t -> (Matrix t, Matrix t)
schur = schur'
mbCholSH :: Field t => Matrix t -> Maybe (Matrix t)
mbCholSH =  mbCholSH'
cholSH      :: Field t => Matrix t -> Matrix t
cholSH =  cholSH'
chol :: Field t => Matrix t ->  Matrix t
chol m | exactHermitian m = cholSH m
       | otherwise = error "chol requires positive definite complex hermitian or real symmetric matrix"
invlndet :: Field t
         => Matrix t
         -> (Matrix t, (t, t)) 
invlndet m | square m = (im,(ladm,sdm))
           | otherwise = error $ "invlndet of nonsquare "++ shSize m ++ " matrix"
  where
    lp@(lup,perm) = luPacked m
    s = signlp (rows m) perm
    dg = toList $ takeDiag $ lup
    ladm = sum $ map (log.abs) dg
    sdm = s* product (map signum dg)
    im = luSolve lp (ident (rows m))
det :: Field t => Matrix t -> t
det m | square m =  s * (product $ toList $ takeDiag $ lup)
      | otherwise = error $ "det of nonsquare "++ shSize m ++ " matrix"
    where (lup,perm) = luPacked m
          s = signlp (rows m) perm
lu :: Field t => Matrix t -> (Matrix t, Matrix t, Matrix t, t)
lu = luFact . luPacked
inv :: Field t => Matrix t -> Matrix t
inv m | square m = m `linearSolve` ident (rows m)
      | otherwise = error $ "inv of nonsquare "++ shSize m ++ " matrix"
pinv :: Field t => Matrix t -> Matrix t
pinv = pinvTol 1
pinvTol :: Field t => Double -> Matrix t -> Matrix t
pinvTol t m = conj v' `mXm` diag s' `mXm` ctrans u' where
    (u,s,v) = thinSVD m
    sl@(g:_) = toList s
    s' = real . fromList . map rec $ sl
    rec x = if x <= g*tol then x else 1/x
    tol = (fromIntegral (max r c) * g * t * eps)
    r = rows m
    c = cols m
    d = dim s
    u' = takeColumns d u
    v' = takeColumns d v
rankSVD :: Element t
        => Double   
        -> Matrix t 
        -> Vector Double 
        -> Int      
rankSVD teps m s = ranksv teps (max (rows m) (cols m)) (toList s)
ranksv ::  Double   
        -> Int      
        -> [Double] 
        -> Int      
ranksv teps maxdim s = k where
    g = maximum s
    tol = fromIntegral maxdim * g * teps
    s' = filter (>tol) s
    k = if g > teps then length s' else 0
eps :: Double
eps =  2.22044604925031e-16
peps :: RealFloat x => x
peps = x where x = 2.0 ** fromIntegral (1  floatDigits x)
i :: Complex Double
i = 0:+1
nullspaceSVD :: Field t
             => Either Double Int 
                                  
             -> Matrix t          
             -> (Vector Double, Matrix t) 
             -> Matrix t          
nullspaceSVD hint a (s,v) = vs where
    tol = case hint of
        Left t -> t
        _      -> eps
    k = case hint of
        Right t -> t
        _       -> rankSVD tol a s
    vs = conj (dropColumns k v)
nullspacePrec :: Field t
              => Double     
              -> Matrix t   
              -> [Vector t] 
nullspacePrec t m = toColumns $ nullspaceSVD (Left (t*eps)) m (rightSV m)
nullVector :: Field t => Matrix t -> Vector t
nullVector = last . nullspacePrec 1
orthSVD :: Field t
             => Either Double Int 
                                  
             -> Matrix t          
             -> (Matrix t, Vector Double) 
             -> Matrix t          
orthSVD hint a (v,s) = vs where
    tol = case hint of
        Left t -> t
        _      -> eps
    k = case hint of
        Right t -> t
        _       -> rankSVD tol a s
    vs = takeColumns k v
orth :: Field t => Matrix t -> [Vector t]
orth m = take r $ toColumns u
  where
    (u,s,_) = compactSVD m
    r = ranksv eps (max (rows m) (cols m)) (toList s)
haussholder :: (Field a) => a -> Vector a -> Matrix a
haussholder tau v = ident (dim v) `sub` (tau `scale` (w `mXm` ctrans w))
    where w = asColumn v
zh k v = fromList $ replicate (k1) 0 ++ (1:drop k xs)
              where xs = toList v
zt 0 v = v
zt k v = vjoin [subVector 0 (dim v  k) v, konst' 0 k]
unpackQR :: (Field t) => (Matrix t, Vector t) -> (Matrix t, Matrix t)
unpackQR (pq, tau) =   (q,r)
    where cs = toColumns pq
          m = rows pq
          n = cols pq
          mn = min m n
          r = fromColumns $ zipWith zt ([m1, m2 .. 1] ++ repeat 0) cs
          vs = zipWith zh [1..mn] cs
          hs = zipWith haussholder (toList tau) vs
          q = foldl1' mXm hs
unpackHess :: (Field t) => (Matrix t -> (Matrix t,Vector t)) -> Matrix t -> (Matrix t, Matrix t)
unpackHess hf m
    | rows m == 1 = ((1><1)[1],m)
    | otherwise = (uH . hf) m
uH (pq, tau) = (p,h)
    where cs = toColumns pq
          m = rows pq
          n = cols pq
          mn = min m n
          h = fromColumns $ zipWith zt ([m2, m3 .. 1] ++ repeat 0) cs
          vs = zipWith zh [2..mn] cs
          hs = zipWith haussholder (toList tau) vs
          p = foldl1' mXm hs
rcond :: Field t => Matrix t -> Double
rcond m = last s / head s
    where s = toList (singularValues m)
rank :: Field t => Matrix t -> Int
rank m = rankSVD eps m (singularValues m)
diagonalize m = if rank v == n
                    then Just (l,v)
                    else Nothing
    where n = rows m
          (l,v) = if exactHermitian m
                    then let (l',v') = eigSH m in (real l', v')
                    else eig m
matFunc :: (Complex Double -> Complex Double) -> Matrix (Complex Double) -> Matrix (Complex Double)
matFunc f m = case diagonalize m of
    Just (l,v) -> v `mXm` diag (mapVector f l) `mXm` inv v
    Nothing -> error "Sorry, matFunc requires a diagonalizable matrix"
golubeps :: Integer -> Integer -> Double
golubeps p q = a * fromIntegral b / fromIntegral c where
    a = 2^^(3pq)
    b = fact p * fact q
    c = fact (p+q) * fact (p+q+1)
    fact n = product [1..n]
epslist :: [(Int,Double)]
epslist = [ (fromIntegral k, golubeps k k) | k <- [1..]]
geps delta = head [ k | (k,g) <- epslist, g<delta]
expm :: Field t => Matrix t -> Matrix t
expm = expGolub
expGolub :: Field t => Matrix t -> Matrix t
expGolub m = iterate msq f !! j
    where j = max 0 $ floor $ logBase 2 $ pnorm Infinity m
          a = m */ fromIntegral ((2::Int)^j)
          q = geps eps 
          eye = ident (rows m)
          work (k,c,x,n,d) = (k',c',x',n',d')
              where k' = k+1
                    c' = c * fromIntegral (qk+1) / fromIntegral ((2*qk+1)*k)
                    x' = a <> x
                    n' = n |+| (c' .* x')
                    d' = d |+| (((1)^k * c') .* x')
          (_,_,_,nf,df) = iterate work (1,1,eye,eye,eye) !! q
          f = linearSolve df nf
          msq x = x <> x
          (<>) = multiply
          v */ x = scale (recip x) v
          (.*) = scale
          (|+|) = add
sqrtm ::  Field t => Matrix t -> Matrix t
sqrtm = sqrtmInv
sqrtmInv x = fst $ fixedPoint $ iterate f (x, ident (rows x))
    where fixedPoint (a:b:rest) | pnorm PNorm1 (fst a |-| fst b) < peps   = a
                                | otherwise = fixedPoint (b:rest)
          fixedPoint _ = error "fixedpoint with impossible inputs"
          f (y,z) = (0.5 .* (y |+| inv z),
                     0.5 .* (inv y |+| z))
          (.*) = scale
          (|+|) = add
          (|-|) = sub
signlp r vals = foldl f 1 (zip [0..r1] vals)
    where f s (a,b) | a /= b    = s
                    | otherwise =  s
swap (arr,s) (a,b) | a /= b    = (arr // [(a, arr!b),(b,arr!a)],s)
                   | otherwise = (arr,s)
fixPerm r vals = (fromColumns $ elems res, sign)
    where v = [0..r1]
          s = toColumns (ident r)
          (res,sign) = foldl swap (listArray (0,r1) s, 1) (zip v vals)
triang r c h v = (r><c) [el s t | s<-[0..r1], t<-[0..c1]]
    where el p q = if qp>=h then v else 1  v
luFact (l_u,perm) | r <= c    = (l ,u ,p, s)
                  | otherwise = (l',u',p, s)
  where
    r = rows l_u
    c = cols l_u
    tu = triang r c 0 1
    tl = triang r c 0 0
    l = takeColumns r (l_u |*| tl) |+| diagRect 0 (konst' 1 r) r r
    u = l_u |*| tu
    (p,s) = fixPerm r perm
    l' = (l_u |*| tl) |+| diagRect 0 (konst' 1 c) r c
    u' = takeRows c (l_u |*| tu)
    (|+|) = add
    (|*|) = mul
data NormType = Infinity | PNorm1 | PNorm2 | Frobenius
class (RealFloat (RealOf t)) => Normed c t where
    pnorm :: NormType -> c t -> RealOf t
instance Normed Vector Double where
    pnorm PNorm1    = norm1
    pnorm PNorm2    = norm2
    pnorm Infinity  = normInf
    pnorm Frobenius = norm2
instance Normed Vector (Complex Double) where
    pnorm PNorm1    = norm1
    pnorm PNorm2    = norm2
    pnorm Infinity  = normInf
    pnorm Frobenius = pnorm PNorm2
instance Normed Vector Float where
    pnorm PNorm1    = norm1
    pnorm PNorm2    = norm2
    pnorm Infinity  = normInf
    pnorm Frobenius = pnorm PNorm2
instance Normed Vector (Complex Float) where
    pnorm PNorm1    = norm1
    pnorm PNorm2    = norm2
    pnorm Infinity  = normInf
    pnorm Frobenius = pnorm PNorm2
instance Normed Matrix Double where
    pnorm PNorm1    = maximum . map (pnorm PNorm1) . toColumns
    pnorm PNorm2    = (@>0) . singularValues
    pnorm Infinity  = pnorm PNorm1 . trans
    pnorm Frobenius = pnorm PNorm2 . flatten
instance Normed Matrix (Complex Double) where
    pnorm PNorm1    = maximum . map (pnorm PNorm1) . toColumns
    pnorm PNorm2    = (@>0) . singularValues
    pnorm Infinity  = pnorm PNorm1 . trans
    pnorm Frobenius = pnorm PNorm2 . flatten
instance Normed Matrix Float where
    pnorm PNorm1    = maximum . map (pnorm PNorm1) . toColumns
    pnorm PNorm2    = realToFrac . (@>0) . singularValues . double
    pnorm Infinity  = pnorm PNorm1 . trans
    pnorm Frobenius = pnorm PNorm2 . flatten
instance Normed Matrix (Complex Float) where
    pnorm PNorm1    = maximum . map (pnorm PNorm1) . toColumns
    pnorm PNorm2    = realToFrac . (@>0) . singularValues . double
    pnorm Infinity  = pnorm PNorm1 . trans
    pnorm Frobenius = pnorm PNorm2 . flatten
relativeError' :: (Normed c t, Container c t) => c t -> c t -> Int
relativeError' x y = dig (norm (x `sub` y) / norm x)
    where norm = pnorm Infinity
          dig r = round $ logBase 10 (realToFrac r :: Double)
relativeError :: (Normed c t, Num (c t)) => NormType -> c t -> c t -> Double
relativeError t a b = realToFrac r
  where
    norm = pnorm t
    na = norm a
    nb = norm b
    nab = norm (ab)
    mx = max na nb
    mn = min na nb
    r = if mn < peps
        then mx
        else nab/mx
geigSH' :: Field t
        => Matrix t 
        -> Matrix t 
        -> (Vector Double, Matrix t)
geigSH' a b = (l,v')
  where
    u = cholSH b
    iu = inv u
    c = ctrans iu <> a <> iu
    (l,v) = eigSH' c
    v' = iu <> v
    (<>) = mXm