module TBit.Hamiltonian.Builder.PrimitiveLattice ( primLattice
                                                 , setPrimLattice 
                                                 , delLEdges 
                                                 , replicateE
                                                 , replicateG
                                                 ) where

import TBit.Types
import Numeric.LinearAlgebra.HMatrix
import Data.Graph.Inductive
import Data.Graph.Inductive.Query.SP (spLength)
import Data.Graph.Inductive.Query.DFS (components)
import Control.Monad.State
import Data.List (nub, nubBy, sortBy) -- .Stream

nthNeighborsTo :: CellGraph -> Int -> Node -> [(Node,Displacement)]
nthNeighborsTo g 0 j = [(j,vector [0,0])]
nthNeighborsTo g n j = concatMap (branch g) $ nthNeighborsTo g (pred n) j

branch :: CellGraph -> (Node, Displacement) -> [(Node, Displacement)]
branch g (i,r) = map (\(_,m,d) -> (m,r+d)) $ out g i

    Sets the 'TBit.Types.latticeData' field of the 'TBit.Types.Parameters'
    state according to a 'primLattice'.
setPrimLattice :: CellGraph -> Parameterizable CellGraph
setPrimLattice g = do tds <- gets decomData
                      modify (\ps -> ps {latticeData = primLattice tds g})
                      return g

    Determine a primitive lattice for a given 'TBit.Types.CellGraph'.
    This is currently accomplished by determining the diameter of the thegraph,
    collecting all non-zero displacements from an arbitary site to itself
    which are within a diameter's worth of NN hopping, and returning a
    maximal linearly independent subset of these with a preference for
    vectors with smaller L2 norms.

    For finite nanoribbons such as those generated by
    'TBit.Hamiltonian.Builder.Decompactification.decompactify', the graph
    diameter can be quite large. This means that finding the primitive
    lattice vectors as described above can become unreasonably slow. In the
    future, we will label decompactified lattice vectors in the CellGraph
    so that 'graphDiameter' will understand not to count them.
primLattice :: [LEdge Displacement] -> CellGraph -> Lattice
primLattice tds gr = take (rank' $ fromColumns $ bulkCycles tds gr site) 
                   $ bulkCycles tds gr site
    where site = fst . head . labNodes $ gr
          rank' m = if   fst (size m) == 0
                    then 0
                    else rank m

bulkCycles :: [LEdge Displacement] -> CellGraph -> Node -> Lattice
bulkCycles tds gr site = nub
                       $ nubBy (\u v -> rank (fromColumns [u,v]) /= 2)
                       $ sortBy (\a b -> compare (norm_2 a) (norm_2 b)) 
                       $ filter ((<) 0.0001 . norm_2)
                       $ map snd twins
    where hops = nthNeighborsTo gr (succ $ graphDiameter g) site 
          ges = filter (\(u,v,_) -> (u `elem` gvs)
                                 && (v `elem` gvs)) $ labEdges gr
          gvs = head $ components $ delLEdges tds gr
          g :: Gr Int Int
          g   = emap (const 1) $ mkGraph (map (\x -> (x,x)) gvs) ges
          twins = filter next hops
          next (k,d) = (site == k) && (any (/= 0) $ toList d)

graphDiameter :: Graph gr => gr a Int -> Int
graphDiameter g = maximum $ [ spLength i j g | i <- [1..noNodes g]
                                             , j <- [1..noNodes g]
                                             , i <= j ]

-- | Delete a list of 'Data.Graph.Inductive.LEdge's from a graph. 
delLEdges :: (Eq b, DynGraph gr) => [LEdge b] -> gr a b -> gr a b
delLEdges es g = foldr delLEdge g es

-- | Replicate a 'Data.Graph.Inductive.LEdge' n times, each time
--   increasing the in and out nodes by m. (n is the first argument,
--   m the second, somewhat stupidly.)
replicateE :: Int -> Int -> LEdge Displacement -> [LEdge Displacement]
replicateE n m (v1,v2,e) = [ (v1 + j*m, v2 + j*m, e) | j <- [0..pred n]]

-- | Replicate a 'TBit.Types.CellGraph' n times, each time
--   increasing the in and out nodes by the number of nodes
--   in the 'TBit.Types.CellGraph'.
replicateG :: Int -> CellGraph -> CellGraph
replicateG n gr = mkGraph vss ess
    where m = noNodes gr
          vss = concatMap (\(j,vs) 
                          -> map (\(v,a) 
                                  -> (v + j*m, a { num = num a + j*m })) vs)
              $ zip [0..] 
              $ map labNodes gs
          ess = concatMap (\(j,es) 
                          -> map (\(v1,v2,r) 
                                  -> (v1 + j*m, v2 + j*m, r)) es)
              $ zip [0..] 
              $ map labEdges gs
          gs = replicate n gr