{-# OPTIONS_GHC -Wall #-}
{-# LANGUAGE Trustworthy #-}

{- | 
Module      :  Physics.Learn.Schrodinger1D
Copyright   :  (c) Scott N. Walck 2015-2018
License     :  BSD3 (see LICENSE)
Maintainer  :  Scott N. Walck <walck@lvc.edu>
Stability   :  experimental

This module contains functions to
solve the (time dependent) Schrodinger equation
in one spatial dimension for a given potential function.
-}

module Physics.Learn.Schrodinger1D
    (
    -- * Potentials
      freeV
    , harmonicV
    , squareWell
    , doubleWell
    , stepV
    , wall
    -- * Initial wavefunctions
    --    , harm
    , coherent
    , gaussian
    , movingGaussian
    -- * Utilities
    , stateVectorFromWavefunction
    , hamiltonianMatrix
    , expectX
    , picture
    , xRange
    , listForm
    )
    where

import Data.Complex
    ( Complex(..)
    , magnitude
    )
import Graphics.Gloss
    ( Picture(..)
    , yellow
    , black
    , Display(..)
    , display
    )
-- import Math.Polynomial.Hermite
--     ( evalPhysHermite
--     )
import Numeric.LinearAlgebra
    ( R
    , C
    , Vector
    , Matrix
    , (|>)
    , (<.>)
    , fromLists
    , toList
    , size
    )
import Physics.Learn.QuantumMat
    ( probVector
    , timeEv
    )

--i :: Complex Double
--i = 0 :+ 1

----------------
-- Potentials --
----------------

-- | Free potential.
--   The potential energy is zero everywhere.
freeV
    :: Double  -- ^ position
    -> Double  -- ^ potential energy
freeV :: R -> R
freeV R
_x = R
0

-- | Harmonic potential.
--   This is the potential energy of a linear spring.
harmonicV
    :: Double  -- ^ spring constant
    -> Double  -- ^ position
    -> Double  -- ^ potential energy
harmonicV :: R -> R -> R
harmonicV R
k R
x = R
k forall a. Num a => a -> a -> a
* R
xforall a. Floating a => a -> a -> a
**R
2 forall a. Fractional a => a -> a -> a
/ R
2

-- | A double well potential.
--   Potential energy is a quartic function of position
--   that gives two wells, each approximately harmonic
--   at the bottom of the well.
doubleWell
    :: Double  -- ^ width (for both wells and well separation)
    -> Double  -- ^ energy height of barrier between wells
    -> Double  -- ^ position
    -> Double  -- ^ potential energy
doubleWell :: R -> R -> R -> R
doubleWell R
a R
v0 R
x = R
v0 forall a. Num a => a -> a -> a
* ((R
xforall a. Floating a => a -> a -> a
**R
2 forall a. Num a => a -> a -> a
- R
aforall a. Floating a => a -> a -> a
**R
2)forall a. Fractional a => a -> a -> a
/R
aforall a. Floating a => a -> a -> a
**R
2)forall a. Floating a => a -> a -> a
**R
2

-- | Finite square well potential.
--   Potential is zero inside the well,
--   and constant outside the well.
--   Well is centered at the origin.
squareWell
    :: Double  -- ^ well width
    -> Double  -- ^ energy height of well
    -> Double  -- ^ position
    -> Double  -- ^ potential energy
squareWell :: R -> R -> R -> R
squareWell R
l R
v0 R
x
    | forall a. Num a => a -> a
abs R
x forall a. Ord a => a -> a -> Bool
< R
lforall a. Fractional a => a -> a -> a
/R
2  = R
0
    | Bool
otherwise    = R
v0

-- | A step barrier potential.
--   Potential is zero to left of origin.
stepV
    :: Double  -- ^ energy height of barrier (to the right of origin)
    -> Double  -- ^ position
    -> Double  -- ^ potential energy
stepV :: R -> R -> R
stepV R
v0 R
x
    | R
x forall a. Ord a => a -> a -> Bool
< R
0      = R
0
    | Bool
otherwise  = R
v0

-- | A potential barrier with thickness and height.
wall
    :: Double  -- ^ thickness of wall
    -> Double  -- ^ energy height of barrier
    -> Double  -- ^ position of center of barrier
    -> Double  -- ^ position
    -> Double  -- ^ potential energy
wall :: R -> R -> R -> R -> R
wall R
w R
v0 R
x0 R
x
    | forall a. Num a => a -> a
abs (R
xforall a. Num a => a -> a -> a
-R
x0) forall a. Ord a => a -> a -> Bool
< R
wforall a. Fractional a => a -> a -> a
/R
2  = R
v0
    | Bool
otherwise         = R
0

---------------------------
-- Initial wavefunctions --
---------------------------

-- -- | Harmonic oscillator stationary state
-- harm :: Int          -- ^ nonnegative integer n identifying stationary state
--      -> Double       -- ^ x / sqrt(hbar/(m * omega)), i.e. position
--                      --   in units of sqrt(hbar/(m * omega))
--      -> C            -- ^ complex amplitude
-- harm n u
--     = exp (-u**2/2) * evalPhysHermite n u / sqrt (2^n * fact n * sqrt pi) :+ 0

coherent
    :: R       -- ^ length scale = sqrt(hbar / m omega)
    -> C       -- ^ parameter z
    -> R -> C  -- ^ wavefunction
coherent :: R -> Complex R -> R -> Complex R
coherent R
l Complex R
z R
x
    = ((R
1forall a. Fractional a => a -> a -> a
/(forall a. Floating a => a
piforall a. Num a => a -> a -> a
*R
lforall a. Floating a => a -> a -> a
**R
2))forall a. Floating a => a -> a -> a
**R
0.25 forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
exp(-R
xforall a. Floating a => a -> a -> a
**R
2forall a. Fractional a => a -> a -> a
/(R
2forall a. Num a => a -> a -> a
*R
lforall a. Floating a => a -> a -> a
**R
2)) forall a. a -> a -> Complex a
:+ R
0)
      forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
exp(-Complex R
zforall a. Floating a => a -> a -> a
**Complex R
2forall a. Fractional a => a -> a -> a
/Complex R
2 forall a. Num a => a -> a -> a
+ (forall a. Floating a => a -> a
sqrt(R
2forall a. Fractional a => a -> a -> a
/R
lforall a. Floating a => a -> a -> a
**R
2) forall a. Num a => a -> a -> a
* R
x forall a. a -> a -> Complex a
:+ R
0) forall a. Num a => a -> a -> a
* Complex R
z)

gaussian
    :: R       -- ^ width parameter
    -> R       -- ^ center of wave packet
    -> R -> C  -- ^ wavefunction
gaussian :: R -> R -> R -> Complex R
gaussian R
a R
x0 R
x = forall a. Floating a => a -> a
exp(-(R
xforall a. Num a => a -> a -> a
-R
x0)forall a. Floating a => a -> a -> a
**R
2forall a. Fractional a => a -> a -> a
/(R
2forall a. Num a => a -> a -> a
*R
aforall a. Floating a => a -> a -> a
**R
2)) forall a. Fractional a => a -> a -> a
/ forall a. Floating a => a -> a
sqrt(R
a forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt forall a. Floating a => a
pi) forall a. a -> a -> Complex a
:+ R
0

movingGaussian
    :: R       -- ^ width parameter
    -> R       -- ^ center of wave packet
    -> R       -- ^ l0 = hbar / p0
    -> R -> C  -- ^ wavefunction
movingGaussian :: R -> R -> R -> R -> Complex R
movingGaussian R
a R
x0 R
l0 R
x = forall a. Floating a => a -> a
exp((R
0 forall a. a -> a -> Complex a
:+ R
xforall a. Fractional a => a -> a -> a
/R
l0) forall a. Num a => a -> a -> a
- ((R
xforall a. Num a => a -> a -> a
-R
x0)forall a. Floating a => a -> a -> a
**R
2forall a. Fractional a => a -> a -> a
/(R
2forall a. Num a => a -> a -> a
*R
aforall a. Floating a => a -> a -> a
**R
2) forall a. a -> a -> Complex a
:+ R
0)) forall a. Fractional a => a -> a -> a
/ (forall a. Floating a => a -> a
sqrt(R
a forall a. Num a => a -> a -> a
* forall a. Floating a => a -> a
sqrt forall a. Floating a => a
pi) forall a. a -> a -> Complex a
:+ R
0)

---------------
-- Utilities --
---------------

fact :: Int -> Double
fact :: Int -> R
fact Int
0 = R
1
fact Int
n = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n forall a. Num a => a -> a -> a
* Int -> R
fact (Int
nforall a. Num a => a -> a -> a
-Int
1)

linspace :: Double -> Double -> Int -> [Double]
linspace :: R -> R -> Int -> [R]
linspace R
left R
right Int
num
    = let dx :: R
dx = (R
right forall a. Num a => a -> a -> a
- R
left) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
num forall a. Num a => a -> a -> a
- Int
1)
      in [ R
left forall a. Num a => a -> a -> a
+ R
dx forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n | Int
n <- [Int
0..Int
numforall a. Num a => a -> a -> a
-Int
1]]

-- | Transform a wavefunction into a state vector.
stateVectorFromWavefunction :: R         -- ^ lowest x
                            -> R         -- ^ highest x
                            -> Int       -- ^ dimension of state vector
                            -> (R -> C)  -- ^ wavefunction
                            -> Vector C  -- ^ state vector
stateVectorFromWavefunction :: R -> R -> Int -> (R -> Complex R) -> Vector (Complex R)
stateVectorFromWavefunction R
left R
right Int
num R -> Complex R
psi
    = (Int
num forall a. Storable a => Int -> [a] -> Vector a
|>) [R -> Complex R
psi R
x | R
x <- R -> R -> Int -> [R]
linspace R
left R
right Int
num]

hamiltonianMatrix :: R         -- ^ lowest x
                  -> R         -- ^ highest x
                  -> Int       -- ^ dimension of state vector
                  -> R         -- ^ hbar
                  -> R         -- ^ mass
                  -> (R -> R)  -- ^ potential energy function
                  -> Matrix C  -- ^ Hamiltonian Matrix
hamiltonianMatrix :: R -> R -> Int -> R -> R -> (R -> R) -> Matrix (Complex R)
hamiltonianMatrix R
xmin R
xmax Int
num R
hbar R
m R -> R
pe
    = let coeff :: R
coeff = -R
hbarforall a. Floating a => a -> a -> a
**R
2forall a. Fractional a => a -> a -> a
/(R
2forall a. Num a => a -> a -> a
*R
m)
          dx :: R
dx = (R
xmax forall a. Num a => a -> a -> a
- R
xmin) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
num forall a. Num a => a -> a -> a
- Int
1)
          diagKEterm :: R
