{-# LANGUAGE FlexibleContexts, FlexibleInstances #-}
{-# LANGUAGE RecordWildCards #-}

{-# OPTIONS_GHC -fno-warn-orphans #-}

module Internal.CG(
    cgSolve, cgSolve',
    CGState(..), R, V
) where

import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import Internal.Element
import Internal.IO
import Internal.Container
import Internal.Sparse
import Numeric.Vector()
import Internal.Algorithms(linearSolveLS, linearSolve, relativeError, pnorm, NormType(..))
import Control.Arrow((***))

{-
import Util.Misc(debug, debugMat)

(//) :: Show a => a -> String -> a
infix 0 // -- , ///
a // b = debug b id a

(///) :: V -> String -> V
infix 0 ///
v /// b = debugMat b 2 asRow v
-}

type V = Vector R

data CGState = CGState
    { CGState -> Vector R
cgp  :: Vector R  -- ^ conjugate gradient
    , CGState -> Vector R
cgr  :: Vector R  -- ^ residual
    , CGState -> R
cgr2 :: R         -- ^ squared norm of residual
    , CGState -> Vector R
cgx  :: Vector R  -- ^ current solution
    , CGState -> R
cgdx :: R         -- ^ normalized size of correction
    }

cg :: Bool -> (V -> V) -> (V -> V) -> CGState -> CGState
cg :: Bool
-> (Vector R -> Vector R)
-> (Vector R -> Vector R)
-> CGState
-> CGState
cg Bool
sym Vector R -> Vector R
at Vector R -> Vector R
a (CGState Vector R
p Vector R
r R
r2 Vector R
x R
_) = Vector R -> Vector R -> R -> Vector R -> R -> CGState
CGState Vector R
p' Vector R
r' R
r'2 Vector R
x' R
rdx
  where
    ap1 :: Vector R
ap1 = Vector R -> Vector R
a Vector R
p
    ap :: Vector R
ap  | Bool
sym       = Vector R
ap1
        | Bool
otherwise = Vector R -> Vector R
at Vector R
ap1
    pap :: R
pap | Bool
sym       = Vector R
p Vector R -> Vector R -> R
forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
ap1
        | Bool
otherwise = Vector R -> RealOf R
forall e. (Product e, Floating e) => Vector e -> RealOf e
norm2 Vector R
ap1 R -> R -> R
forall a. Floating a => a -> a -> a
** R
2
    alpha :: R
alpha = R
r2 R -> R -> R
forall a. Fractional a => a -> a -> a
/ R
pap
    dx :: Vector R
dx = R -> Vector R -> Vector R
forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale R
alpha Vector R
p
    x' :: Vector R
x' = Vector R
x Vector R -> Vector R -> Vector R
forall a. Num a => a -> a -> a
+ Vector R
dx
    r' :: Vector R
r' = Vector R
r Vector R -> Vector R -> Vector R
forall a. Num a => a -> a -> a
- R -> Vector R -> Vector R
forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale R
alpha Vector R
ap
    r'2 :: R
r'2 = Vector R
r' Vector R -> Vector R -> R
forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
r'
    beta :: R
beta = R
r'2 R -> R -> R
forall a. Fractional a => a -> a -> a
/ R
r2
    p' :: Vector R
p' = Vector R
r' Vector R -> Vector R -> Vector R
forall a. Num a => a -> a -> a
+ R -> Vector R -> Vector R
forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale R
beta Vector R
p

    rdx :: R
rdx = Vector R -> RealOf R
forall e. (Product e, Floating e) => Vector e -> RealOf e
norm2 Vector R
dx R -> R -> R
forall a. Fractional a => a -> a -> a
/ R -> R -> R
forall a. Ord a => a -> a -> a
max R
1 (Vector R -> RealOf R
forall e. (Product e, Floating e) => Vector e -> RealOf e
norm2 Vector R
x)

conjugrad
  :: Bool -> GMatrix -> V -> V -> R -> R -> [CGState]
conjugrad :: Bool -> GMatrix -> Vector R -> Vector R -> R -> R -> [CGState]
conjugrad Bool
sym GMatrix
a Vector R
b = Bool
-> (Vector R -> Vector R)
-> (Vector R -> Vector R)
-> ((Vector R -> Vector R)
    -> (Vector R -> Vector R) -> CGState -> CGState)
-> Vector R
-> Vector R
-> R
-> R
-> [CGState]
solveG Bool
sym (GMatrix -> GMatrix
forall m mt. Transposable m mt => m -> mt
tr GMatrix
a GMatrix -> Vector R -> Vector R
!#>) (GMatrix
a GMatrix -> Vector R -> Vector R
!#>) (Bool
-> (Vector R -> Vector R)
-> (Vector R -> Vector R)
-> CGState
-> CGState
cg Bool
sym) Vector R
b

solveG
    :: Bool
    -> (V -> V) -> (V -> V)
    -> ((V -> V) -> (V -> V) -> CGState -> CGState)
    -> V
    -> V
    -> R -> R
    -> [CGState]
solveG :: Bool
-> (Vector R -> Vector R)
-> (Vector R -> Vector R)
-> ((Vector R -> Vector R)
    -> (Vector R -> Vector R) -> CGState -> CGState)
-> Vector R
-> Vector R
-> R
-> R
-> [CGState]
solveG Bool
sym Vector R -> Vector R
mat Vector R -> Vector R
ma (Vector R -> Vector R)
-> (Vector R -> Vector R) -> CGState -> CGState
meth Vector R
rawb Vector R
x0' R
ϵb R
ϵx
    = (CGState -> Bool) -> [CGState] -> [CGState]
forall a. (a -> Bool) -> [a] -> [a]
takeUntil CGState -> Bool
ok ([CGState] -> [CGState])
-> (CGState -> [CGState]) -> CGState -> [CGState]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CGState -> CGState) -> CGState -> [CGState]
forall a. (a -> a) -> a -> [a]
iterate ((Vector R -> Vector R)
-> (Vector R -> Vector R) -> CGState -> CGState
meth Vector R -> Vector R
mat Vector R -> Vector R
ma) (CGState -> [CGState]) -> CGState -> [CGState]
forall a b. (a -> b) -> a -> b
$ Vector R -> Vector R -> R -> Vector R -> R -> CGState
CGState Vector R
p0 Vector R
r0 R
r20 Vector R
x0 R
1
  where
    a :: Vector R -> Vector R
a = if Bool
sym then Vector R -> Vector R
ma else Vector R -> Vector R
mat (Vector R -> Vector R)
-> (Vector R -> Vector R) -> Vector R -> Vector R
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector R -> Vector R
ma
    b :: Vector R
b = if Bool
sym then Vector R
rawb else Vector R -> Vector R
mat Vector R
rawb
    x0 :: Vector R
x0  = if Vector R
x0' Vector R -> Vector R -> Bool
forall a. Eq a => a -> a -> Bool
== Vector R
0 then R -> Int -> Vector R
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst R
0 (Vector R -> Int
forall t. Storable t => Vector t -> Int
dim Vector R
b) else Vector R
x0'
    r0 :: Vector R
r0  = Vector R
b Vector R -> Vector R -> Vector R
forall a. Num a => a -> a -> a
- Vector R -> Vector R
a Vector R
x0
    r20 :: R
r20 = Vector R
r0 Vector R -> Vector R -> R
forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
r0
    p0 :: Vector R
p0  = Vector R
r0
    nb2 :: R
nb2 = Vector R
b Vector R -> Vector R -> R
forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
b
    ok :: CGState -> Bool
ok CGState {R
Vector R
cgdx :: R
cgx :: Vector R
cgr2 :: R
cgr :: Vector R
cgp :: Vector R
cgdx :: CGState -> R
cgx :: CGState -> Vector R
cgr2 :: CGState -> R
cgr :: CGState -> Vector R
cgp :: CGState -> Vector R
..}
        =  R
cgr2 R -> R -> Bool
forall a. Ord a => a -> a -> Bool
<R
nb2R -> R -> R
forall a. Num a => a -> a -> a
*R
ϵbR -> R -> R
forall a. Floating a => a -> a -> a
**R
2
        Bool -> Bool -> Bool
|| R
cgdx R -> R -> Bool
forall a. Ord a => a -> a -> Bool
< R
ϵx


takeUntil :: (a -> Bool) -> [a] -> [a]
takeUntil :: (a -> Bool) -> [a] -> [a]
takeUntil a -> Bool
q [a]
xs = [a]
a[a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
1 [a]
b
  where
    ([a]
a,[a]
b) = (a -> Bool) -> [a] -> ([a], [a])
forall a. (a -> Bool) -> [a] -> ([a], [a])
break a -> Bool
q [a]
xs

-- | Solve a sparse linear system using the conjugate gradient method with default parameters.
cgSolve
  :: Bool          -- ^ is symmetric
  -> GMatrix       -- ^ coefficient matrix
  -> Vector R      -- ^ right-hand side
  -> Vector R      -- ^ solution
cgSolve :: Bool -> GMatrix -> Vector R -> Vector R
cgSolve Bool
sym GMatrix
a Vector R
b  = CGState -> Vector R
cgx (CGState -> Vector R) -> CGState -> Vector R
forall a b. (a -> b) -> a -> b
$ [CGState] -> CGState
forall a. [a] -> a
last ([CGState] -> CGState) -> [CGState] -> CGState
forall a b. (a -> b) -> a -> b
$ Bool
-> R -> R -> Int -> GMatrix -> Vector R -> Vector R -> [CGState]
cgSolve' Bool
sym R
1E-4 R
1E-3 Int
n GMatrix
a Vector R
b Vector R
0
  where
    n :: Int
n = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
10 (R -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round (R -> Int) -> R -> Int
forall a b. (a -> b) -> a -> b
$ R -> R
forall a. Floating a => a -> a
sqrt (Int -> R
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Vector R -> Int
forall t. Storable t => Vector t -> Int
dim Vector R
b) :: Double))

-- | Solve a sparse linear system using the conjugate gradient method with default parameters.
cgSolve'
  :: Bool      -- ^ symmetric
  -> R         -- ^ relative tolerance for the residual (e.g. 1E-4)
  -> R         -- ^ relative tolerance for δx (e.g. 1E-3)
  -> Int       -- ^ maximum number of iterations
  -> GMatrix   -- ^ coefficient matrix
  -> Vector R  -- ^ initial solution
  -> Vector R  -- ^ right-hand side
  -> [CGState] -- ^ solution
cgSolve' :: Bool
-> R -> R -> Int -> GMatrix -> Vector R -> Vector R -> [CGState]
cgSolve' Bool
sym R
er R
es Int
n GMatrix
a Vector R
b Vector R
x = Int -> [CGState] -> [CGState]
forall a. Int -> [a] -> [a]
take Int
n ([CGState] -> [CGState]) -> [CGState] -> [CGState]
forall a b. (a -> b) -> a -> b
$ Bool -> GMatrix -> Vector R -> Vector R -> R -> R -> [CGState]
conjugrad Bool
sym GMatrix
a Vector R
b Vector R
x R
er R
es


--------------------------------------------------------------------------------

instance Testable GMatrix
  where
    checkT :: GMatrix -> (Bool, IO ())
checkT GMatrix
_ = (Bool
ok,IO ()
info)
      where
        sma :: AssocMatrix
sma = Int -> Int -> AssocMatrix
convo2 Int
20 Int
3
        x1 :: Vector R
x1 = [R] -> Vector R
vect [R
1..R
20]
        x2 :: Vector R
x2 = [R] -> Vector R
vect [R
1..R
40]
        sm :: GMatrix
sm = AssocMatrix -> GMatrix
mkSparse [((Int, Int), R)]
AssocMatrix
sma
        dm :: Matrix R
dm = AssocMatrix -> Matrix R
toDense [((Int, Int), R)]
AssocMatrix
sma

        s1 :: Vector R
s1 = GMatrix
sm GMatrix -> Vector R -> Vector R
!#> Vector R
x1
        d1 :: Vector R
d1 = Matrix R
dm Matrix R -> Vector R -> Vector R
forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector R
x1

        s2 :: Vector R
s2 = GMatrix -> GMatrix
forall m mt. Transposable m mt => m -> mt
tr GMatrix
sm GMatrix -> Vector R -> Vector R
!#> Vector R
x2
        d2 :: Vector R
d2 = Matrix R -> Matrix R
forall m mt. Transposable m mt => m -> mt
tr Matrix R
dm Matrix R -> Vector R -> Vector R
forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector R
x2

        sdia :: GMatrix
sdia = Int -> Int -> Vector R -> GMatrix
mkDiagR Int
40 Int
20 ([R] -> Vector R
vect [R
1..R
10])
        s3 :: Vector R
s3 =    GMatrix
sdia GMatrix -> Vector R -> Vector R
!#> Vector R
x1
        s4 :: Vector R
s4 = GMatrix -> GMatrix
forall m mt. Transposable m mt => m -> mt
tr GMatrix
sdia GMatrix -> Vector R -> Vector R
!#> Vector R
x2
        ddia :: Matrix R
ddia = R -> Vector R -> Int -> Int -> Matrix R
forall t. Storable t => t -> Vector t -> Int -> Int -> Matrix t
diagRect R
0 ([R] -> Vector R
vect [R
1..R
10])  Int
40 Int
20
        d3 :: Vector R
d3 = Matrix R
ddia Matrix R -> Vector R -> Vector R
forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector R
x1
        d4 :: Vector R
d4 = Matrix R -> Matrix R
forall m mt. Transposable m mt => m -> mt
tr Matrix R
ddia Matrix R -> Vector R -> Vector R
forall t. Numeric t => Matrix t -> Vector t -> Vector t
#> Vector R
x2

        v :: Vector R
v = Int -> Vector R
testb Int
40
        s5 :: Vector R
s5 = Bool -> GMatrix -> Vector R -> Vector R
cgSolve Bool
False GMatrix
sm Vector R
v
        d5 :: Vector R
d5 = Matrix R -> Vector R -> Vector R
forall t. Field t => Matrix t -> Vector t -> Vector t
denseSolve Matrix R
dm Vector R
v

        symassoc :: [((Int, Int), R)]
symassoc = [((Int
0,Int
0),R
1.0),((Int
1,Int
1),R
2.0),((Int
0,Int
1),R
0.5),((Int
1,Int
0),R
0.5)]
        b :: Vector R
b = [R] -> Vector R
vect [R
3,R
4]
        d6 :: Vector R
d6 = Matrix R -> Vector R
forall t. Element t => Matrix t -> Vector t
flatten (Matrix R -> Vector R) -> Matrix R -> Vector R
forall a b. (a -> b) -> a -> b
$ Matrix R -> Matrix R -> Matrix R
forall t. Field t => Matrix t -> Matrix t -> Matrix t
linearSolve (AssocMatrix -> Matrix R
toDense [((Int, Int), R)]
AssocMatrix
symassoc) (Vector R -> Matrix R
forall a. Storable a => Vector a -> Matrix a
asColumn Vector R
b)
        s6 :: Vector R
s6 = Bool -> GMatrix -> Vector R -> Vector R
cgSolve Bool
True (AssocMatrix -> GMatrix
mkSparse [((Int, Int), R)]
AssocMatrix
symassoc) Vector R
b

        info :: IO ()
info = do
            GMatrix -> IO ()
forall a. Show a => a -> IO ()
print GMatrix
sm
            Matrix R -> IO ()
disp (AssocMatrix -> Matrix R
toDense [((Int, Int), R)]
AssocMatrix
sma)
            Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
s1; Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
d1
            Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
s2; Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
d2
            Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
s3; Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
d3
            Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
s4; Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
d4
            Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
s5; Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
d5
            R -> IO ()
forall a. Show a => a -> IO ()
print (R -> IO ()) -> R -> IO ()
forall a b. (a -> b) -> a -> b
$ (Vector R -> R) -> Vector R -> Vector R -> R
forall a. Num a => (a -> R) -> a -> a -> R
relativeError (NormType -> Vector R -> RealOf R
forall (c :: * -> *) t. Normed c t => NormType -> c t -> RealOf t
pnorm NormType
Infinity) Vector R
s5 Vector R
d5
            Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
s6; Vector R -> IO ()
forall a. Show a => a -> IO ()
print Vector R
d6
            R -> IO ()
forall a. Show a => a -> IO ()
print (R -> IO ()) -> R -> IO ()
forall a b. (a -> b) -> a -> b
$ (Vector R -> R) -> Vector R -> Vector R -> R
forall a. Num a => (a -> R) -> a -> a -> R
relativeError (NormType -> Vector R -> RealOf R
forall (c :: * -> *) t. Normed c t => NormType -> c t -> RealOf t
pnorm NormType
Infinity) Vector R
s6 Vector R
d6

        ok :: Bool
ok = Vector R
s1Vector R -> Vector R -> Bool
forall a. Eq a => a -> a -> Bool
==Vector R
d1
          Bool -> Bool -> Bool
&& Vector R
s2Vector R -> Vector R -> Bool
forall a. Eq a => a -> a -> Bool
==Vector R
d2
          Bool -> Bool -> Bool
&& Vector R
s3Vector R -> Vector R -> Bool
forall a. Eq a => a -> a -> Bool
==Vector R
d3
          Bool -> Bool -> Bool
&& Vector R
s4Vector R -> Vector R -> Bool
forall a. Eq a => a -> a -> Bool
==Vector R
d4
          Bool -> Bool -> Bool
&& (Vector R -> R) -> Vector R -> Vector R -> R
forall a. Num a => (a -> R) -> a -> a -> R
relativeError (NormType -> Vector R -> RealOf R
forall (c :: * -> *) t. Normed c t => NormType -> c t -> RealOf t
pnorm NormType
Infinity) Vector R
s5 Vector R
d5 R -> R -> Bool
forall a. Ord a => a -> a -> Bool
< R
1E-10
          Bool -> Bool -> Bool
&& (Vector R -> R) -> Vector R -> Vector R -> R
forall a. Num a => (a -> R) -> a -> a -> R
relativeError (NormType -> Vector R -> RealOf R
forall (c :: * -> *) t. Normed c t => NormType -> c t -> RealOf t
pnorm NormType
Infinity) Vector R
s6 Vector R
d6 R -> R -> Bool
forall a. Ord a => a -> a -> Bool
< R
1E-10

        disp :: Matrix R -> IO ()
disp = String -> IO ()
putStr (String -> IO ()) -> (Matrix R -> String) -> Matrix R -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Matrix R -> String
dispf Int
2

        vect :: [R] -> Vector R
vect = [R] -> Vector R
forall a. Storable a => [a] -> Vector a
fromList :: [Double] -> Vector Double

        convomat :: Int -> Int -> AssocMatrix
        convomat :: Int -> Int -> AssocMatrix
convomat Int
n Int
k = [ ((Int
i,Int
j Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
n),R
1) | Int
i<-[Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1], Int
j <- [Int
i..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]

        convo2 :: Int -> Int -> AssocMatrix
        convo2 :: Int -> Int -> AssocMatrix
convo2 Int
n Int
k = [((Int, Int), R)]
m1 [((Int, Int), R)] -> [((Int, Int), R)] -> [((Int, Int), R)]
forall a. [a] -> [a] -> [a]
++ [((Int, Int), R)]
m2
          where
            m1 :: AssocMatrix
m1 = Int -> Int -> AssocMatrix
convomat Int
n Int
k
            m2 :: [((Int, Int), R)]
m2 = (((Int, Int), R) -> ((Int, Int), R))
-> [((Int, Int), R)] -> [((Int, Int), R)]
forall a b. (a -> b) -> [a] -> [b]
map (((Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
n) (Int -> Int) -> (Int -> Int) -> (Int, Int) -> (Int, Int)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** Int -> Int
forall a. a -> a
id) ((Int, Int) -> (Int, Int))
-> (R -> R) -> ((Int, Int), R) -> ((Int, Int), R)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** R -> R
forall a. a -> a
id) [((Int, Int), R)]
m1

        testb :: Int -> Vector R
testb Int
n = [R] -> Vector R
vect ([R] -> Vector R) -> [R] -> Vector R
forall a b. (a -> b) -> a -> b
$ Int -> [R] -> [R]
forall a. Int -> [a] -> [a]
take Int
n ([R] -> [R]) -> [R] -> [R]
forall a b. (a -> b) -> a -> b
$ [R] -> [R]
forall a. [a] -> [a]
cycle ([R
0..R
10][R] -> [R] -> [R]
forall a. [a] -> [a] -> [a]
++[R
9,R
8..R
1])

        denseSolve :: Matrix t -> Vector t -> Vector t
denseSolve Matrix t
a = Matrix t -> Vector t
forall t. Element t => Matrix t -> Vector t
flatten (Matrix t -> Vector t)
-> (Vector t -> Matrix t) -> Vector t -> Vector t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> Matrix t -> Matrix t
forall t. Field t => Matrix t -> Matrix t -> Matrix t
linearSolveLS Matrix t
a (Matrix t -> Matrix t)
-> (Vector t -> Matrix t) -> Vector t -> Matrix t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector t -> Matrix t
forall a. Storable a => Vector a -> Matrix a
asColumn

        -- mkDiag v = mkDiagR (dim v) (dim v) v