module TBit.Topological.Curvature where 

import TBit.Types
import TBit.Framework
import TBit.Parameterization
import TBit.Numerical.Derivative
import TBit.Hamiltonian.Eigenstates

import Control.Monad

import Data.Map (elems)
import Data.List (delete) -- .Stream (delete)
import qualified Data.Traversable as T

import Numeric.LinearAlgebra.HMatrix

{-|
   Calculate the Berry curvature of a single band, which is to be given
   indexed from zero (i.e. to calculate the lowest band, pass in 0 for
   the 'BandIndex'. Uses the five-point stencil method for differentiation.
-}
bandCurvature :: BandIndex -> Hamiltonian -> Wavevector -> Parameterized Curvature
bandCurvature n h k = do ket <- eigenkets h k
                         bra <- eigenbras h k
                         eng <- eigenenergies h k
                         return . negate . twice . imagPart . sum
                                $ [ (num $ (bra!!n) <> hx <> (ket!!m))
                                  * (num $ (bra!!m) <> hy <> (ket!!n))
                                  / ((eng!!m - eng!!n)^2 :+ 0.0)
                                  | m <- delete n [0..pred dim]]
    where num m = m ! 0 ! 0
          dim = fst $ size $ h k
          twice = (*) 2.0
          kx = k ! 0
          ky = k ! 1
          hx = diff 0.0005 (\x -> h $ vector [realPart x, ky]) (kx :+ 0.0)
          hy = diff 0.0005 (\y -> h $ vector [kx, realPart y]) (ky :+ 0.0)

{-|
   Calculate the total Berry curvature of a the occupied bands, which are
   specified by passing in the number of filled bands as the first argument.
   For example, to find the curvature due to occupied bands of a 4 band system
   at half-filling, pass in 2 for the 'BandIndex'. Uses the five-point stencil
   method for differentiation.
-}
occupiedCurvature :: BandIndex -> Hamiltonian -> Wavevector -> Parameterized Curvature
occupiedCurvature b h k = do ket <- eigenkets h k
                             bra <- eigenbras h k
                             eng <- eigenenergies h k
                             (Spacing s) <- getMesh
                             return $ negate . twice . imagPart . sum
                                    $ zipWith (/)
                                      [ num $ (bra!!n) <> hx s <> (ket!!m)
                                           <> (bra!!m) <> hy s <> (ket!!n)
                                      | m <- occ , n <- unocc ]
                                      [ (eng!!m - eng!!n)^2 :+ 0.0
                                      | m <- occ , n <- unocc ]
    where num m = m ! 0 ! 0
          occ = [0 .. pred b]
          unocc = [b .. pred dim]
          dim = fst $ size $ h k
          twice = (*) 2.0
          kx = k ! 0
          ky = k ! 1
          hx s = diff (s :+ 0.0) (\x -> h $ vector [realPart x, ky]) (kx :+ 0.0)
          hy s = diff (s :+ 0.0)  (\y -> h $ vector [kx, realPart y]) (ky :+ 0.0)


-- |Deprecated?
curvatureFieldBand :: BandIndex -> Hamiltonian -> Parameterized [(Wavevector, Curvature)]
curvatureFieldBand n h = do grid <- meshBigBZ
                            let field = fmap (\k -> do { bc <- bandCurvature n h k;
                                                         return (k,bc) }) grid
                            liftM elems $ T.sequence field
    
-- |Deprecated?
curvatureFieldOcc :: BandIndex -> Hamiltonian -> Parameterized [(Wavevector, Curvature)]
curvatureFieldOcc n h = do grid <- meshBZ
                           let field = fmap (\k -> do { bc <- occupiedCurvature n h k;
                                                        return (k,bc) }) grid
                           liftM elems $ T.sequence field