diagKEterm = -R
2 forall a. Num a => a -> a -> a
* R
coeff forall a. Fractional a => a -> a -> a
/ R
dxforall a. Floating a => a -> a -> a
**R
2
          offdiagKEterm :: R
offdiagKEterm = R
coeff forall a. Fractional a => a -> a -> a
/ R
dxforall a. Floating a => a -> a -> a
**R
2
          xs :: [R]
xs = R -> R -> Int -> [R]
linspace R
xmin R
xmax Int
num
      in forall t. Element t => [[t]] -> Matrix t
fromLists [[case forall a. Num a => a -> a
abs(Int
iforall a. Num a => a -> a -> a
-Int
j) of
                       Int
0  -> (R
diagKEterm forall a. Num a => a -> a -> a
+ R -> R
pe R
x) forall a. a -> a -> Complex a
:+ R
0
                       Int
1  -> R
offdiagKEterm forall a. a -> a -> Complex a
:+ R
0
                       Int
_  -> Complex R
0
                          | Int
j <- [Int
1..Int
num] ] | (Int
i,R
x) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Int
1..Int
num] [R]
xs]

expectX :: Vector C  -- ^ state vector
        -> Vector R  -- ^ vector of x values
        -> R         -- ^ <X>, expectation value of X
expectX :: Vector (Complex R) -> Vector R -> R
expectX Vector (Complex R)
psi Vector R
xs = Vector (Complex R) -> Vector R
probVector Vector (Complex R)
psi forall t. Numeric t => Vector t -> Vector t -> t
<.> Vector R
xs


