module TBit.Types ( Parameterized
                  , Parameters (Parameters)
                  , latticeData , meshingData
                  , scalarParams , vectorParams
                  , decomData
                  , Lattice
                  , TBError (..)
                  , Grid
                  , BandIndex
                  , Filling
                  , Meshing (Spacing)
                  , GridIndex (GID)
                  , Chern
                  , Curvature
                  , Wavevector
                  , KPath
                  , Hamiltonian
                  , Magnetization
                  , Moment
                  , ChemEnergy
                  , Energy
                  , Eigenstate
                  , Eigenbra
                  , Eigenket
                  , AFOrder   
                  , Hopping   
                  , OnSite   
                  , Rashba    
                  , SOC
                  , Parameterizable
                  , SiteData (..)
                  , Displacement
                  , CellGraph
                  , AdjMatrix
                  , Term
                  , sigmaX
                  , sigmaY
                  , sigmaZ
                  ) where

import Control.Applicative
import Control.DeepSeq
import Data.Functor()
import Data.Foldable
import Data.Traversable
import Data.Map (Map)
import Control.Monad.Except (ExceptT)
import Control.Monad.State  (State)
import Numeric.LinearAlgebra.HMatrix (Vector, Matrix, iC, (><))
import Data.Complex (Complex (..))
import Data.Graph.Inductive hiding ((><))
import Control.DeepSeq ()

-- Computational types
type Parameterized = ExceptT TBError (State Parameters)
data Parameters = Parameters { latticeData  :: Lattice
                             , meshingData  :: Meshing
                             , decomData    :: [LEdge Displacement]
                             , scalarParams :: Map String (Complex Double)
                             , vectorParams :: Map String (Vector (Complex Double)) }
                               deriving (Show)

type Term = String -> CellGraph -> Parameterized (Wavevector -> AdjMatrix)

instance NFData Meshing where
    rnf (Spacing t) = t `seq` ()

instance NFData Parameters where
    rnf ps = map rnf (latticeData ps) 
       `seq` rnf (meshingData ps)
       `seq` map rnf (decomData ps)
       `seq` rnf (scalarParams ps)
       `seq` rnf (vectorParams ps)

type Lattice = [Vector Double]
data TBError = SingularLatticeError 
             | DimensionalityError String
             | UndefinedError String
             | UnknownParameter String
               deriving Show

type Grid      = Map GridIndex
type BandIndex = Int
type Filling   = BandIndex

data Meshing   = Spacing Double deriving Show
data GridIndex = GID [Int] deriving (Ord, Eq, Show)

-- Physical quantities
type Chern       = Double
type Wavevector  = Vector Double
type Curvature   = Double
type KPath       = [(String,Wavevector)]
type Hamiltonian = Wavevector -> Matrix (Complex Double)
type Magnetization = Double
type Moment      = Vector Double

type ChemEnergy = Double
type Energy     = Double
type Eigenstate = Vector (Complex Double)
type Eigenbra   = Matrix (Complex Double)
type Eigenket   = Matrix (Complex Double)

-- TB Parameters
type AFOrder = Complex Double
type Hopping = Complex Double
type OnSite  = Complex Double
type Rashba  = Complex Double
type SOC     = Complex Double
type Position     = Vector Double
data SiteData     = ScalarSite { num :: Int }
                  | VectorSite { num :: Int, mom :: Moment } deriving Show

type Parameterizable = ExceptT TBError (State Parameters)
type Displacement = Vector Double
type CellGraph    = Gr SiteData Displacement
type AdjMatrix    = Gr (Matrix (Complex Double)) (Matrix (Complex Double))

instance Traversable ((,,) a b) where
    traverse f (x,y,z) = (,,) x y <$> f z

instance Foldable ((,,) a b) where
    foldMap f (_, _, z) = f z
    foldr f t (_, _, z) = f z t

instance Functor ((,,) a b) where
    fmap f (x,y,z) = (x,y,f z)

instance Traversable ((,,,) a b c) where
    traverse f (x,y,z,w) = (,,,) x y z <$> f w

instance Foldable ((,,,) a b c) where
    foldMap f (_, _, _, w) = f w
    foldr f t (_, _, _, w) = f w t

instance Functor ((,,,) a b c) where
    fmap f (x,y,z,w) = (x,y,z,f w)

instance NFData TBError where
    rnf e@(SingularLatticeError) = e `seq` ()
    rnf (DimensionalityError s)  = s `seq` ()
    rnf (UndefinedError      s)  = s `seq` ()
    rnf (UnknownParameter    s)  = s `seq` ()

sigmaX :: Matrix (Complex Double)
sigmaX = (2 >< 2)
         [ 0 , 1
         , 1 , 0 ]

sigmaY :: Matrix (Complex Double)
sigmaY = (2 >< 2)
         [ 0  , -iC
         , iC ,  0 ]

sigmaZ :: Matrix (Complex Double)
sigmaZ = (2 >< 2)
         [ 1 , 0
         , 0 ,-1 ]