module TBit.Topological.Chern (chern, chernRaw, chernBand) where

import TBit.Types
import TBit.Framework
import TBit.Hamiltonian.Eigenstates
import TBit.Topological.Curvature

import Numeric.LinearAlgebra.HMatrix hiding ((!))

import Prelude hiding (lookup)
import Data.Map hiding (map)
import Data.Maybe (fromJust, isJust)
import qualified Data.Traversable as T

import Control.Monad
import Control.Monad.Except (throwError)


{-|
    Calculate the Chern number of the nth band (indexed from 0) by
    integrating the Berry curvature over the Brillouin zone. The 'BandIndex'
    parameter is passed directly to 'TBit.Topological.Curvature.bandCurvature',
    and should use the same conventions for specifying the band.

    The output is appropriately normalized by 1/2π. The integration is carried
    out using the TanhSinh quadrature method via
    'TBit.Numerical.Integration.integrate'.
-}
chernBand :: BandIndex -> Hamiltonian -> Parameterized Chern
chernBand n h = liftM ((*) (0.5/pi)) $ bzIntegral $ bandCurvature n h

chernRaw :: BandIndex -> Hamiltonian -> Parameterized Chern
chernRaw n h = do wg <- waveGrid h n
                  liftM sum . liftM elems 
                            . T.sequence 
                            . mapWithKey (\gid _ -> curvature wg n gid)
                            $ wg

{-|
    Calculate the Chern number of the first n occupied bands
    by using a grid of closed loops and calculating many Berry phases
    using the discretized formula. This function is /guaranteed/ to
    return an integer result by rounding the actual calculation. It tries
    to determine if the Chern number is undefined due to a degeneracy, and
    if it is then it throws an error via the 'Control.Monad.Except.ExceptT'
    monad transformer.
-}
chern ::  BandIndex -> Hamiltonian -> Parameterized Chern
chern n h = liftM (fromIntegral . round) $ chernRaw n h

waveGrid :: Hamiltonian     
         -> BandIndex 
         -> Parameterized (Grid [Eigenstate])
waveGrid h b = do bz <- meshBZ
                  T.mapM eigvs bz
    where eigvs = (safelyTake b . eigensystem h)

safelyTake :: Int -> Parameterized [(Energy, Eigenstate)] -> Parameterized [Eigenstate]
safelyTake n msys = do sys <- msys
                       case compare n (length sys)
                         of GT -> throwError dimErr
                            EQ -> return . map snd $ sys
                            LT -> if   (fst $ sys !! (n-1)) == (fst $ sys !! n)
                                  then throwError chernErr
                                  else return . map snd . take n $ sys
    where dimErr = DimensionalityError ("Cannot compute the Chern number for more "
                                      ++"energy bands than the system supports.")
          chernErr = UndefinedError ("The Chern number is undefined due to a band "
                                   ++"degeneracy.")

curvature :: Grid [Eigenstate]
          -> BandIndex 
          -> GridIndex 
          -> Parameterized Double
curvature gr b (GID (m:n:[])) = do u1 <- phaseDiff gr b here right
                                   u2 <- phaseDiff gr b right there
                                   u3 <- phaseDiff gr b there up
                                   u4 <- phaseDiff gr b up here
                                   return $ ( phase
                                            $ u1 * u2 * u3 * u4 )
                                            / (2 * pi)
    where here  = GID [m,n]
          right = GID [m+1,n]
          up    = GID [m,n+1]
          there = GID [m+1,n+1]

curvature _ _ _ = throwError $ DimensionalityError 
                             $ "Curvature calculations "
                            ++ "are only implemented for 2D."

-- This will only work properly if there grid indices form a cube
phaseDiff :: Grid [Eigenstate]
          -> BandIndex
          -> GridIndex 
          -> GridIndex 
          -> Parameterized (Complex Double)
phaseDiff gr b m n = do let eigvs1 = lookup m gr
                        let eigvs2 = lookup n gr

                        let e1 = fromJust eigvs1
                        let e2 = fromJust eigvs2

                        case (isJust eigvs1, isJust eigvs2)
                          of (True ,  True) -> return . det . (b >< b) . concat
                                             $ map (\v -> map (dot v) e2) e1
                             (True , False) -> phaseDiff gr b (    m) (cyc n)
                             (False,  True) -> phaseDiff gr b (cyc m) (    n)
                             (False, False) -> phaseDiff gr b (cyc m) (cyc n)

    where cyc (GID xs) = GID (map (\x -> if x == maximum xs then 0 else x) xs)