glossScaleX :: Int -> (Double,Double) -> Double -> Float
glossScaleX :: Int -> (R, R) -> R -> Float
glossScaleX Int
screenWidth (R
xmin,R
xmax) R
x
    = let w :: R
w = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
screenWidth :: Double
      in forall a b. (Real a, Fractional b) => a -> b
realToFrac forall a b. (a -> b) -> a -> b
$ (R
x forall a. Num a => a -> a -> a
- R
xmin) forall a. Fractional a => a -> a -> a
/ (R
xmax forall a. Num a => a -> a -> a
- R
xmin) forall a. Num a => a -> a -> a
* R
w forall a. Num a => a -> a -> a
- R
w forall a. Fractional a => a -> a -> a
/ R
2

glossScaleY :: Int -> (Double,Double) -> Double -> Float
glossScaleY :: Int -> (R, R) -> R -> Float
glossScaleY Int
screenHeight (R
ymin,R
ymax) R
y
    = let h :: R
h = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
screenHeight :: Double
      in forall a b. (Real a, Fractional b) => a -> b
realToFrac forall a b. (a -> b) -> a -> b
$ (R
y forall a. Num a => a -> a -> a
- R
ymin) forall a. Fractional a => a -> a -> a
/ (R
ymax forall a. Num a => a -> a -> a
- R
ymin) forall a. Num a => a -> a -> a
* R
h forall a. Num a => a -> a -> a
- R
h forall a. Fractional a => a -> a -> a
/ R
2

