{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-|
    Module      :  Control.ERNet.Blocks.Real.LFT
    Description :  real protocol using LFTs as digits
    Copyright   :  (c) Michal Konecny
    License     :  BSD3

    Maintainer  :  mik@konecny.aow.cz
    Stability   :  experimental
    Portability :  portable

    A protocol for sending a real number using a stream
    of LFT digits based on the work of Potts and Edalat (1997). 

-}
module Control.ERNet.Blocks.Real.LFT 
(
    -- * protocol
    QALFTRealQ(..),
    QALFTRealA(..),
    LFTDigit(..),
    lftDigit2Tensor,
    chTLFTReal,
    -- * processes
    lftRealNumberIncremProcess,
    lftRealNumberBufferForkProcess,
--    lftBinaryTensorIncremProcess,
--    lftParametrisedBinaryTensorIncremProcess,
    -- * arithmetic
    LFTTensor(..),
    lftConst,
    lftMatrix,
    lftTensorBinary,
    lftTensorInfo,
    lftTensorIsPositive,
    lftTensorCompose,
    lftTensorComposeUnary
)
where

import Control.ERNet.Foundations.Protocol
import Control.ERNet.Foundations.Protocol.StandardCombinators
import qualified Control.ERNet.Foundations.Channel as CH
import Control.ERNet.Foundations.Process

import Control.ERNet.Blocks.Basic

import qualified Data.Number.ER.Real.Approx as RA
import Data.Number.ER.BasicTypes
import Data.Number.ER.Real.DefaultRepr

import Data.Number.ER.ShowHTML
import qualified Text.Html as H
import Data.Number.ER.Misc

import qualified Data.Map as Map

import Data.Typeable

instance 
    (QAProtocol QALFTRealQ QALFTRealA)
    where
    qaMatch (QALFTRealQ n) (QALFTRealA digits)
        | length digits == 1 = Nothing -- when using incrementally  
        | length digits == n = Nothing -- non-incremental interpretation
        | otherwise = Just "Prtocol QALFTReal: wrong number of digits in answer" 


data QALFTRealQ
    = QALFTRealQ Int
    deriving (Eq, Ord, Show, Typeable)

data QALFTRealA
    = QALFTRealA [LFTDigit]
    deriving (Eq, Ord, Show, Typeable)

data LFTDigit =
    LFT_L | LFT_M | LFT_R | LFT_SG_ZER | LFT_SG_INF | LFT_SG_POS | LFT_SG_NEG
    deriving (Eq, Ord, Typeable)

instance Show LFTDigit 
    where
    show LFT_L = "L"
    show LFT_M = "M"
    show LFT_R = "R"
    show LFT_SG_ZER = "S[-1,1]"
    show LFT_SG_INF = "S[1,-1]"
    show LFT_SG_POS = "S[0,oo]"
    show LFT_SG_NEG = "S[oo,0]"

instance H.HTML LFTDigit
    where
    toHtml = toHtmlDefault   
    
instance H.HTML QALFTRealQ 
    where
    toHtml = toHtmlDefault 
    
instance H.HTML QALFTRealA
    where
    toHtml (QALFTRealA digits) =
        H.toHtml $
            concat $ map show digits 
            
chTLFTReal :: ChannelType
chTLFTReal = ChannelType (QALFTRealQ 3) (QALFTRealA [LFT_SG_ZER, LFT_R, LFT_M])


{-| 
    Interpret the LFT digits as LFTs.
-}
lftDigit2Tensor :: LFTDigit -> LFTTensor
lftDigit2Tensor d =
     tensor
     where
     tensor = 
        case d of
            LFT_L -> lftMatrix 1 0 1 2
            LFT_M -> lftMatrix 3 1 1 3
            LFT_R -> lftMatrix 2 1 0 1
            LFT_SG_INF -> lftMatrix 1 1 (-1) 1
            LFT_SG_NEG -> lftMatrix 0 1 (-1) 0
            LFT_SG_ZER -> lftMatrix 1 (-1) 1 1
            LFT_SG_POS -> lftMatrix 1 0 0 1
            

lftDigits2Tensor :: 
    [LFTDigit] -> 
    LFTTensor
lftDigits2Tensor digits = 
    foldl1 lftTensorComposeUnary  $ map lftDigit2Tensor digits

