{-# LANGUAGE NoImplicitPrelude #-}
module TBit.Hamiltonian.Builder.Matrification (mixedSum, toMatrix) where

import Prelude.Listless
import Data.List -- .Stream
import Numeric.LinearAlgebra.HMatrix
import Data.Graph.Inductive
import Data.Maybe

mixedAdd :: Maybe (Matrix (Complex Double))
         -> Maybe (Matrix (Complex Double))
         -> Maybe (Matrix (Complex Double))
mixedAdd (Just a) (Just b) | dim a ==     dim b = Just $ a + b
                           | dim a == 2 * dim b = Just $ a + expand b
                           | dim b == 2 * dim a = Just $ expand a + b
                           | otherwise          = Nothing
    where dim = fst . size
          expand m = m `kronecker` ident 2

mixedAdd _ _ = Nothing

{-|
    Take a list of matrices, some of which may differ from the others in
    dimensionality by a factor of two, and maybe return the sum of these
    matrices with appropriate (right) kronecker products taken to make
    the summation well-formed. Useful for combining matrices written in
    forms with and without spin indices.
-}
mixedSum :: [Matrix (Complex Double)] -> Maybe (Matrix (Complex Double))
mixedSum = foldl1 mixedAdd . map Just

{-|
    Send a tight-binding graph model to the corresponding Hamiltonian
    matrix.
-}
toMatrix :: Gr (Matrix (Complex Double)) (Matrix (Complex Double)) 
         -> Matrix (Complex Double)
toMatrix gr = fromJust $ mixedSum [offd, ond]
    where es = map (\(i,j,k) -> ((i,j),k)) $ labEdges gr
          n  = noNodes gr
          ns = labNodes gr
          d i j = if i==j then 1 else 0
          --offd = fromBlocks [[ fromMaybe 0 $ lookup (i,j) es | i <- [1..n]]
          --                                                   | j <- [1..n]]
          offd = fromBlocks [[ sum $ map snd 
                                   $ filter ((==) (i,j) . fst) es | i <- [1..n]]
                                                                  | j <- [1..n]]
          ond  = fromBlocks [[ (*) (d i j) 
                             $ sum $ map snd $ filter ((==) j . fst) ns
                                                                 | i <- [1..n]]
                                                                 | j <- [1..n]]