module TBit.Parameterization ( loadParams
                             , getScalar
                             , getVector
                             , getMesh
                             , crunch
                             , primitiveLattice
                             , recipPrimitiveLattice) where

import TBit.Types
import Prelude hiding (map, head)

import Control.Monad.State
import Control.Monad.Except 

import qualified Data.Map as M
import Data.Complex
import Data.List (map, head) -- .Stream (map, head)
import Data.Maybe

import Numeric.LinearAlgebra.HMatrix (toColumns, fromRows, col, linearSolve, Vector)

-- |Given a list of ('Prelude.String', a) pairs, return a mapping from the
--  string to the value. Such a map is suitable for setting the 'TBit.Types.scalarParams'
--  and 'TBit.Types.vectorParams' records of the 'TBit.Types.Parameters' type.
loadParams :: [(String, a)] -> M.Map String a
loadParams = foldl (flip (uncurry M.insert)) M.empty

-- |Given a 'Prelude.String', retrieve that value from the parameterization monad.
--  If no such scalar has been uploaded to the 'TBit.Types.Parameters' type, 'getScalar'
--  with throw an error in 'TBit.Types.Parameterized'\'s 'Control.Monad.Except.ExceptT' monad.
getScalar :: String -> Parameterized (Complex Double)
getScalar s = do p <- gets (M.lookup s . scalarParams)
                 if   isJust p
                 then return $ fromJust p
                 else throwError 
                    $ UnknownParameter s

realScalar :: String -> Parameterized Double
realScalar = liftM realPart . getScalar

complexScalar :: String -> Parameterized (Complex Double)
complexScalar = getScalar

-- |Given a 'Prelude.String', retrieve that value from the parameterization monad.
--  If no such scalar has been uploaded to the 'TBit.Types.Parameters' type, 'getScalar'
--  with throw an error in 'TBit.Types.Parameterized'\'s 'Control.Monad.Except.ExceptT' monad.
getVector :: String -> Parameterized (Vector (Complex Double))
getVector s = do p <- gets (M.lookup s . vectorParams)
                 if   isJust p
                 then return $ fromJust p
                 else throwError 
                    $ UnknownParameter s

-- |Return the spacing information for purposes of gridding the Brillouin zone.
getMesh :: Parameterized Meshing
getMesh = gets meshingData

-- |Compute a 'TBit.Types.Parameterized' quantity given a set of input parameters.
crunch :: Parameterized a -> Parameters -> Either TBError a
crunch pmz ps = fst $ runState (runExceptT pmz) ps

-- |Return a list of primitive lattice vectors.
primitiveLattice :: Parameterized Lattice
primitiveLattice = gets latticeData

-- |Return a list of reciprocal primitive lattice vectors. They correspond in order
--  to the return values of 'primitiveLattice', in the sense that:
--  
--  > do as <- primitiveLattice
--  >    bs <- primitiveLattice
--  >    return $ zipWith `dot` as bs
--
--  will return the list
--
--  > replicate dim (2*pi)
--
--  up to numerical precision.
recipPrimitiveLattice :: Parameterized Lattice
recipPrimitiveLattice = do lat <- primitiveLattice
                           let l = length lat
                           let bs = [ linearSolve (fromRows lat) 
                                                  (col $ mtx l n)
                                                  | n <- [1..l]]
                           if   any isNothing bs
                           then throwError SingularLatticeError
                           else (return . map (head . toColumns) . catMaybes) bs
    where mtx l n = [ 2.0 * pi * kd n j | j <- [1..l] ]
          kd a b = if a == b then 1.0 else 0.0