{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE CPP #-}

{- | 
Module      :  Physics.Learn.QuantumMat
Copyright   :  (c) Scott N. Walck 2016-2018
License     :  BSD3 (see LICENSE)
Maintainer  :  Scott N. Walck <walck@lvc.edu>
Stability   :  experimental

This module contains state vectors and matrices
for quantum mechanics.
-}

-- Using only Complex Double here, no cyclotomic

module Physics.Learn.QuantumMat
    (
    -- * Complex numbers
      C
    -- * State Vectors
    , xp
    , xm
    , yp
    , ym
    , zp
    , zm
    , np
    , nm
    , dim
    , scaleV
    , inner
    , norm
    , normalize
    , probVector
    , gramSchmidt
    , conjV
    , fromList
    , toList
    -- * Matrices (operators)
    , sx
    , sy
    , sz
    , scaleM
    , (<>)
    , (#>)
    , (<#)
    , conjugateTranspose
    , fromLists
    , toLists
    , size
    , matrixFunction
    -- * Density matrices
    , couter
    , dm
    , trace
    , normalizeDM
    , oneQubitMixed
    -- * Quantum Dynamics
    , timeEvMat
    , timeEv
    , timeEvMatSpec
    -- * Composition
    , Kronecker(..)
    -- * Measurement
    , possibleOutcomes
    , outcomesProjectors
    , outcomesProbabilities
    -- * Vector and Matrix
    , Vector
    , Matrix
    )
    where

import Numeric.LinearAlgebra
    ( C
    , Vector
    , Matrix
    , Herm
    , iC        -- square root of negative one
    , (><)      -- matrix definition
    , ident
    , scale
    , norm_2
    , inv
    , (<\>)
    , sym
    , eigenvaluesSH
    , eigSH
    , cmap
    , takeDiag
    , conj
    , dot
    , tr
    )
--    , (<>)      -- matrix product (not * !!!!)
--    , (#>)      -- matrix-vector product
--    , fromList  -- vector definition

import qualified Numeric.LinearAlgebra as H
-- because H.outer does not conjugate
import Data.Complex
    ( Complex(..)
    , magnitude
    , realPart
    )
#if MIN_VERSION_base(4,11,0)
import Prelude hiding ((<>))
#endif

-- | The state resulting from a measurement of
--   spin angular momentum in the x direction
--   on a spin-1/2 particle
--   when the result of the measurement is hbar/2.
xp :: Vector C
xp :: Vector C
xp = Vector C -> Vector C
normalize forall a b. (a -> b) -> a -> b
$ [C] -> Vector C
fromList [C
1, C
1]

-- | The state resulting from a measurement of
--   spin angular momentum in the x direction
--   on a spin-1/2 particle
--   when the result of the measurement is -hbar/2.
xm :: Vector C
xm :: Vector C
xm = Vector C -> Vector C
normalize forall a b. (a -> b) -> a -> b
$ [C] -> Vector C
fromList [C
1, -C
1]

-- | The state resulting from a measurement of
--   spin angular momentum in the y direction
--   on a spin-1/2 particle
--   when the result of the measurement is hbar/2.
yp :: Vector C
yp :: Vector C
yp = Vector C -> Vector C
normalize forall a b. (a -> b) -> a -> b
$ [C] -> Vector C
fromList [C
1, C
iC]

-- | The state resulting from a measurement of
--   spin angular momentum in the y direction
--   on a spin-1/2 particle
--   when the result of the measurement is -hbar/2.
ym :: Vector C
ym :: Vector C
ym = Vector C -> Vector C
normalize forall a b. (a -> b) -> a -> b
$ [C] -> Vector C
fromList [C
1, -C
iC]

-- | The state resulting from a measurement of
--   spin angular momentum in the z direction
--   on a spin-1/2 particle
--   when the result of the measurement is hbar/2.
zp :: Vector C
zp :: Vector C
zp = Vector C -> Vector C
normalize forall a b. (a -> b) -> a -> b
$ [C] -> Vector C
fromList [C
1, C
0]

-- | The state resulting from a measurement of
--   spin angular momentum in the z direction
--   on a spin-1/2 particle
--   when the result of the measurement is -hbar/2.
zm :: Vector C
zm :: Vector C
zm = Vector C -> Vector C
normalize forall a b. (a -> b) -> a -> b
$ [C] -> Vector C
fromList [C
0, C
1]

-- | The state resulting from a measurement of
--   spin angular momentum in the direction
--   specified by spherical angles theta (polar angle)
--   and phi (azimuthal angle)
--   on a spin-1/2 particle
--   when the result of the measurement is hbar/2.
np :: Double -> Double -> Vector C
np :: Double -> Double -> Vector C
np Double
theta Double
phi = [C] -> Vector C
fromList [ forall a. Floating a => a -> a
cos (Double
thetaforall a. Fractional a => a -> a -> a
/Double
2) forall a. a -> a -> Complex a
:+ Double
0
                        , forall a. Floating a => a -> a
exp(Double
0 forall a. a -> a -> Complex a
:+ Double
phi) forall a. Num a => a -> a -> a
* (forall a. Floating a => a -> a
sin (Double
thetaforall a. Fractional a => a -> a -> a
/Double
2) forall a. a -> a -> Complex a
:+ Double
0) ]

-- | The state resulting from a measurement of
--   spin angular momentum in the direction
--   specified by spherical angles theta (polar angle)
--   and phi (azimuthal angle)
--   on a spin-1/2 particle
--   when the result of the measurement is -hbar/2.
nm :: Double -> Double -> Vector C
nm :: Double -> Double -> Vector C
nm Double
theta Double
phi = [C] -> Vector C
fromList [ forall a. Floating a => a -> a
sin (Double
thetaforall a. Fractional a => a -> a -> a
/Double
2) forall a. a -> a -> Complex a
:+ Double
0
                        , -forall a. Floating a => a -> a
exp(Double
0 forall a. a -> a -> Complex a
:+ Double
phi) forall a. Num a => a -> a -> a
* (forall a. Floating a => a -> a
cos (Double
thetaforall a. Fractional a => a -> a -> a
/Double
2) forall a. a -> a -> Complex a
:+ Double
0) ]

-- | Dimension of a vector.
dim :: Vector C -> Int
dim :: Vector C -> Int
dim = forall (c :: * -> *) t. Container c t => c t -> IndexOf c
H.size

-- | Scale a complex vector by a complex number.
scaleV :: C -> Vector C -> Vector C
scaleV :: C -> Vector C -> Vector C
scaleV = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale

-- | Complex inner product.  First vector gets conjugated.
inner :: Vector C -> Vector C -> C
inner :: Vector C -> Vector C -> C
inner = forall t. Numeric t => Vector t -> Vector t -> t
dot

-- | Length of a complex vector.
norm :: Vector C -> Double
norm :: Vector C -> Double
norm = forall a. Normed a => a -> Double
norm_2

-- | Return a normalized version of a given state vector.
normalize :: Vector C -> Vector C
normalize :: Vector C -> Vector C
normalize Vector C
v = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Double
1 forall a. Fractional a => a -> a -> a
/ forall a. Normed a => a -> Double
norm_2 Vector C
v forall a. a -> a -> Complex a
:+ Double
0) Vector C
v

-- | Return a vector of probabilities for a given state vector.
probVector :: Vector C       -- ^ state vector
           -> Vector Double  -- ^ vector of probabilities
probVector :: Vector C -> Vector Double
probVector = forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
cmap (\C
c -> forall a. RealFloat a => Complex a -> a
magnitude C
cforall a. Floating a => a -> a -> a
**Double
2)

-- | Conjugate the entries of a vector.
conjV :: Vector C -> Vector C
conjV :: Vector C -> Vector C
conjV = forall (c :: * -> *) e. Container c e => c e -> c e
conj

-- | Construct a vector from a list of complex numbers.
fromList :: [C] -> Vector C
fromList :: [C] -> Vector C
fromList = forall a. Storable a => [a] -> Vector a
H.fromList

-- | Produce a list of complex numbers from a vector.
toList :: Vector C -> [C]
toList :: Vector C -> [C]
toList = forall a. Storable a => Vector a -> [a]
H.toList

--------------
-- Matrices --
--------------

-- | The Pauli X matrix.
sx :: Matrix C
sx :: Matrix C
sx = (Int
2forall a. Storable a => Int -> Int -> [a] -> Matrix a
><Int
2) [ C
0, C
1
            , C
1, C
0 ]

-- | The Pauli Y matrix.
sy :: Matrix C
sy :: Matrix C
sy = (Int
2forall a. Storable a => Int -> Int -> [a] -> Matrix a
><Int
2) [  C
0, -C
iC
            , C
iC,   C
0 ]

-- | The Pauli Z matrix.
sz :: Matrix C
sz :: Matrix C
sz = (Int
2forall a. Storable a => Int -> Int -> [a] -> Matrix a
><Int
2) [ C
1,  C
0
            , C
0, -C
1 ]

-- | Scale a complex matrix by a complex number.
scaleM :: C -> Matrix C -> Matrix C
scaleM :: C -> Matrix C -> Matrix C
scaleM = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale

-- | Matrix product.
(<>) :: Matrix C -> Matrix C -> Matrix C
<> :: Matrix C -> Matrix C -> Matrix C
(<>) = forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
(H.<>)

-- | Matrix-vector product.
(#>) :: Matrix C -> Vector C -> Vector C
#> :: Matrix C -> Vector C -> Vector C
(#>) = forall t. Numeric t => Matrix t -> Vector t -> Vector t
(H.#>)

-- | Vector-matrix product
(<#) :: Vector C -> Matrix C -> Vector C
<# :: Vector C -> Matrix C -> Vector C
(<#) = forall t. Numeric t => Vector t -> Matrix t -> Vector t
(H.<#)

-- | Conjugate transpose of a matrix.
conjugateTranspose :: Matrix C -> Matrix C
conjugateTranspose :: Matrix C -> Matrix C
conjugateTranspose = forall m mt. Transposable m mt => m -> mt
tr

-- | Construct a matrix from a list of lists of complex numbers.
fromLists :: [[C]] -> Matrix C
fromLists :: [[C]] -> Matrix C
fromLists = forall t. Element t => [[t]] -> Matrix t
H.fromLists

-- | Produce a list of lists of complex numbers from a matrix.
toLists :: Matrix C -> [[C]]
toLists :: Matrix C -> [[C]]
toLists = forall t. Element t => Matrix t -> [[t]]
H.toLists

-- | Size of a matrix.
size :: Matrix C -> (Int,Int)
size :: Matrix C -> (Int, Int)
size = forall (c :: * -> *) t. Container c t => c t -> IndexOf c
H.size

-- | Apply a function to a matrix.
--   Assumes the matrix is a normal matrix (a matrix
--   with an orthonormal basis of eigenvectors).
matrixFunction :: (C -> C) -> Matrix C -> Matrix C
matrixFunction :: (C -> C) -> Matrix C -> Matrix C
matrixFunction C -> C
f Matrix C
m
    = let (Vector C
valv,Matrix C
vecm) = forall t. Field t => Matrix t -> (Vector C, Matrix C)
H.eig Matrix C
m
          fvalv :: Vector C
fvalv = [C] -> Vector C
fromList [C -> C
f C
val | C
val <- Vector C -> [C]
toList Vector C
valv]
      in Matrix C
vecm Matrix C -> Matrix C -> Matrix C
<> forall a. (Num a, Element a) => Vector a -> Matrix a
H.diag Vector C
fvalv Matrix C -> Matrix C -> Matrix C
<> forall m mt. Transposable m mt => m -> mt
tr Matrix C
vecm

----------------------
-- Density Matrices --
----------------------

-- | Complex outer product
couter :: Vector C -> Vector C -> Matrix C
couter :: Vector C -> Vector C -> Matrix C
couter Vector C
v Vector C
w = Vector C
v forall t. Product t => Vector t -> Vector t -> Matrix t
`H.outer` forall (c :: * -> *) e. Container c e => c e -> c e
conj Vector C
w

-- | Build a pure-state density matrix from a state vector.
dm :: Vector C -> Matrix C
dm :: Vector C -> Matrix C
dm Vector C
cvec = Vector C
cvec Vector C -> Vector C -> Matrix C
`couter` Vector C
cvec

-- | Trace of a matrix.
trace :: Matrix C -> C
trace :: Matrix C -> C
trace = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector C -> [C]
toList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Element t => Matrix t -> Vector t
takeDiag

-- | Normalize a density matrix so that it has trace one.
normalizeDM :: Matrix C -> Matrix C
normalizeDM :: Matrix C -> Matrix C
normalizeDM Matrix C
rho = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (C
1 forall a. Fractional a => a -> a -> a
/ Matrix C -> C
trace Matrix C
rho) Matrix C
rho

-- | The one-qubit totally mixed state.
oneQubitMixed :: Matrix C
oneQubitMixed :: Matrix C
oneQubitMixed = Matrix C -> Matrix C
normalizeDM forall a b. (a -> b) -> a -> b
$ forall a. (Num a, Element a) => Int -> Matrix a
ident Int
2

----------------------
-- Quantum Dynamics --
----------------------

-- | Given a time step and a Hamiltonian matrix,
--   produce a unitary time evolution matrix.
--   Unless you really need the time evolution matrix,
--   it is better to use 'timeEv', which gives the
--   same numerical results without doing an explicit
--   matrix inversion.  The function assumes hbar = 1.
timeEvMat :: Double -> Matrix C -> Matrix C
timeEvMat :: Double -> Matrix C -> Matrix C
timeEvMat Double
dt Matrix C
h
    = let ah :: Matrix C
ah = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Double
0 forall a. a -> a -> Complex a
:+ Double
dt forall a. Fractional a => a -> a -> a
/ Double
2) Matrix C
h
          (Int
l,Int
m) = Matrix C -> (Int, Int)
size Matrix C
h
          n :: Int
n = if Int
l forall a. Eq a => a -> a -> Bool
== Int
m then Int
m else forall a. HasCallStack => [Char] -> a
error [Char]
"timeEv needs square Hamiltonian"
          identity :: Matrix C
identity = forall a. (Num a, Element a) => Int -> Matrix a
ident Int
n
      in forall t. Field t => Matrix t -> Matrix t
inv (Matrix C
identity forall a. Num a => a -> a -> a
+ Matrix C
ah) Matrix C -> Matrix C -> Matrix C
<> (Matrix C
identity forall a. Num a => a -> a -> a
- Matrix C
ah)

-- | Given a time step and a Hamiltonian matrix,
--   advance the state vector using the Schrodinger equation.
--   This method should be faster than using 'timeEvMat'
--   since it solves a linear system rather than calculating
--   an inverse matrix.  The function assumes hbar = 1.
timeEv :: Double -> Matrix C -> Vector C -> Vector C
timeEv :: Double -> Matrix C -> Vector C -> Vector C
timeEv Double
dt Matrix C
h Vector C
v
    = let ah :: Matrix C
ah = forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Double
0 forall a. a -> a -> Complex a
:+ Double
dt forall a. Fractional a => a -> a -> a
/ Double
2) Matrix C
h
          (Int
l,Int
m) = Matrix C -> (Int, Int)
size Matrix C
h
          n :: Int
n = if Int
l forall a. Eq a => a -> a -> Bool
== Int
m then Int
m else forall a. HasCallStack => [Char] -> a
error [Char]
"timeEv needs square Hamiltonian"
          identity :: Matrix C
identity = forall a. (Num a, Element a) => Int -> Matrix a
ident Int
n
      in (Matrix C
identity forall a. Num a => a -> a -> a
+ Matrix C
ah) forall (c :: * -> *) t.
(LSDiv c, Field t) =>
Matrix t -> c t -> c t
<\> ((Matrix C
identity forall a. Num a => a -> a -> a
- Matrix C
ah) Matrix C -> Vector C -> Vector C
#> Vector C
v)

-- | Given a Hamiltonian matrix, return a function from time
--   to evolution matrix.  Uses spectral decomposition.
--   Assumes hbar = 1.
timeEvMatSpec :: Matrix C -> Double -> Matrix C
timeEvMatSpec :: Matrix C -> Double -> Matrix C
timeEvMatSpec Matrix C
m Double
t = (C -> C) -> Matrix C -> Matrix C
matrixFunction (\C
h -> forall a. Floating a => a -> a
exp(-C
iC forall a. Num a => a -> a -> a
* C
h forall a. Num a => a -> a -> a
* (Double
t forall a. a -> a -> Complex a
:+ Double
0))) Matrix C
m

-----------------
-- Composition --
-----------------

class Kronecker a where
    kron :: a -> a -> a

instance H.Product t => Kronecker (Vector t) where
    kron :: Vector t -> Vector t -> Vector t
kron Vector t
v1 Vector t
v2 = forall a. Storable a => [a] -> Vector a
H.fromList [t
c1 forall a. Num a => a -> a -> a
* t
c2 | t
c1 <- forall a. Storable a => Vector a -> [a]
H.toList Vector t
v1, t
c2 <- forall a. Storable a => Vector a -> [a]
H.toList Vector t
v2]

instance H.Product t => Kronecker (Matrix t) where
    kron :: Matrix t -> Matrix t -> Matrix t
kron = forall t. Product t => Matrix t -> Matrix t -> Matrix t
H.kronecker

-----------------
-- Measurement --
-----------------

-- | The possible outcomes of a measurement
--   of an observable.
--   These are the eigenvalues of the matrix
--   of the observable.
possibleOutcomes :: Matrix C -> [Double]
possibleOutcomes :: Matrix C -> [Double]
possibleOutcomes Matrix C
observable
    = forall a. Storable a => Vector a -> [a]
H.toList forall a b. (a -> b) -> a -> b
$ forall t. Field t => Herm t -> Vector Double
eigenvaluesSH (forall t. Field t => Matrix t -> Herm t
sym Matrix C
observable)

-- From a Hermitian matrix, a list of pairs of eigenvalues and eigenvectors.
valsVecs :: Herm C -> [(Double,Vector C)]
valsVecs :: Herm C -> [(Double, Vector C)]
valsVecs Herm C
h = let (Vector Double
valv,Matrix C
m) = forall t. Field t => Herm t -> (Vector Double, Matrix t)
eigSH Herm C
h
                 vals :: [Double]
vals = forall a. Storable a => Vector a -> [a]
H.toList Vector Double
valv
                 vecs :: [Vector C]
vecs = forall a b. (a -> b) -> [a] -> [b]
map (Vector C -> Vector C
conjV forall b c a. (b -> c) -> (a -> b) -> a -> c
. [C] -> Vector C
fromList) forall a b. (a -> b) -> a -> b
$ Matrix C -> [[C]]
toLists (Matrix C -> Matrix C
conjugateTranspose Matrix C
m)
             in forall a b. [a] -> [b] -> [(a, b)]
zip [Double]
vals [Vector C]
vecs

-- From a Hermitian matrix, a list of pairs of eigenvalues and projectors.
valsPs :: Herm C -> [(Double,Matrix C)]
valsPs :: Herm C -> [(Double, Matrix C)]
valsPs Herm C
h = [(Double
val,Vector C -> Vector C -> Matrix C
couter Vector C
vec Vector C
vec) | (Double
val,Vector C
vec) <- Herm C -> [(Double, Vector C)]
valsVecs Herm C
h]

combineFst :: (Eq a, Num b) => [(a,b)] -> [(a,b)]
combineFst :: forall a b. (Eq a, Num b) => [(a, b)] -> [(a, b)]
combineFst [] = []
combineFst [(a, b)
p] = [(a, b)
p]
combineFst ((a
x1,b
m1):(a
x2,b
m2):[(a, b)]
ps)
    = if a
x1 forall a. Eq a => a -> a -> Bool
== a
x2
      then forall a b. (Eq a, Num b) => [(a, b)] -> [(a, b)]
combineFst ((a
x1,b
m1forall a. Num a => a -> a -> a
+b
m2)forall a. a -> [a] -> [a]
:[(a, b)]
ps)
      else (a
x1,b
m1)forall a. a -> [a] -> [a]
:forall a b. (Eq a, Num b) => [(a, b)] -> [(a, b)]
combineFst ((a
x2,b
m2)forall a. a -> [a] -> [a]
:[(a, b)]
ps)

-- | Given an obervable, return a list of pairs
--   of possible outcomes and projectors
--   for each outcome.
outcomesProjectors :: Matrix C -> [(Double,Matrix C)]
outcomesProjectors :: Matrix C -> [(Double, Matrix C)]
outcomesProjectors Matrix C
m = forall a b. (Eq a, Num b) => [(a, b)] -> [(a, b)]
combineFst (Herm C -> [(Double, Matrix C)]
valsPs (forall t. Field t => Matrix t -> Herm t
sym Matrix C
m))

-- | Given an observable and a state vector, return a list of pairs
--   of possible outcomes and probabilites
--   for each outcome.
outcomesProbabilities :: Matrix C -> Vector C -> [(Double,Double)]
outcomesProbabilities :: Matrix C -> Vector C -> [(Double, Double)]
outcomesProbabilities Matrix C
m Vector C
v
    = [(Double
a,forall a. Complex a -> a
realPart (Vector C -> Vector C -> C
inner Vector C
v (Matrix C
p Matrix C -> Vector C -> Vector C
#> Vector C
v))) | (Double
a,Matrix C
p) <- Matrix C -> [(Double, Matrix C)]
outcomesProjectors Matrix C
m]

------------------
-- Gram-Schmidt --
------------------

-- | Form an orthonormal list of complex vectors
--   from a linearly independent list of complex vectors.
gramSchmidt :: [Vector C] -> [Vector C]
gramSchmidt :: [Vector C] -> [Vector C]
gramSchmidt [] = []
gramSchmidt (Vector C
v:[Vector C]
vs) = let nvs :: [Vector C]
nvs = [Vector C] -> [Vector C]
gramSchmidt [Vector C]
vs
                         nv :: Vector C
nv = Vector C -> Vector C
normalize (Vector C
v forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Vector C -> Vector C -> C
inner Vector C
w Vector C
v) Vector C
w | Vector C
w <- [Vector C]
nvs])
                     in Vector C
nvforall a. a -> [a] -> [a]
:[Vector C]
nvs

-- To Do
--   Generate higher spin operators and state vectors
--   eigenvectors
--   projection operators