lftDigitsInfo :: 
    (RA.ERApprox ra) =>
    Granularity ->
    [LFTDigit] -> 
    ExtInterval ra
lftDigitsInfo gran digits =
    lftTensorInfoUnary gran $ lftDigits2Tensor digits
    

{-|
    A process communicating a real number to a single client
    incrementally digit by digit.
-}    
lftRealNumberIncremProcess ::
    (CH.Channel sIn sOut sInAnyProt sOutAnyProt,
     RA.ERIntApprox ra, Typeable ra) =>
    ERProcessName ->
    (EffortIndex -> ra) 
    {-^
        the number to represent; 
        intersection of this sequence has to converge to a singleton 
     -} -> 
    ERProcess sInAnyProt sOutAnyProt
lftRealNumberIncremProcess defName xByIx =
    constantStatefulProcess defName responderTransferFn initState (chTChanges chTLFTReal)
    where
--    initState :: 
--        (
--         Int, -- the prvious query
--         (LFTTensor, EffortIndex, ra) 
--           -- the composition of the previously emitted digits
--           -- and information about progress with using xByIx
--        )
    initState = (0, (lftTensorId, 10, xByIx 10))
    
    responderTransferFn
            prevState@(prevQuery, precomp@(prevTensor, prevIx, prevRA)) 
            (qryId, qry) =
        -- non-incremental query
--        unsafePrint "A" $
        case qry of
            QAChangesQ (QALFTRealQ n) 
                | n == 1 && prevQuery == 0 ->
            -- initial query 
                (
                 (False, QAChangesANew $ QALFTRealA [signDigit])
                , 
                 Just (1, (signTensor, signIx, signRA))
                )
            QAChangesQWhenNew prevQry (QALFTRealQ n) 
                | prevQry == prevQuery && prevQry == (n - 1) ->
                (
                 (False, QAChangesANew $ QALFTRealA $ [newDigit])
                , 
                 Just (n, (newTensor, newIx, newRA))
                )
            _ ->
                error $
                    "ERNet.Blocks.Real.LFT: lftRealNumberProcess: " ++
                    "query " ++ show qry ++ " is not strictly incremental."
        where
        (signDigit, signTensor, signIx, signRA) =
            searchSignDigit prevIx prevRA
        (newDigit, newTensor, newIx, newRA) =
            searchDigit (10 + prevQuery) prevTensor prevIx prevRA
            
    searchSignDigit ix ra
        | ra `refinesRAExtInt` infoZ = (LFT_SG_ZER, tZ, ix, ra)
        | ra `refinesRAExtInt` infoI = (LFT_SG_INF, tI, ix, ra)
        | ra `refinesRAExtInt` infoP = (LFT_SG_POS, tP, ix, ra)
        | ra `refinesRAExtInt` infoN = (LFT_SG_NEG, tN, ix, ra)
        | otherwise =
            searchSignDigit nextIx (xByIx nextIx)
        where
        nextIx = ix + 1
        infoZ = lftTensorInfoUnary gran tZ
        infoI = lftTensorInfoUnary gran tI
        infoP = lftTensorInfoUnary gran tP
        infoN = lftTensorInfoUnary gran tN
        gran = 10
    searchDigit gran tensor ix ra
        | ra `refinesRAExtInt` infoL = (LFT_L, compWithL, ix, ra)
        | ra `refinesRAExtInt` infoM = (LFT_M, compWithM, ix, ra)
        | ra `refinesRAExtInt` infoR = (LFT_R, compWithR, ix, ra)
        | otherwise =
            searchDigit gran tensor nextIx nextRA
        where
        nextIx = ix + 1
        nextRA = xByIx nextIx
        infoL = lftTensorInfoUnary gran compWithL
        compWithL = lftTensorComposeUnary tensor tL
        infoM = lftTensorInfoUnary gran compWithM 
        compWithM = lftTensorComposeUnary tensor tM
        infoR = lftTensorInfoUnary gran compWithR 
        compWithR = lftTensorComposeUnary tensor tR
    [tZ, tI, tP, tN, tL, tM, tR] = 
        map lftDigit2Tensor 
            [LFT_SG_ZER, LFT_SG_INF, LFT_SG_POS, LFT_SG_NEG, LFT_L, LFT_M, LFT_R]
        
