{-# LANGUAGE NoImplicitPrelude #-}
module TBit.Framework ( meshBZ
                      , meshBigBZ
                      , bzIntegral
                      , bzIntegral'
                      , bzIntegral''
                      , kPath ) where

import TBit.Types
import TBit.Parameterization
import TBit.Numerical.Integration

import Data.Foldable hiding (foldl', concatMap)
import Data.List hiding (sum) -- .Stream hiding (sum)
import Data.Traversable (traverse)
import qualified Data.Map as M

import Numeric.LinearAlgebra.HMatrix (norm_2, scale, (!))

import Control.Monad (liftM)
import Control.Monad.State (get)
import Prelude.Listless 

{-|
    Given a list of points in /k/-space, return a list of
    points that interpolates affine paths between them, in
    turn; a typical usage case might be

    > kPath [gammaPoint, kPoint, mPoint, gammaPoint]

    which is used in the 'Plots.bandPlot' function. The spacing
    between points on the interpolated path is determined by
    the 'Types.meshingData' parameter.
-}
kPath :: [Wavevector] -> Parameterized ([Wavevector])
kPath []        = return []
kPath (k:[])    = return [k]
kPath (k:k':ks) = do (Spacing eps) <- getMesh
                     nextPath <- kPath (k':ks)
                     return $ [ k + scale n (dk eps)
                              | n <- takeWhile ((>) (norm_2 (k'-k)) . (*) eps)
                                               [0.0, 1.0 ..]]
                              ++ nextPath

    where dk eps = scale eps $ normalize (k'-k)
          normalize v = scale (1.0/(norm_2 v)) v

{-|
    Given a function defined on the Brillouin zone, evaluate it everywhere
    by doing nested single integrations and using Takahashi and Mori's
    Tanh-Sinh quadrature method. Should be robust against singularities
    and the like, and is /properly/ set up for massive parallelization (via
    EdwardKmett's integration library, certainly not mine). Watch Dirac run
    this thing. Only implemented in 2D.
-}
bzIntegral'' :: (Wavevector -> Parameterized Double) -> Parameterized Double
bzIntegral'' f = do (b1:b2:_) <- recipPrimitiveLattice
                    ps        <- get
                    let g s t = f (scale s b1 + scale t b2) :: Parameterized Double
                    return . (*) (jac b1 b2)
                           . fst 
                           $ integrate (\x y -> either e id 
                                               . flip crunch ps 
                                               $ g x y) (0.0,1.0) (0.0,1.0)
    where e err = error (show err)
          jac a b = abs $ (a!0) * (b!1) - (a!1) * (b!0)


{-|
    As 'bzIntegral', but also gives the absolute error as the second value
    in the returned tuple.
-}
bzIntegral' :: (Wavevector -> Parameterized Double) -> Parameterized (Double, Double)
bzIntegral' f = do (b1:b2:_) <- recipPrimitiveLattice
                   ps        <- get
                   let g s t = f (scale s b1 + scale t b2) :: Parameterized Double
                   return . (\(x,y) -> (x * (jac b1 b2), y))
                          $ integrate (\x y -> either e id 
                                              . flip crunch ps 
                                              $ g x y) (0.0,1.0) (0.0,1.0)
    where e err = error (show err)
          jac a b = abs $ (a!0) * (b!1) - (a!1) * (b!0)

{-|
    Integrates using a simple grid-sum.
-}
bzIntegral :: (Wavevector -> Parameterized Double) -> Parameterized Double
bzIntegral f = do kGrid        <- meshBZ
                  (Spacing dk) <- getMesh
                  let df = liftM ((*) dk . (*) dk) . f
                  liftM sum $! traverse df kGrid

{-|
    Returns a mesh of points within the /n/-parallelepiped subtended by the
    reciprocal lattice vectors. Covers the entire brillouin zone, though not
    in the shape you might expect, and not in a way that's pretty for graphing.
-}
meshBZ :: Parameterized (Grid Wavevector)
meshBZ = do bs <- recipPrimitiveLattice
            (Spacing dk) <- getMesh
            let frame = map (makeAxis dk) bs 
            return . foldl' (flip (uncurry M.insert)) M.empty 
                   $ populate frame

{-|
    As 'meshBZ', but with parallelepipeds extending in each quadrant, octant...
    /n/-ant of the reciprocal lattice basis. Should cover more than the entire
    first Brillouin zone.
-}
meshBigBZ :: Parameterized (Grid Wavevector)
meshBigBZ = do bs <- recipPrimitiveLattice
               (Spacing dk) <- getMesh
               let frame = map (makeDoubleAxis dk) bs 
               return . foldl' (flip (uncurry M.insert)) M.empty 
                      $ populate frame

--

makeAxis :: Double -> Wavevector -> [Wavevector]
makeAxis dk b = pos ++ [b]
    where pos = [ scale n bdk 
                | n <- takeWhile (\m -> (m*dk) < (norm_2 b)) [0..]]
          bdk = scale (dk / (norm_2 b)) b

makeDoubleAxis :: Double -> Wavevector -> [Wavevector]
makeDoubleAxis dk b = (reverse neg) ++ pos ++ [b]
    where pos = [ scale n bdk 
                | n <- takeWhile (\m -> (m*dk) < (norm_2 b)) [0..]]
          neg = [ scale n bdk 
                | n <- takeWhile (\m -> (abs $ m*dk) < (norm_2 b)) nins]
          bdk = scale (dk / (norm_2 b)) b
          nins = map negate [1..]

populate :: [[Wavevector]] -> [(GridIndex, Wavevector)]
populate [] = []
populate (ks:[])  = zipWith (\n k -> (GID [n],k)) [0..] ks
populate (ks:kss) = concatMap (\(GID i,k') -> [ (GID (j:i), k + k') 
                                              | (j,k) <- zip [0..] ks]) 
                  $ populate kss