glossScalePoint :: (Int,Int)        -- ^ (screenWidth,screenHeight)
                -> (Double,Double)  -- ^ (xmin,xmax)
                -> (Double,Double)  -- ^ (ymin,ymax)
                -> (Double,Double)  -- ^ (x,y)
                -> (Float,Float)
glossScalePoint :: (Int, Int) -> (R, R) -> (R, R) -> (R, R) -> (Float, Float)
glossScalePoint (Int
screenWidth,Int
screenHeight) (R, R)
xMinMax (R, R)
yMinMax (R
x,R
y)
    = (Int -> (R, R) -> R -> Float
glossScaleX Int
screenWidth  (R, R)
xMinMax R
x
      ,Int -> (R, R) -> R -> Float
glossScaleY Int
screenHeight (R, R)
yMinMax R
y)


-- | Produce a gloss 'Picture' of state vector
--   for 1D wavefunction.
picture :: (Double, Double)    -- ^ y range
        -> [Double]            -- ^ xs
        -> Vector C            -- ^ state vector
        -> Picture
picture :: (R, R) -> [R] -> Vector (Complex R) -> Picture
picture (R
ymin,R
ymax) [R]
xs Vector (Complex R)
psi
    = Color -> Picture -> Picture
Color
      Color
yellow
      (Path -> Picture
Line
       [(Int, Int) -> (R, R) -> (R, R) -> (R, R) -> (Float, Float)
glossScalePoint
        (Int
screenWidth,Int
screenHeight)
        (forall a. [a] -> a
head [R]
xs, forall a. [a] -> a
last [R]
xs)
        (R
ymin,R
ymax)
        (R, R)
p | (R, R)
p <- forall a b. [a] -> [b] -> [(a, b)]
zip [R]
xs (forall a b. (a -> b) -> [a] -> [b]
map Complex R -> R
magSq forall a b. (a -> b) -> a -> b
$ forall a. Storable a => Vector a -> [a]
toList Vector (Complex R)
psi)])
           where
             magSq :: Complex R -> R
magSq = \Complex R
z -> forall a. RealFloat a => Complex a -> a
magnitude Complex R
z forall a. Floating a => a -> a -> a
** R
2
             screenWidth :: Int
screenWidth  = Int
1000
             screenHeight :: Int
screenHeight =  Int
750

-- options for representing wave functions
-- 1.  A function R -> C
-- 2.  ([R],Vector C), where lengths match
-- 3.  [(R,C)]
-- 4.  (R,R,Vector C)  -- xmin, xmax, state vector (assumes even spacing)

-- 2,4 are best for evolution

listForm :: (R,R,Vector C) -> ([R],Vector C)
listForm :: (R, R, Vector (Complex R)) -> ([R], Vector (Complex R))
listForm (R
xmin,R
xmax,Vector (Complex R)
v)
    = let dt :: R
dt = (R
xmax forall a. Num a => a -> a -> a
- R
xmin) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (c :: * -> *) t. Container c t => c t -> IndexOf c
size Vector (Complex R)
v forall a. Num a => a -> a -> a
- Int
1)
      in ([R
xmin, R
xmin forall a. Num a => a -> a -> a
+ R
dt .. R
xmax],Vector (Complex R)
v)