{-|
    A process that receives a real number incrementally digit by digit
    and makes it available to multiple clients incrementally or non-incrementally.
-}    
lftRealNumberBufferForkProcess ::
    (CH.Channel sIn sOut sInAnyProt sOutAnyProt) =>
    ERProcessName ->
    ERProcess sInAnyProt sOutAnyProt
lftRealNumberBufferForkProcess defName =
    passThroughStatefulProcess 
        defName qryStFn ansStFn initState 
        (chTChanges chTLFTReal) (chTChanges chTLFTReal)
    where
    initState :: 
        (
         Map.Map QueryId Int, -- memory of past queries
         [LFTDigit],
           -- memoising past replies
           -- the number of digits equals the highest answered query so far
           -- the most significant digit is last
         Int, -- the higest query answered so far
         Int
            -- current highest target query
            -- it is larger or equal to the number above 
            -- the numbers are equal iff each incoming query either has been answered
            --   or can be answered without further queries
        )
    initState = (Map.empty, [], 0, 0)
    
    qryStFn 
            prevState@(pastQueries, prevDigits, largestAnswered, target) 
            (qryId, qry1) =
        case (n > largestAnswered, largestAnswered == target, n > target) of
            (False, _, _) -> -- can answer immediately
                (
                 ERProcessActionAnswer False $ QAChangesANew $ QALFTRealA $ reverse $ 
                    case qry1 of
                        QAChangesQ (QALFTRealQ n) ->
                            drop (largestAnswered - n) prevDigits
                        QAChangesQWhenNew prevQry (QALFTRealQ n) ->
                            take (n - p) $
                                drop (largestAnswered - n) prevDigits
                            where
                            p = 
                                case Map.lookup prevQry pastQueries of
                                    Just p -> p
                                    Nothing -> 
                                        error $ "ERNet.Blocks.Real.LFT: lftRealNumberProcess:" 
                                            ++ " query refers to non-existent previous query" 
                , 
                 Just (pastQueriesNew, prevDigits, largestAnswered, target)
                )
            (True, False, False) -> -- have to wait until the desired level is reached
                (
                 ERProcessActionRetryWhen $ \ (_, _, largestAnswered, _) -> largestAnswered >= n 
                , 
                 Just (pastQueriesNew, prevDigits, largestAnswered, target)
                )
            (True, False, True) -> -- have to wait until the previous target is reached  
                (
                 ERProcessActionRetryWhen $ \ (_, _, largestAnswered, _) -> largestAnswered == target 
                , 
                 Just (pastQueriesNew, prevDigits, largestAnswered, target)
                )
            (True, True, True) -> -- have to set a new target and start working towards it  
                (
                 ERProcessActionQuery $ 
                    case largestAnswered > 0 of
                        True -> QAChangesQWhenNew largestAnswered $ QALFTRealQ (largestAnswered + 1)
                        False ->  QAChangesQ $ QALFTRealQ (largestAnswered + 1)
                , 
                 Just (pastQueriesNew, prevDigits, largestAnswered, n)
                )
        where
        pastQueriesNew = Map.insert qryId n pastQueries
        n = 
            case qry1 of
                QAChangesQ (QALFTRealQ n) -> n
                QAChangesQWhenNew prevQuery (QALFTRealQ n) -> n
        
    ansStFn 
            prevState@(pastQueries, prevDigits, largestAnswered, target) 
            (qryId, qry1)
            (_, QAChangesANew (QALFTRealA [newDigit])) 
            =
        -- add ans2 to the state
        -- if target reached, reply qry1, else, make another query
        (action, Just newState)
        where
        newState = (pastQueries, newDigits, largestAnsweredNew, target)
        newDigits = newDigit : prevDigits
        largestAnsweredNew = largestAnswered + 1
        action 
            | largestAnsweredNew == target =
                ERProcessActionAnswer True $ QAChangesANew $ QALFTRealA $ reverse $
                    case qry1 of
                        QAChangesQ (QALFTRealQ n) ->
                            drop (largestAnswered - n) newDigits
                        QAChangesQWhenNew prevQry (QALFTRealQ n) ->
                            take (n - p) $
                                drop (largestAnswered - n) newDigits
                            where
                            p = 
                                case Map.lookup prevQry pastQueries of
                                    Just p -> p
                                    Nothing -> 
                                        error "ERNet.Blocks.Real.LFT: lftRealNumberProcess: query refers to non-existent previous query" 
            | otherwise =
                ERProcessActionQuery $ 
                    QAChangesQWhenNew largestAnsweredNew $ QALFTRealQ (largestAnsweredNew + 1)         
        
