{-# OPTIONS_GHC -fno-warn-missing-methods #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE UndecidableInstances   #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE DeriveDataTypeable   #-}
{-|
    Module      :  Data.Number.ER.RnToRm.Approx.DomEdges
    Description :  separate approximations per domain-box hyper-edge
    Copyright   :  (c) Michal Konecny
    License     :  BSD3

    Maintainer  :  mik@konecny.aow.cz
    Stability   :  experimental
    Portability :  portable

-}
module Data.Number.ER.RnToRm.Approx.DomEdges 
(
    ERFnDomEdgesApprox(..)
)
where

import qualified Data.Number.ER.RnToRm.Approx as FA
import qualified Data.Number.ER.Real.Approx as RA
import qualified Data.Number.ER.Real.Approx.Elementary as RAEL

import qualified Data.Number.ER.Real.DomainBox as DBox
import Data.Number.ER.Real.DomainBox (VariableID(..), DomainBox)
import Data.Number.ER.BasicTypes
import Data.Number.ER.Misc
import Data.Number.ER.PlusMinus

import Data.Typeable
import Data.Generics.Basics
import Data.Binary

import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.List

{-|
    Use some function approximation type and for each domain box
    keep a structure of function approximations of this type indexed
    by the hyper-edge structure.  For each hyper-edge of the domain,
    the approximation has this edge as its domain.
    
    E.g. for a 2D square domain there are:
    
      * one approximation for the whole square
      
      * four 1D approximations, one for each edge
      
      * eight 0D approximations, one for each endpoint of each edge 
 -}
data ERFnDomEdgesApprox varid fa =
    ERFnDomEdgesApprox
    {
        erfnMainVolume :: fa,
        erfnEdges :: Map.Map (varid, PlusMinus) (ERFnDomEdgesApprox varid fa)
    }
    deriving (Typeable,Data)

instance (Ord a, Binary a, Binary b) => Binary (ERFnDomEdgesApprox a b) where
  put (ERFnDomEdgesApprox a b) = put a >> put b
  get = get >>= \a -> get >>= \b -> return (ERFnDomEdgesApprox a b)

edgesLift1 ::
    (fa -> fa) ->
    (ERFnDomEdgesApprox varid fa) -> (ERFnDomEdgesApprox varid fa)
edgesLift1 op (ERFnDomEdgesApprox mainEncl edges) =    
    ERFnDomEdgesApprox (op mainEncl) (Map.map (edgesLift1 op) edges)
        
edgesLift2 ::
    (Ord varid) =>
    (fa -> fa -> fa) ->
    (ERFnDomEdgesApprox varid fa) -> (ERFnDomEdgesApprox varid fa) -> (ERFnDomEdgesApprox varid fa)
edgesLift2 op f1@(ERFnDomEdgesApprox mainEncl1 edges1) f2@(ERFnDomEdgesApprox mainEncl2 edges2) 
        | Map.keys edges1 == Map.keys edges2 =
            ERFnDomEdgesApprox (mainEncl1 `op` mainEncl2) $
                Map.intersectionWith (edgesLift2 op) edges1 edges2
        | otherwise =
            edgesLift2 op f1a f2a
        where
        (f1a, f2a) = unifyEdgeVariables f1 f2

unifyEdgeVariables ::
    (Ord varid) =>
    ERFnDomEdgesApprox varid fa ->
    ERFnDomEdgesApprox varid fa ->
    (ERFnDomEdgesApprox varid fa, ERFnDomEdgesApprox varid fa)
unifyEdgeVariables 
        f1@(ERFnDomEdgesApprox fa1 edges1) 
        f2@(ERFnDomEdgesApprox fa2 edges2) =
    (ERFnDomEdgesApprox fa1 edges1amended, 
     ERFnDomEdgesApprox fa2 edges2amended)
    where
    vars1 = Set.map fst $ Map.keysSet edges1
    vars2 = Set.map fst $ Map.keysSet edges2
    vars = Set.union vars1 vars2
    newVars1 = vars2 `Set.difference` vars1 
    newVars2 = vars1 `Set.difference` vars2 
    (ERFnDomEdgesApprox _ edges1amended) = 
        foldl (\f v -> addVarToEdges v f) f1 $ Set.toList newVars1
    (ERFnDomEdgesApprox _ edges2amended) = 
        foldl (\f v -> addVarToEdges v f) f2 $ Set.toList newVars2

addVarToEdges ::
    (Ord varid) =>
    varid ->
    ERFnDomEdgesApprox varid fa ->
    ERFnDomEdgesApprox varid fa 
addVarToEdges var f@(ERFnDomEdgesApprox fa edges) =
    (ERFnDomEdgesApprox fa edgesNew)
    where
    edgesNew =
        Map.insert (var, Plus) f $ 
            Map.insert (var, Minus) f $ 
                Map.map (addVarToEdges var) edges


instance 
    (FA.ERFnDomApprox box varid domra ranra fa, Ord varid, VariableID varid) =>
    Show (ERFnDomEdgesApprox varid fa)
    where
    show f@(ERFnDomEdgesApprox fa edges) =
        showAux [] f
        where
        showAux varAssignments (ERFnDomEdgesApprox fa edges) =
            edgeDescription ++
            show fa ++
            (concat $ map showEdge $ Map.toList edges)
            where
            edgeDescription 
                | null varAssignments =
                    "\n>>>>> main enclosure: "
                | otherwise =
                    "\n>>>>> edge" ++ showVarAssignments varAssignments ++ ": "
            showVarAssignments varAssignments =
                concat $ map showVarAssignment $ reverse varAssignments
            showVarAssignment (varID, val) =
                " " ++ showVar varID ++ "=" ++ show val
            showEdge ((varId, pm), faEdge) =
                showAux ((varId, varDomEndpoint) : varAssignments) faEdge
                where 
                varDomEndpoint =
                    case pm of
                        Minus -> varDomLo
                        Plus -> varDomHi 
                (varDomLo, varDomHi) = RA.bounds varDom
                varDom = DBox.findWithDefault RA.bottomApprox varId domB
        domB = FA.dom fa

instance
    (FA.ERFnApprox box varid domra ranra fa) =>
    Eq (ERFnDomEdgesApprox varid fa)
    where
    (ERFnDomEdgesApprox fa1 edges1) == (ERFnDomEdgesApprox fa2 edges2) =
        fa1 == fa2

instance
    (FA.ERFnApprox box varid domra ranra fa, Ord fa) =>
    Ord (ERFnDomEdgesApprox varid fa)
    where
    compare (ERFnDomEdgesApprox fa1 edges1) (ERFnDomEdgesApprox fa2 edges2) =
        compare fa1 fa2

instance
    (FA.ERFnDomApprox box varid domra ranra fa, VariableID varid) =>
    Num (ERFnDomEdgesApprox varid fa)
    where
    fromInteger n = ERFnDomEdgesApprox (fromInteger n) Map.empty
    negate = edgesLift1 negate
    (+) = edgesLift2 (+)
    (*) = edgesLift2 (*)

instance 
    (FA.ERFnDomApprox box varid domra ranra fa, VariableID varid) =>
    Fractional (ERFnDomEdgesApprox varid fa)
    where
    fromRational r = ERFnDomEdgesApprox (fromRational r) Map.empty
    recip = edgesLift1 recip 


instance 
    (FA.ERFnDomApprox box varid domra ranra fa, VariableID varid) =>
    RA.ERApprox (ERFnDomEdgesApprox varid fa)
    where
    getGranularity (ERFnDomEdgesApprox mainEncl edges) =
        RA.getGranularity mainEncl
    setGranularity gran = edgesLift1 (RA.setGranularity gran) 
    setMinGranularity gran = edgesLift1 (RA.setMinGranularity gran)
    f1 /\ f2 = edgesLift2 (RA./\) f1 f2
    intersectMeasureImprovement ix 
            f1@(ERFnDomEdgesApprox mainEncl1 edges1) 
            f2@(ERFnDomEdgesApprox mainEncl2 edges2) 
        | Map.keys edges1 == Map.keys edges2 =
            (ERFnDomEdgesApprox mainEnclIsect edgesIsect,
             ERFnDomEdgesApprox mainEnclImpr edgesImpr)
        | otherwise =
            RA.intersectMeasureImprovement ix f1a f2a
        where
        (f1a, f2a) = unifyEdgeVariables f1 f2
        (mainEnclIsect, mainEnclImpr) =
             RA.intersectMeasureImprovement ix mainEncl1 mainEncl2
        edgesIsect = Map.map fst edgesIsectImpr
        edgesImpr = Map.map snd edgesIsectImpr
        edgesIsectImpr =
            Map.intersectionWith (RA.intersectMeasureImprovement ix) edges1 edges2 
    leqReals fa1 fa2 =
        RA.leqReals (erfnMainVolume fa1) (erfnMainVolume fa2)

instance 
    (FA.ERFnDomApprox box varid domra ranra fa, RA.ERIntApprox fa, VariableID varid) =>
    RA.ERIntApprox (ERFnDomEdgesApprox varid fa)
    where
--    doubleBounds = :: ira -> (Double, Double) 
--    floatBounds :: ira -> (Float, Float)
--    integerBounds :: ira -> (ExtendedInteger, ExtendedInteger)
    bisectDomain maybePt (ERFnDomEdgesApprox mainEncl edges) =
        (ERFnDomEdgesApprox mainEnclLo edgesLo,
         ERFnDomEdgesApprox mainEnclHi edgesHi)
        where
        (mainEnclLo, mainEnclHi) = RA.bisectDomain maybePtMainEncl mainEncl
        edgesLoHi = Map.intersectionWith RA.bisectDomain maybePtEdges edges
        edgesLo = Map.map fst edgesLoHi 
        edgesHi = Map.map snd edgesLoHi 
        (maybePtMainEncl, maybePtEdges) =
            case maybePt of
                Nothing -> 
                    (Nothing, 
                     Map.map (const Nothing) edges)
                Just (ERFnDomEdgesApprox mainEnclPt edgesPt) ->
                    (Just mainEnclPt,
                     Map.map Just edgesPt)
    bounds (ERFnDomEdgesApprox mainEncl edges) =
        (ERFnDomEdgesApprox mainEnclLo edgesLo,
         ERFnDomEdgesApprox mainEnclHi edgesHi)
        where
        (mainEnclLo, mainEnclHi) = RA.bounds mainEncl
        edgesLoHi = Map.map (RA.bounds) edges
        edgesLo = Map.map fst edgesLoHi 
        edgesHi = Map.map snd edgesLoHi
    f1 \/ f2 = edgesLift2 (RA.\/) f1 f2

instance 
    (FA.ERFnDomApprox box varid domra ranra fa, RAEL.ERApproxElementary fa, VariableID varid) =>
    RAEL.ERApproxElementary (ERFnDomEdgesApprox varid fa)
    where
    abs ix = edgesLift1 $ RAEL.abs ix
    exp ix = edgesLift1 $ RAEL.exp ix
    log ix = edgesLift1 $ RAEL.log ix
    sin ix = edgesLift1 $ RAEL.sin ix
    cos ix = edgesLift1 $ RAEL.cos ix
    atan ix = edgesLift1 $ RAEL.atan ix
        
instance 
    (FA.ERFnDomApprox box varid domra ranra fa, VariableID varid) =>
    FA.ERFnApprox box varid domra ranra (ERFnDomEdgesApprox varid fa)
    where
    check prgLocation (ERFnDomEdgesApprox mainEncl edges) =
        ERFnDomEdgesApprox 
            (FA.check prgLocation mainEncl) 
            (Map.mapWithKey checkEdge edges)
        where
        checkEdge (var, pm) edgeFA =
            FA.check (prgLocation ++ showVar var ++ show pm ++ ": ") edgeFA
    domra2ranra fa d =
        FA.domra2ranra (erfnMainVolume fa) d
    ranra2domra fa r =
        FA.ranra2domra (erfnMainVolume fa) r
    setMaxDegree maxDegree = edgesLift1 (FA.setMaxDegree maxDegree)
    getTupleSize (ERFnDomEdgesApprox mainEncl _) =
        FA.getTupleSize mainEncl
    tuple [] = error "ERFnDomEdgesApprox: FA.tuple: empty list"
    tuple fs =
        foldl1 consFs fs 
        where
        consFs = edgesLift2 $ \a b -> FA.tuple [a,b]
    applyTupleFn tupleFn fn = (edgesLift1 $ FA.applyTupleFn tupleFnNoEdges) fn
        where
        tupleFnNoEdges fas =
            map erfnMainVolume $
                tupleFn $
                    map (\fa -> ERFnDomEdgesApprox fa (makeEdges fa (erfnEdges fn))) 
                        fas
        makeEdges fa oldEdges =
            Map.mapWithKey (makeVarPMEdge fa) oldEdges
        makeVarPMEdge fa (var, pm) oldEdge =
            ERFnDomEdgesApprox faNoVar $ makeEdges faNoVar (erfnEdges oldEdge)
            where
            faNoVar =
                FA.partialEval (DBox.singleton var domEndPt) fa
            domEndPt =
                case pm of Minus -> domL; Plus -> domR
            (domL, domR) = RA.bounds dom
            [dom] = DBox.elems $ FA.dom fa
    volume (ERFnDomEdgesApprox mainEncl edges) = FA.volume mainEncl
    scale ratio = edgesLift1 (FA.scale ratio)
    partialIntersect ix substitutions 
            f1@(ERFnDomEdgesApprox mainEncl1 edges1) 
            f2@(ERFnDomEdgesApprox mainEncl2 edges2) 
        | Map.keys edges1 == Map.keys edges2 =
            ERFnDomEdgesApprox (FA.partialIntersect ix substitutions mainEncl1 mainEncl2) $
                Map.intersectionWithKey partialIntersectEdge edges1 edges2
        | otherwise =
            FA.partialIntersect ix substitutions f1a f2a
        where
        (f1a, f2a) = unifyEdgeVariables f1 f2
        partialIntersectEdge (var, pm) edge1 edge2 
            | withinSubstitutions =
                FA.partialIntersect ix substitutions edge1 edge2
            | otherwise = edge1
            where
            withinSubstitutions =
                (varDomEndpoint pm) `RA.refines` varVal
                where
                varVal =
                    DBox.findWithDefault RA.bottomApprox var substitutions
            varDomEndpoint Minus = varDomLO
            varDomEndpoint Plus = varDomHI
            (varDomLO, varDomHI) = RA.bounds varDom
            varDom = DBox.lookup "DomEdges: partialIntersect: " var $ FA.dom mainEncl1 
    eval ptBox (ERFnDomEdgesApprox mainEncl edges) 
        | null edgeVals =
            mainVal
        | otherwise =
            foldl1 (zipWith (RA./\)) edgeVals
        where
        mainVal = FA.eval ptBox mainEncl
        edgeVals = 
            concat $ map edgeEval $ Map.toList edges
        edgeEval ((x, sign), edgeFA) 
            | xPt `RA.refines` xDomLo && sign == Minus =
                [FA.eval ptBoxNoX edgeFA]
            | xPt `RA.refines` xDomHi && sign == Plus =
                [FA.eval ptBoxNoX edgeFA]
            | otherwise = []
            where
            (xDomLo, xDomHi) = RA.bounds xDom
            xDom = DBox.findWithDefault RA.bottomApprox x $ FA.dom mainEncl
            xPt = DBox.findWithDefault RA.bottomApprox x ptBox
            ptBoxNoX = DBox.delete x ptBox
    partialEval substitutions f@(ERFnDomEdgesApprox mainEncl edges) =
        (ERFnDomEdgesApprox mainEnclSubst edgesSubst)
        where
        mainEnclSubst = FA.partialEval substitutions mainEnclSelect
        edgesSubst = 
            Map.map (FA.partialEval substitutionsSelect) $
            Map.filterWithKey (\ (varID,_) _ -> varID `DBox.notMember` substitutionsSelect) edgesSelect
        (ERFnDomEdgesApprox mainEnclSelect edgesSelect, substitutionsSelect) = 
            foldl selectVar (f, substitutions) $ DBox.toList substitutions
        selectVar (fPrev@(ERFnDomEdgesApprox _ edgesPrev), substitutionsPrev) (varID, varVal)
            | varVal `RA.refines` varDomLo =
                (Map.findWithDefault fPrev (varID, Minus) edgesPrev, substitutionsNew) 
            | varVal `RA.refines` varDomHi =
                (Map.findWithDefault fPrev (varID, Plus) edgesPrev, substitutionsNew) 
            | otherwise = (fPrev, substitutionsPrev)
            where
            (varDomLo, varDomHi) = RA.bounds varDom
            varDom = DBox.findWithDefault RA.bottomApprox varID $ FA.dom mainEncl
            substitutionsNew = DBox.delete varID substitutionsPrev
            
instance 
    (FA.ERFnDomApprox box varid domra ranra fa, VariableID varid) =>
    FA.ERFnDomApprox box varid domra ranra (ERFnDomEdgesApprox varid fa)
    where
    dom (ERFnDomEdgesApprox mainEncl edges) = FA.dom mainEncl
    bottomApprox domB tupleSize =
        ERFnDomEdgesApprox (FA.bottomApprox domB tupleSize) $
            Map.fromList $ concat $
                map varEdges $ DBox.toList domB
        where 
        varEdges (varId, _) =
            [((varId, Minus), fEdge), ((varId, Plus), fEdge)]
            where
            fEdge = 
                FA.bottomApprox (DBox.delete varId domB) tupleSize
    const domB vals =
        ERFnDomEdgesApprox (FA.const domB vals) $
            Map.fromList $ concat $
                map varEdges $ DBox.toList domB
        where 
        varEdges (varId, _) =
            [((varId, Minus), fEdge), ((varId, Plus), fEdge)]
            where
            fEdge = 
                FA.const (DBox.delete varId domB) vals
    proj domB i =
        ERFnDomEdgesApprox mainEncl edges
--            Nothing ->
--                error $ 
--                    "DomEdges: projection index " ++ show i 
--                    ++ " out of range for domain " ++ show domB
        where
        mainEncl = FA.proj domB i
        edges =
            Map.fromList $ concat $ map makeVarEdges $ DBox.toList domB
        makeVarEdges (varID, varDom)
            | i == varID =
                [((varID, Minus), FA.const domNoVar [FA.domra2ranra mainEncl idomLo]),
                 ((varID, Plus), FA.const domNoVar [FA.domra2ranra mainEncl idomHi])]
            | otherwise =
                [((varID, Minus), faNoVar),
                 ((varID, Plus), faNoVar)]
            where
            domNoVar = DBox.delete varID domB
            (idomLo, idomHi) = RA.bounds idom
            idom = DBox.lookup "DomEdges: FA.proj: " i domB
            faNoVar = FA.proj domNoVar i
    bisect var maybePt f@(ERFnDomEdgesApprox mainEncl edges) 
        | varAbsent = (f,f)
        | otherwise =
            (ERFnDomEdgesApprox mainEnclLo edgesLo,
             ERFnDomEdgesApprox mainEnclHi edgesHi)
        where
        varAbsent =
            Map.notMember (var, Minus) edges
        (mainEnclLo, mainEnclHi) = FA.bisect var maybePt mainEncl
        pt = 
            case maybePt of 
                Nothing -> RA.defaultBisectPt varDom
                Just pt -> pt
            where
            varDom = 
                DBox.findWithDefault RA.bottomApprox var $ FA.dom mainEncl
        edgesLo =
            Map.insert (var, Minus) (edges Map.! (var, Minus)) $
            Map.insert (var, Plus) fAtPt $
            edgesLoNoVar
        edgesHi =
            Map.insert (var, Minus) fAtPt $
            Map.insert (var, Plus) (edges Map.! (var, Plus)) $
            edgesHiNoVar
        fAtPt = FA.partialEval (DBox.singleton var pt) f
        edgesLoNoVar = Map.map fst edgesLoHiNoVar
        edgesHiNoVar = Map.map snd edgesLoHiNoVar
        edgesLoHiNoVar = 
            Map.map (FA.bisect var maybePt) edgesNoVar
        edgesNoVar = 
            Map.delete (var, Plus) $ Map.delete (var, Minus) edges
    integrate ix fD x integdomBox origin fInit =
        ERFnDomEdgesApprox mainEncl edges
        where
        (ERFnDomEdgesApprox mainEnclD edgesD, 
         fInitWithX@(ERFnDomEdgesApprox _ edgesInitWithX)) = 
            unifyEdgeVariables fD fInit
        (ERFnDomEdgesApprox mainEnclInit edgesInit) = 
            Map.findWithDefault fInitWithX (x, Minus) edgesInitWithX 
        mainEncl = 
            FA.integrate ix mainEnclD x integdomBox origin mainEnclInit
        edges = 
            Map.insert (x, Minus) (FA.partialEval (DBox.singleton x xDomLo) fNoX) $ 
            Map.insert (x, Plus) (FA.partialEval (DBox.singleton x xDomHi) fNoX) $
            edgesNoX
        fNoX = ERFnDomEdgesApprox mainEncl edgesNoX
        edgesNoX =
            Map.intersectionWithKey integrEdge edgesD edgesInit
        (xDomLo, xDomHi) = RA.bounds xDom
        xDom = DBox.findWithDefault RA.bottomApprox x $ FA.dom fD
        integrEdge (varID, _) edgeD edgeInit =
            FA.integrate ix edgeD x (DBox.delete varID integdomBox) origin edgeInit
            
    integrateMeasureImprovement ix fD x integdomBox xOrigin fP =
--        unsafePrint 
--            ("DomEdges: integrateMeasureImprovement: faIntegrLo = " ++ show faIntegrLo)  
        (faIntegr, faImprovement)
        where
        faIntegr =
            faIntegrIsect
--            case RA.compareReals (FA.volume faIntegrIsect) (FA.volume faIntegrRaw) of
--                Just LT -> faIntegrIsect
--                _ -> faIntegrRaw -- this is wrong - forgets initial conditions!
        (faIntegrIsect, faImprovement) = 
            RA.intersectMeasureImprovement ix fP faIntegrRaw
        faIntegrRaw 
            | RA.isExact xOrigin = faIntegrLo
            | otherwise = faIntegrLo RA./\ faIntegrHi
        (xOriginLo, xOriginHi) = RA.bounds xOrigin
        faIntegrLo = 
            FA.integrate ix fD x integdomBox xOriginLo faPxLo      
        faPxLo = 
            FA.partialEval (DBox.singleton x xOriginLo) fP 
        faIntegrHi = 
            FA.integrate ix fD x integdomBox xOriginHi faPxHi      
        faPxHi = 
            FA.partialEval (DBox.singleton x xOriginHi) fP