{-
-- | Given an initial state vector and
--   state propagation function, produce a simulation.
--   The 'Float' in the state propagation function is the time
--   interval for one timestep.
simulate1D :: [Double] -> Vector C -> (Float -> (Float,[Double],Vector C) -> (Float,[Double],Vector C)) -> IO ()
simulate1D xs initial statePropFunc
    = simulate display black 10 (0,initial) displayFunc (const statePropFunc)
      where
        display = InWindow "Animation" (screenWidth,screenHeight) (10,10)
        displayFunc (_t,v) = Color yellow (Line [(
      
      white (\tFloat -> Pictures [Color blue (Line (points (realToFrac tFloat)))
                                 ,axes (screenWidth,screenHeight) (xmin,xmax) (ymin,ymax)])

-- | Produce a state propagation function from a time-dependent Hamiltonian.
--   The float is dt.
statePropGloss :: (Double -> Matrix C) -> Float -> (Float,Vector C) -> (Float,Vector C)
statePropGloss ham dt (tOld,v)
    = (tNew, timeEv (realToFrac dt) (ham tMid) v)
      where
        tNew = tOld + dt
        tMid = realToFrac $ (tNew + tOld) / 2

-- | Given an initial state vector and a time-dependent Hamiltonian,
--   produce a visualization of a 1D wavefunction.
evolutionBlochSphere :: Vector C -> (Double -> Matrix C) -> IO ()
evolutionBlochSphere psi0 ham
    = simulateBlochSphere 0.01 psi0 (stateProp ham)

-}


{-
def triDiagMatrixMult(square_arr,arr):
    num = len(arr)
    result = array([0 for n in range(num)],dtype=complex128)
    result[0] = square_arr[0][0] * arr[0] + square_arr[0][1] * arr[1]
    for n in range(1,num-1):
        result[n] = square_arr[n][n-1] * arr[n-1] + square_arr[n][n] * arr[n] \
            + square_arr[n][n+1] * arr[n+1]
    result[num-1] = square_arr[num-1][num-2] * arr[num-2] \
        + square_arr[num-1][num-1] * arr[num-1]
    return result
-}

------------------
-- Main program --
------------------

-- n is number of points
-- n-1 is number of intervals
xRange :: R -> R -> Int -> [R]
xRange :: R -> R -> Int -> [R]
xRange R
xmin R
xmax Int
n
    = let dt :: R
dt = (R
xmax forall a. Num a => a -> a -> a
- R
xmin) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
n forall a. Num a => a -> a -> a
- Int
1)
      in [R
xmin, R
xmin forall a. Num a => a -> a -> a
+ R
dt .. R
xmax]


{-
if __name__ == '__main__':
    m = 1
    omega = 10
    xmin = -2.0
    xmax =  2.0
    num = 256
    num = 128
    dt = 0.0002
    dt = 0.01
    xs = linspace(xmin,xmax,num)
    dx = xs[1] - xs[0]

    super = lambda x: (harm0(m,omega)(x) + harm1(m,omega)(x))/sqrt(2)
    shiftedHarm = lambda x: harm0(m,omega)(x-1)
    coh = coherent(m,omega,1)

    print sum(conj(psi)*psi)*dx

    harmV = harmonicV(m * omega**2)

    V = doubleWell(1,0.1*hbar*omega)
    V = squareWell(1.0,hbar*omega)
    V = harmonicV(m*omega**2)
    V = stepV(10*hbar*omega)
    V = wall(0.1,14.0*hbar*omega,0)
    V = freeV

    H = matrixH(m,xmin,xmax,num,V)
    I = matrixI(num)

    (vals,vecs) = eigh(H)

    E0 = vals[0]
    E1 = vals[1]
    psi0 = normalize(transpose(vecs)[0],dx)
    psi1 = normalize(transpose(vecs)[1],dx)

    psi = func2psi(gaussian(0.3,1),xmin,xmax,num)
    psi = func2psi(coh,xmin,xmax,num)
    psi = func2psi(movingGaussian(0.3,10,-1),xmin,xmax,num)

    psi = psi0
    psi = psi1
    psi = (psi0 + psi1)/sqrt(2)

    E = sum(conj(psi)*triDiagMatrixMult(H,psi)).real*dx

    Escale = hbar*omega

    print E
    print Escale

    leftM  = I + 0.5 * i * H / hbar * dt
    rightM = I - 0.5 * i * H / hbar * dt

    box = display(title='Schrodinger Equation',width=1000,height=1000)

    c = curve(pos = psi2rho(psi,xs))
    c.color = color.blue
    c.radius = 0.02

    ball = sphere(radius=0.05,color=color.red,pos=(expectX(psi,xs),0,0))

    pot_curve = [(x,V(x)/Escale,0) for x in xs if V(x)/Escale < xmax]
    pot = curve(color=color.green,pos=pot_curve,radius=0.01)

    Eline = curve(color=(1,1,0),pos=[(x,E/Escale) for x in xs])
    axis = curve(color=color.white,pos=[(x,0) for x in xs])

    while 1:
        psi = solve(leftM,triDiagMatrixMult(rightM,psi))
        c.pos = psi2rho(psi,xs)
        ball.x = expectX(psi,xs)

To Do:
add combinators for potentials
to shift horizontally and vertically,
and to add potentials

-}

-- Are we committed to SI units for hbar?  No.
-- harmonic oscillator functions depend only on sqrt(hbar/m omega)
-- which is a length parameter
-- for moving gaussian, could give hbar/p0 instead of p0
-- (is that debrogie wavelength?  I think it's h/p0)