{-| 
    A multi-dimensional linear fractional transformation with integer coefficients. 
-}
data LFTTensor =
    LFTTensor
    {
        lftTNSrank :: Int,
        lftTNScoeffs :: Map.Map [Bool] Integer 
        -- ^ the first Bool indicates whether or not the term is in the numerator of the LFT
    }
    deriving (Show)
    
{-| 
    Constructor for a 0-ary LFT with integer coefficients. 
-}
lftConst :: 
    Integer -> Integer -> 
    LFTTensor
lftConst a b =
    LFTTensor 1 $ Map.fromList 
        [([True], a), 
         ([False], b)]

{-| 
    Constructor for a unary LFT with integer coefficients. 
-}
lftMatrix :: 
    Integer -> Integer -> Integer -> Integer -> 
    LFTTensor
lftMatrix a b c d =
    LFTTensor 2 $ Map.fromList 
        [([True, True], a), ([True, False], b), 
         ([False, True], c), ([False, False], d)]

lftTensorId :: LFTTensor
lftTensorId = lftMatrix 1 0 0 1

{-|
    Constructor for a binary LFT with integer coefficients. 
-}
lftTensorBinary :: 
    Integer -> Integer -> Integer -> Integer -> 
    Integer -> Integer -> Integer -> Integer -> 
    LFTTensor
lftTensorBinary a b c d e f g h =
    LFTTensor 3 $ Map.fromList 
        [([True, True, True], a), ([True, True, False], b), 
                ([True, False, True], c), ([True, False, False], d),
         ([False, True, True], e), ([False, True, False], f),
                ([False, False, True], g), ([False, False, False], h)]

data ExtInterval ra =
    ExtInterval 
    {
        extIntervalL :: ra, 
        extIntervalR :: ra
    }
    deriving (Show)

refinesRAExtInt ::
    (RA.ERIntApprox ra) =>
    ra ->
    ExtInterval ra ->
    Bool
ra `refinesRAExtInt` (ExtInterval lRA rRA) =
    case lRA `RA.compareReals` rRA of
        Just LT -> 
            ra `RA.refines` (lRA RA.\/ rRA)
        Just GT -> 
            (ra `RA.refines` (lRA RA.\/ infty))
            ||
            (ra `RA.refines` ((-infty) RA.\/ rRA))
    where
    infty = RA.plusInfinity

{-|
    Work out what interval is the image of the lft when all
    variables are given the value [0,oo].
    The returned interval may be slightly bigger than the
    exact image due to rounding but it always contains the
    whole exact image. 
-}
lftTensorInfo ::
    (RA.ERApprox ra) =>
    Granularity -> 
    LFTTensor ->
    ExtInterval ra
lftTensorInfo gran t@(LFTTensor n coeffs) 
    | n == 1 = lftTensorInfoConst gran t
    | n == 2 = lftTensorInfoUnary gran t
--    | n == 3 = lftTensorInfoBinary gran t
    | otherwise =
        error $ "ERNet.Blocks.Real.LFT: lftTensorInfo: unsupported rank " ++ show n

{-|
    Like "lftTensorInfo" but assuming the lft is constant.
-}
lftTensorInfoConst ::
    (RA.ERApprox ra) =>
    Granularity -> 
    LFTTensor ->
    ExtInterval ra
lftTensorInfoConst gran t@(LFTTensor _ coeffs) =
    ExtInterval ratio ratio
    where
    ratio = aRA / bRA 
    aRA = RA.setMinGranularityOuter gran $ fromInteger a 
    bRA = RA.setMinGranularityOuter gran $ fromInteger b 
    [b, a] = map snd $ Map.toAscList coeffs

