{-# LANGUAGE FlexibleContexts, TypeFamilies #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Packed.Array.Solve
-- Copyright   :  (c) Alberto Ruiz 2009
-- License     :  BSD3
-- Maintainer  :  Alberto Ruiz
-- Stability   :  provisional
--
-- Solution of general multidimensional linear and multilinear systems.
--
-----------------------------------------------------------------------------

module Numeric.LinearAlgebra.Array.Solve (
-- * Linear systems
    solve,
    solveHomog, solveHomog1, solveH,
    solveP,
-- *  Multilinear systems
-- ** General
    ALSParam(..), defaultParameters,
    mlSolve, mlSolveH, mlSolveP,
-- ** Factorized
    solveFactors, solveFactorsH,
-- * Utilities
    eqnorm, infoRank,
    solve', solveHomog', solveHomog1', solveP'
) where

import Numeric.LinearAlgebra.Array.Util
import Numeric.LinearAlgebra.Exterior
import Numeric.LinearAlgebra.Array.Internal(mkNArray, selDims, debug, namesR)
import Numeric.LinearAlgebra.HMatrix hiding (scalar,size)
--import qualified Numeric.LinearAlgebra.HMatrix as LA
import Data.List
import System.Random


-- | Solution of the linear system a x = b, where a and b are
-- general multidimensional arrays. The structure and dimension names
-- of the result are inferred from the arguments.
solve :: (Compat i, Coord t, Field t)
        => NArray i t -- ^ coefficients (a)
        -> NArray i t -- ^ target       (b)
        -> NArray i t -- ^ result       (x)
solve = solve' id

solve' g a b = x where
    nx = namesR a \\ namesR b
    na = namesR a \\ nx
    nb = namesR b \\ namesR a
    aM = g $ matrixator a na nx
    bM = g $ matrixator b na nb
    xM = linearSolveSVD aM bM
    dx = map opos (selDims (dims a) nx) ++ selDims (dims b) nb
    x = mkNArray dx (flatten xM)


-- | Solution of the homogeneous linear system a x = 0, where a is a
-- general multidimensional array.
--
-- If the system is overconstrained we may provide the theoretical rank to get a MSE solution.
solveHomog :: (Compat i, Coord t, Field t)
           =>  NArray i t    -- ^ coefficients (a)
           -> [Name]         -- ^ desired dimensions for the result
                             --   (a subset selected from the target).
           -> Either Double Int -- ^ Left \"numeric zero\" (e.g. eps), Right \"theoretical\" rank
           -> [NArray i t] -- ^ basis for the solutions (x)
solveHomog = solveHomog' id

solveHomog' g a nx' hint = xs where
    nx = filter (`elem` (namesR a)) nx'
    na = namesR a \\ nx
    aM = g $ matrixator a na nx
    vs = toColumns $ nullspaceSVD hint aM (rightSV aM)
    dx = map opos (selDims (dims a) nx)
    xs = map (mkNArray dx) vs

-- | A simpler way to use 'solveHomog', which returns just one solution.
-- If the system is overconstrained it returns the MSE solution.
solveHomog1 :: (Compat i, Coord t, Field t)
            => NArray i t
            -> [Name]
            -> NArray i t
solveHomog1 = solveHomog1' id

solveHomog1' g m ns = head $ solveHomog' g m ns (Right (k-1))
    where k = product $ map iDim $ selDims (dims m) ns

-- | 'solveHomog1' for single letter index names.
solveH :: (Compat i, Coord t, Field t) => NArray i t -> [Char] -> NArray i t
solveH m ns = solveHomog1 m (map return ns)


-- | Solution of the linear system a x = b, where a and b are
-- general multidimensional arrays, with homogeneous equality along a given index.
solveP :: Tensor Double   -- ^ coefficients (a)
       -> Tensor Double   -- ^ desired result (b)
       -> Name            -- ^ the homogeneous dimension
       -> Tensor Double   -- ^ result (x)
solveP = solveP' id

solveP' g a b h = mapTat (solveP1 g h a) (namesR b \\ (h:namesR a)) b

-- solveP for a single right hand side
solveP1 g nh a b = solveHomog1' g ou ns where
    k = size nh b
    epsi = t $ leviCivita k `renameO` (nh : (take (k-1) $ (map (('e':).(:[])) ['2'..])))
    ou = a .* b' * epsi
    ns = (namesR a \\ namesR b) ++ x
    b' = renameExplicit [(nh,"e2")] b
    x = if nh `elem` (namesR a) then [] else [nh]
    t = if typeOf nh b == Co then contrav else cov
        -- mapTypes (const (opos $ typeOf nh b))

-----------------------------------------------------------------------

-- | optimization parameters for alternating least squares
data ALSParam i t = ALSParam
    { nMax  ::   Int     -- ^ maximum number of iterations
    , delta ::   Double  -- ^ minimum relative improvement in the optimization (percent, e.g. 0.1)
    , epsilon :: Double  -- ^ maximum relative error. For nonhomogeneous problems it is
                         --   the reconstruction error in percent (e.g.
                         --   1E-3), and for homogeneous problems is the frobenius norm of the
                         --  expected zero structure in the right hand side.
    , post :: [NArray i t] -> [NArray i t] -- ^ post-processing function after each full iteration (e.g. 'id')
    , postk :: Int -> NArray i t -> NArray i t-- ^ post-processing function for the k-th argument (e.g. 'const' 'id')
    , presys :: Matrix t -> Matrix t -- ^ preprocessing function for the linear systems (eg. 'id', or 'infoRank')
    }


optimize :: (x -> x)      -- ^ method
         -> (x -> Double) -- ^ error function
         -> x             -- ^ starting point
         -> ALSParam i t     -- ^ optimization parameters
         -> (x, [Double]) -- ^ solution and error history
optimize method errfun s0 p = (sol,e) where
    sols = take (max 1 (nMax p)) $ iterate method s0
    errs = map errfun sols
    (sol,e) = convergence (zip sols errs) []
    convergence [] _  = error "impossible"
    convergence [(s,err)] prev = (s, err:prev)
    convergence ((s1,e1):(s2,e2):ses) prev
        | e1 < epsilon p = (s1, e1:prev)
        | abs (100*(e1 - e2)/e1) < delta p = (s2, e2:prev)
        | otherwise = convergence ((s2,e2):ses) (e1:prev)

percent t s = 100 * frobT (t - smartProduct s) / frobT t

percentP h t s = 100 * frobT (t' - s') / frobT t' where
    t' = f t
    s' = f (smartProduct s)
    f = mapTat g (namesR t \\ [h])
    g v = v / atT v [n]
    n = size h t - 1

frobT t = realToFrac . norm_2 . coords $ t
--unitT t = t / scalar (frobT t)

dropElemPos k xs = take k xs ++ drop (k+1) xs
replaceElemPos k v xs = take k xs ++ v : drop (k+1) xs

takes [] _ = []
takes (n:ns) xs = take n xs : takes ns (drop n xs)

----------------------------------------------------------------------

alsStep f params a x = (foldl1' (.) (map (f params a) [n,n-1 .. 0])) x
    where n = length x - 1

-----------------------------------------------------------------------

-- | Solution of a multilinear system a x y z ... = b based on alternating least squares.
mlSolve
  :: (Compat i, Coord t, Field t, Num (NArray i t), Show (NArray i t))
     => ALSParam i t     -- ^ optimization parameters
     -> [NArray i t]  -- ^ coefficients (a), given as a list of factors.
     -> [NArray i t]  -- ^ initial solution [x,y,z...]
     -> NArray i t    -- ^ target (b)
     -> ([NArray i t], [Double]) -- ^ Solution and error history
mlSolve params a x0 b
    = optimize (post params . alsStep (alsArg b) params a) (percent b . (a++)) x0 params

alsArg _ _ _ _ [] = error "alsArg _ _ []"
alsArg b params a k xs = sol where
    p = smartProduct (a ++ dropElemPos k xs)
    x = solve' (presys params) p b
    x' = postk params k x
    sol = replaceElemPos k x' xs

----------------------------------------------------------

-- | Solution of the homogeneous multilinear system a x y z ... = 0 based on alternating least squares.
mlSolveH
  :: (Compat i, Coord t, Field t, Num (NArray i t), Show (NArray i t))
     => ALSParam  i t    -- ^ optimization parameters
     -> [NArray i t]  -- ^ coefficients (a), given as a list of factors.
     -> [NArray i t]  -- ^ initial solution [x,y,z...]
     -> ([NArray i t], [Double]) -- ^ Solution and error history
mlSolveH params a x0
    = optimize (post params . alsStep alsArgH params a) (frobT . smartProduct . (a++)) x0 params

alsArgH _ _ _ [] = error "alsArgH _ _ []"
alsArgH params a k xs = sol where
    p = smartProduct (a ++ dropElemPos k xs)
    x = solveHomog1' (presys params) p (namesR (xs!!k))
    x' = postk params k x
    sol = replaceElemPos k x' xs

----------------------------------------------------------

-- | Solution of a multilinear system a x y z ... = b, with a homogeneous index, based on alternating least squares.
mlSolveP
     :: ALSParam Variant Double     -- ^ optimization parameters
     -> [Tensor Double]  -- ^ coefficients (a), given as a list of factors.
     -> [Tensor Double]  -- ^ initial solution [x,y,z...]
     -> Tensor Double    -- ^ target (b)
     -> Name             -- ^ homogeneous index
     -> ([Tensor Double], [Double]) -- ^ Solution and error history
mlSolveP params a x0 b h
    = optimize (post params . alsStep (alsArgP b h) params a) (percentP h b . (a++)) x0 params

alsArgP _ _ _ _ _ [] = error "alsArgP _ _ []"
alsArgP b h params a k xs = sol where
    p = smartProduct (a ++ dropElemPos k xs)
    x = solveP' (presys params)  p b h
    x' = postk params k x
    sol = replaceElemPos k x' xs

-------------------------------------------------------------

{- | Given two arrays a (source) and  b (target), we try to compute linear transformations x,y,z,... for each dimension, such that product [a,x,y,z,...] == b.
(We can use 'eqnorm' for 'post' processing, or 'id'.)
-}
solveFactors :: (Coord t, Field t, Random t, Compat i, Num (NArray i t), Show (NArray i t))
             => Int          -- ^ seed for random initialization
             -> ALSParam i t     -- ^ optimization parameters
             -> [NArray i t] -- ^ source (also factorized)
             -> String       -- ^ index pairs for the factors separated by spaces
             -> NArray i t   -- ^ target
             -> ([NArray i t],[Double]) -- ^ solution and error history
solveFactors seed params a pairs b =
    mlSolve params a (initFactorsRandom seed (smartProduct a) pairs b) b

initFactorsSeq rs a pairs b | ok = as
                            | otherwise = error "solveFactors index pairs"
  where
    (ia,ib) = unzip (map sep (words pairs))
    ic = intersect (namesR a) (namesR b)
    ok = sort (namesR b\\ic) == sort ib && sort (namesR a\\ic) == sort ia
    db = selDims (dims b) ib
    da = selDims (dims a) ia
    nb = map iDim db
    na = map iDim da
    ts = takes (zipWith (*) nb na) rs
    as = zipWith5 f ts ib ia db da
    f c i1 i2 d1 d2 = (mkNArray [d1,opos d2] (fromList c)) `renameO` [i1,i2]

initFactorsRandom seed a b = initFactorsSeq (randomRs (-1,1) (mkStdGen seed)) a b


-- | Homogeneous factorized system. Given an array a,
-- given as a list of factors as, and a list of pairs of indices
-- [\"pi\",\"qj\", \"rk\", etc.], we try to compute linear transformations
-- x!\"pi\", y!\"pi\", z!\"rk\", etc. such that product [a,x,y,z,...] == 0.
solveFactorsH
  :: (Coord t, Random t, Field t, Compat i, Num (NArray i t), Show (NArray i t))
     => Int -- ^ seed for random initialization
     -> ALSParam  i t    -- ^ optimization parameters
     -> [NArray i t] -- ^ coefficient array (a), (also factorized)
     -> String       -- ^ index pairs for the factors separated by spaces
     -> ([NArray i t], [Double]) -- ^ solution and error history
solveFactorsH seed params a pairs =
    mlSolveH params a (initFactorsHRandom seed (smartProduct a) pairs)

initFactorsHSeq rs a pairs = as where
    (ir,it) = unzip (map sep (words pairs))
    nr = map (flip size a) ir
    nt = map (flip size a) it
    ts = takes (zipWith (*) nr nt) rs
    as = zipWith5 f ts ir it (selDims (dims a) ir) (selDims (dims a) it)
    f c i1 i2 d1 d2 = (mkNArray (map opos [d1,d2]) (fromList c)) `renameO` [i1,i2]

initFactorsHRandom seed a pairs = initFactorsHSeq (randomRs (-1,1) (mkStdGen seed)) a pairs

sep [a,b] = ([a],[b])
sep _ = error "impossible pattern in hTensor initFactors"

----------------------------------

-- | post processing function that modifies a list of tensors so that they
-- have equal frobenius norm
eqnorm :: (Compat i,Show (NArray i Double))
       => [NArray i Double] -> [NArray i Double]

eqnorm [] = error "eqnorm []"
eqnorm as = as' where
    n = length as
    fs = map (frobT) as
    s = product fs ** (1/fromIntegral n)
    as' = zipWith g as fs where g a f = a * (scalar (s/f))

-- | nMax = 20, epsilon = 1E-3, delta = 1, post = id, postk = const id, presys = id
defaultParameters :: ALSParam i t
defaultParameters = ALSParam {
    nMax = 20,
    epsilon = 1E-3,
    delta = 1,
    post = id,
    postk = const id,
    presys = id
  }

-- | debugging function (e.g. for 'presys'), which shows rows, columns and rank of the
-- coefficient matrix of a linear system.
infoRank :: Field t => Matrix t -> Matrix t
infoRank a = debug "" (const (rows a, cols a, rank a)) a