{-# LANGUAGE NoImplicitPrelude #-}
module TBit.Hamiltonian.Builder.Terms ( neighborTerm
                                      , onsiteTerm
                                      , parityStaggeredTerm
                                      , localMoments
                                      , kagomeSOC
                                      , rashbaZ
                                      ) where

import Prelude.Listless
import Data.List -- .Stream
import TBit.Types
import TBit.Parameterization
import Control.Monad.State
import Data.Graph.Inductive
import Data.Maybe
import Numeric.LinearAlgebra.HMatrix

-- | Add nearest-neighbor hopping to a lattice model.
neighborTerm :: String -> CellGraph -> Parameterized (Wavevector -> AdjMatrix)
neighborTerm p g = do t <- getScalar p
                      r <- restrictor
                      return (\k -> nmap (scalar . const 0) 
                                  $ emap (scalar . \d -> t * cis ((r d) `dot` (r k))) g)

-- | Add an onsite energy term to a lattice model.
onsiteTerm :: String -> CellGraph -> Parameterized (Wavevector -> AdjMatrix)
onsiteTerm p g = do t <- getScalar p
                    return (\_ -> nmap (scalar . const t) 
                                $ emap (scalar . const 0) g)

-- | Add an staggered onsite term to a lattice model. Works based on the
--   integer parity of graph nodes, making it model-detail-dependent.
parityStaggeredTerm :: String -> CellGraph -> Parameterized (Wavevector -> AdjMatrix)
parityStaggeredTerm p g = do t <- getScalar p
                             return (\_ -> nmap (scalar . \e -> t * (-1.0)^(num e + 1)) 
                                         $ emap (scalar . const 0) g)
{-|
    Produces a representation of local magnetic moments given by
    site-wise 'TBit.Types.VectorSite' data. Fails clumsily if
    applied to Scalar sites.
-}
localMoments :: String -> CellGraph -> Parameterized (Wavevector -> AdjMatrix)
localMoments p g = do t <- getScalar p
                      return (\_ -> nmap (scale t . exch . mom)
                                  $ emap (const zeroMtx) g)
    where zeroMtx = scale 0 (ident 2)
          exch = sum 
               . zipWith (flip scale) [sigmaX, sigmaY, sigmaZ]
               . map (:+ 0.0)
               . toList

{-| Produces a Rashba spin-orbit coupling term for an E-field applied along
    the /z/ direction. -}
rashbaZ :: String -> CellGraph -> Parameterized (Wavevector -> AdjMatrix)
rashbaZ p g = do t <- getScalar p
                 r <- restrictor
                 return $ \k -> mtx t k r
    where nng = nmap (\u -> map ((,) (num u))
                          . filter (\t -> snd t /= vector [0,0]) 
                          $ nthNeighborsTo g 1 (num u)) g
          rsh t k r = nmap 
                    ( map (\(u,(v,d)) -> (,,) u v
                                        . scale (iC * t)
                                        . scale (cis $ r k `dot` r d)
                                        $ scale (c $ d!0) sigmaY
                                        - scale (c $ d!1) sigmaX)) nng
          mtx t k r = let ns = labNodes $ rsh t k r
                          es = concatMap snd ns
                       in nmap (const 0) $ mkGraph ns es
          c z = z :+ 0.0
                  

{-|
    Produces a spin-orbit interaction in the style of that given by
    Hua et al in PRL /112/, 017205 (2014). It should probably only
    be used with a Kagomé lattice 'TBit.Types.CellGraph'. It works by
    looking at nearest neighbor pairs (/i/,/j/) and then looking up the
    'TBit.Types.VectorSite' moment of site /k/; /k/ is computed as
    
    > k = let i' = succ $ (i - 1) `mod` 3
    >         j' = succ $ (j - 1) `mod` 3
    >      in head $ [1,2,3] \\ [i',j']

    (Recall that /i/ and /j/ are indexed from 1 as nodes.) Clearly, 
    this function is not safe unless it's applied to the correct lattice.

    Once 'TBit.Types.mom' is computed for 'TBit.Types.VectorSite' /k/, it
    is coupled to the Pauli matrix tensor as expected. The parity &nu;
    is chosen by asking whether succ /i/ == /j/ mod 3.
-}
kagomeSOC :: String -> CellGraph -> Parameterized (Wavevector -> AdjMatrix)
kagomeSOC p g = do t <- getScalar p
                   r <- restrictor
                   return $ (\k -> mtx t k r)
          -- The nng graph is just like g, but nodes are replaced with
          -- lists of nearest neighbors and their displacements
    where nng = nmap (\u -> map ((,) (num u))
                          . filter (\t -> snd t /= 0) 
                          $ nthNeighborsTo g 1 (num u)) g
          moms t k r = nmap 
                     ( map (\(u,(v,d)) -> (,,) u v
                                        . scale (iC * t)
                                        . scale (cis $ (r k) `dot` (r d))
                                       $! exch (nab u v))) nng
          mtx t k r = let ns = labNodes $! moms t k r
                          es = concatMap snd ns
                       in nmap (const 0) $! mkGraph ns es
          nab i j = let i' = succ $ pred i `mod` 3
                        j' = succ $ pred j `mod` 3
                     in neg i j . mom . fromJust . lab g . head $! [1,2,3] \\ [i',j']
          neg i j = let i' = succ $ pred i `mod` 3
                        j' = succ $ pred j `mod` 3
                     in if   (succ i' `mod` 3) == (j' `mod` 3)
                        then scale 1.0
                        else scale (-1.0)
          exch = sum 
               . zipWith (flip scale) [sigmaX, sigmaY, sigmaZ]
               . map (:+ 0.0)
               . toList
          
          

nthNeighborsTo :: CellGraph -> Int -> Node -> [(Node,Displacement)]
nthNeighborsTo _ 0 j = [(j,vector [0,0])]
nthNeighborsTo g n j = concatMap (branch g) $! nthNeighborsTo g (pred n) j

branch :: CellGraph -> (Node, Displacement) -> [(Node, Displacement)]
branch g (i,r) = map (\(_,m,d) -> (m,r+d)) $ out g i

restrictor :: Parameterized (Vector Double -> Vector Double)
restrictor = do (l:ls) <- gets latticeData
                case compare (size l) (rank . fromColumns $ l:ls)
                  of EQ -> return id
                     GT -> return $ (#>) (fromRows (map nrml $ l:ls))
                     LT -> error "primitive lattice overdetermines space"
    where nrml v = scale (1.0/(norm_2 v)) v