{-|
    Like "lftTensorInfo" but assuming the lft is unary.
-}
lftTensorInfoUnary ::
    (RA.ERApprox ra) =>
    Granularity -> 
    LFTTensor ->
    ExtInterval ra
lftTensorInfoUnary gran t@(LFTTensor _ coeffs) 
    | det == 0 =
        ExtInterval ratioL ratioL
    | det > 0 =
        ExtInterval ratioL ratioR
    | det < 0 =
        ExtInterval ratioR ratioL
    where
    ratioL = bRA / dRA 
    ratioR = aRA / cRA 
    aRA = RA.setMinGranularityOuter gran $ fromInteger a 
    bRA = RA.setMinGranularityOuter gran $ fromInteger b 
    cRA = RA.setMinGranularityOuter gran $ fromInteger c 
    dRA = RA.setMinGranularityOuter gran $ fromInteger d
    [d, c, b, a] = map snd $ Map.toAscList coeffs
    det = a * d - b * c

--{-|
--    Like "lftTensorInfo" but assuming the lft is binary.
---}
--lftTensorInfoBinary ::
--    (RA.ERApprox ra) =>
--    Granularity -> 
--    LFTTensor ->
--    ExtInterval ra
--lftTensorInfoBinary gran t@(LFTTensor _ coeffs) 
--    | det == 0 =
--        ExtInterval ratioLL ratioLL
--    where
--    ratioLL = dRA / hRA 
--    ratioLR = cRA / gRA 
--    ratioRL = bRA / fRA 
--    ratioRR = aRA / eRA 
--    [aRA, bRA, cRA, dRA, eRA, fRA, gRA, hRA] = 
--        map (RA.setMinGranularity gran . fromInteger) 
--        [a, b, c, d, e, f, g, h] 
--    [h, g, f, e, d, c, b, a] = map snd $ Map.toAscList coeffs
--    det = a * d - b * c

lftTensorIsPositive ::
    LFTTensor ->
    Bool
lftTensorIsPositive t@(LFTTensor n coeffs) =
    allEqualNotZero cornerSigns
    where 
    allEqualNotZero [] = True
    allEqualNotZero (x:xs) = 
        sx /= 0 && (and $ map (((==) sx) . signum) xs)
        where
        sx = signum x
    cornerSigns =
        zipWith signOfVector coeffsUp coeffsDown
    (coeffsDown, coeffsUp) =
        splitAt ((length coeffsList) `div` 2) coeffsList
    coeffsList =
        map snd $ Map.toAscList coeffs
    signOfVector a b 
        | a == 0 = signum b 
        | a < 0 && b > 0 = 0
        | a < 0 = -1
        | a > 0 && b < 0 = 0
        | a > 0 = 1
    
{-|
    Compose two unary LFTs, ie substituting one into the other. 
-}
lftTensorComposeUnary :: 
    LFTTensor ->
    LFTTensor ->
    LFTTensor 
lftTensorComposeUnary lft1 lft2 =
    lftTensorCompose lft1 1 lft2

{-|
    Compose two LFTs, ie substituting one into another 
    using one of its variables.  
-}
lftTensorCompose :: 
    LFTTensor ->
    Int ->
    LFTTensor ->
    LFTTensor 
lftTensorCompose (LFTTensor n1 coeffs1) k (LFTTensor n2 coeffs2) =
     LFTTensor n $ Map.fromList
        [(tensorIndex, tensorCoeff tensorIndex) | 
            tensorIndex <- allIndices n
        ]
     where
     n = n1 + n2 - 2
     tensorCoeff tensorIndex =
        sum $ map getCoeffProd [True, False]
        where
        getCoeffProd linkingIndexComponent =
            coeff1 * coeff2
            where
            coeff1 = coeffs1 Map.! index1 
            coeff2 = coeffs2 Map.! index2
            (index1pre, index2pre) = splitAt (n1 - 1) tensorIndex
            index1 = insertAt k linkingIndexComponent index1pre
            index2 = linkingIndexComponent : index2pre

insertAt k a as =
    preK ++ [a] ++ postK
    where
    (preK, postK) = splitAt k as
     
allIndices n 
    | n == 0 = [[]]
    | otherwise =
        (map (True : ) indicesNM1) ++ (map (False :) indicesNM1)
    where
    indicesNM1 =
        allIndices (n - 1)