-- | Sharing for UEs, based on IntMaps.  The idea is to share subexpressions of 'UE's.

module Language.Atom.UeMap 
  ( UeElem (..)
  , MUV (..)
  , UeMap
  , emptyMap
  , Hash
  , typeOf
  , UeState
  , recoverUE
  , getUE
  , newUE
  , newUV
--  , share
  , maybeUpdate
  , ueUpstream
  , nearestUVs
  , arrayIndices
  , isMathHCall
  ) where

import Control.Monad.State.Strict
--import qualified Data.IntMap as M
import qualified Data.Bimap as M
import Data.List (nub)

import Language.Atom.Expressions hiding (typeOf)
import qualified Language.Atom.Expressions as E

type Hash = Int

-- | Untyped variables map.
data MUV
  = MUV Int String Const
  | MUVArray UA Hash
  | MUVExtern String Type
  deriving (Show, Eq, Ord)

-- | Transforms a 'UV' into a 'MUV', returning the possibly updated map.
newUV :: UV -> UeMap -> (MUV, UeMap)
newUV uv mp = 
  case uv of
    UV i j k       -> (MUV i j k, mp)
    UVExtern i j   -> (MUVExtern i j, mp)
    UVArray arr ue -> let (h,mp') = newUE ue mp in
                      (MUVArray arr h, mp')

-- | Corresponds to 'UE's --- the elements in the sharing structure.
data UeElem
  = MUVRef !MUV
  | MUConst !Const
  | MUCast  !Type !Hash
  | MUAdd   !Hash !Hash
  | MUSub   !Hash !Hash
  | MUMul   !Hash !Hash
  | MUDiv   !Hash !Hash
  | MUMod   !Hash !Hash
  | MUNot   !Hash
  | MUAnd   [Hash]
  | MUBWNot !Hash
  | MUBWAnd !Hash !Hash
  | MUBWOr  !Hash !Hash
  | MUShift !Hash !Int
  | MUEq    !Hash !Hash
  | MULt    !Hash !Hash
  | MUMux   !Hash !Hash !Hash
  | MUF2B   !Hash
  | MUD2B   !Hash
  | MUB2F   !Hash
  | MUB2D   !Hash
-- math.h:
  | MUPi
  | MUExp   !Hash
  | MULog   !Hash
  | MUSqrt  !Hash
  | MUPow   !Hash !Hash
  | MUSin   !Hash
  | MUAsin  !Hash
  | MUCos   !Hash
  | MUAcos  !Hash
  | MUSinh  !Hash
  | MUCosh  !Hash
  | MUAsinh !Hash
  | MUAcosh !Hash
  | MUAtan  !Hash
  | MUAtanh !Hash
  deriving (Show, Eq, Ord)

typeOf :: Hash -> UeMap -> Type
typeOf h mp = case getUE h mp of
    MUVRef (MUV _ _ a) -> E.typeOf a
    MUVRef (MUVArray a _) -> E.typeOf a
    MUVRef (MUVExtern _ t) -> t
    MUCast t _  -> t
    MUConst c   -> E.typeOf c
    MUAdd a _   -> typeOf' a
    MUSub a _   -> typeOf' a
    MUMul a _   -> typeOf' a
    MUDiv a _   -> typeOf' a
    MUMod a _   -> typeOf' a
    MUNot _     -> Bool
    MUAnd _     -> Bool
    MUBWNot a   -> typeOf' a
    MUBWAnd a _ -> typeOf' a
    MUBWOr  a _ -> typeOf' a
    MUShift a _ -> typeOf' a
    MUEq  _ _   -> Bool
    MULt  _ _   -> Bool
    MUMux _ a _ -> typeOf' a
    MUF2B _     -> Word32
    MUD2B _     -> Word64
    MUB2F _     -> Float
    MUB2D _     -> Double
-- math.h:
    MUPi        -> Double
    MUExp   a   -> typeOf' a
    MULog   a   -> typeOf' a
    MUSqrt  a   -> typeOf' a
    MUPow   a _ -> typeOf' a
    MUSin   a   -> typeOf' a
    MUAsin  a   -> typeOf' a
    MUCos   a   -> typeOf' a
    MUAcos  a   -> typeOf' a
    MUSinh  a   -> typeOf' a
    MUCosh  a   -> typeOf' a
    MUAsinh a   -> typeOf' a
    MUAcosh a   -> typeOf' a
    MUAtan  a   -> typeOf' a
    MUAtanh a   -> typeOf' a
  where 
  typeOf' h' = typeOf h' mp

-- | An entry in the Map.
type UeMap = (Hash, M.Bimap Int UeElem)

-- | Wrapped in the State Monad.
type UeState a = State UeMap a

-- | Get the element associated with a 'Hash' value.  It's an error if the
-- element is not in the map.
getUE :: Hash -> UeMap -> UeElem
getUE h (_,mp) = 
  case M.lookup h mp of
    Nothing -> error $ "Error looking up hash " ++ show h ++ " in the UE map\n" ++ show mp
    Just e -> e

-- | Put a new 'UE' in the map, unless it's already in there, and return the
-- hash pointing to the 'UE' and a new map.
newUE :: UE -> UeMap -> (Hash, UeMap)
newUE ue mp = runState (share ue) mp

emptyMap :: UeMap
emptyMap = (0, M.empty)

-- | Create the sharing map.
share :: UE -> UeState Hash
share e = case e of 
  UVRef (UV i j k) -> maybeUpdate (MUVRef $ MUV i j k)
  UVRef (UVExtern i j) -> maybeUpdate (MUVRef $ MUVExtern i j)
  UVRef (UVArray arr a) -> unOp a (\x -> MUVRef (MUVArray arr x))
  UConst  a     -> maybeUpdate (MUConst a)
  UCast   t a   -> unOp a (MUCast t)
  UAdd    a b   -> binOp (a,b) MUAdd
  USub    a b   -> binOp (a,b) MUSub
  UMul    a b   -> binOp (a,b) MUMul
  UDiv    a b   -> binOp (a,b) MUDiv
  UMod    a b   -> binOp (a,b) MUMod
  UNot    a     -> unOp a MUNot
  UAnd    ls    -> listOp ls MUAnd
  UBWNot  a     -> unOp a MUBWNot
  UBWAnd  a b   -> binOp (a,b) MUBWAnd
  UBWOr   a b   -> binOp (a,b) MUBWOr
  UShift  a b   -> unOp a (\x -> MUShift x b)
  UEq     a b   -> binOp (a,b) MUEq
  ULt     a b   -> binOp (a,b) MULt
  UMux    a b c -> triOp (a,b,c) MUMux
  UF2B    a     -> unOp a MUF2B
  UD2B    a     -> unOp a MUD2B
  UB2F    a     -> unOp a MUB2F
  UB2D    a     -> unOp a MUB2D
-- math.h:
  UPi           -> maybeUpdate (MUPi)
  UExp    a     -> unOp a MUExp
  ULog    a     -> unOp a MULog
  USqrt   a     -> unOp a MUSqrt
  UPow    a b   -> binOp (a,b) MUPow
  USin    a     -> unOp a MUSin
  UAsin   a     -> unOp a MUAsin
  UCos    a     -> unOp a MUCos
  UAcos   a     -> unOp a MUAcos
  USinh   a     -> unOp a MUSinh
  UCosh   a     -> unOp a MUCosh
  UAsinh  a     -> unOp a MUAsinh
  UAcosh  a     -> unOp a MUAcosh
  UAtan   a     -> unOp a MUAtan
  UAtanh  a     -> unOp a MUAtanh

-- XXX I could combine some of the following functions (unOp, binOp, etc.) to
-- slightly reduce code...
unOp :: UE -> (Hash -> UeElem) -> UeState Hash
unOp e code = do
  h <- share e  
  maybeUpdate (code h)

binOp :: (UE, UE) -> (Hash -> Hash -> UeElem) -> UeState Hash
binOp (e0,e1) code = do
  h0 <- share e0  
  h1 <- share e1  
  maybeUpdate (code h0 h1)

triOp :: (UE, UE, UE) -> (Hash -> Hash -> Hash -> UeElem) -> UeState Hash
triOp (e0,e1,e2) code = do
  h0 <- share e0   
  h1 <- share e1  
  h2 <- share e2  
  maybeUpdate (code h0 h1 h2)

listOp :: [UE] -> ([Hash] -> UeElem) -> UeState Hash
listOp es code = do
  hashes <- foldM (\hashes e -> do h <- share e 
                                   return (h:hashes)
                  ) [] es
  maybeUpdate (code hashes)

-- | Lookup an element in the map, and if it's in there, do nothing, but return
-- its hash value.  Otherwise, update the map and return the new hash value
-- for the inserted element.
maybeUpdate :: UeElem -> UeState Hash
maybeUpdate e = do
  st <- get
  let mp = snd st
  case M.lookupR e mp of
    Nothing -> do let hash = fst st + 1
                  put (hash, M.insert hash e mp)
                  return hash
    Just h -> return h

-- -- Lookup an elem, returning 'Nothing' if no hash exists in the map and 'Just'
-- -- the hash value otherwise.
-- getHash :: UeElem -> UeMap -> Maybe Hash
-- getHash e mp = M.lookupR e


-- ((k,e'):_) | e == e' = Just k
-- getHash e (_:es) | otherwise = getHash e es
-- getHash _ [] = Nothing

-- | Get a 'UE' back out of the 'UeMap'.
recoverUE :: UeMap -> Hash -> UE
recoverUE st h = case getUE h st of
  MUVRef (MUV i j k)     -> UVRef (UV i j k)
  MUVRef (MUVArray i a)  -> UVRef (UVArray i (recover' a))
  MUVRef (MUVExtern i j) -> UVRef (UVExtern i j)
  MUCast t a   -> UCast   t (recover' a)
  MUConst a    -> UConst  a
  MUAdd a b    -> UAdd    (recover' a) (recover' b)
  MUSub a b    -> USub    (recover' a) (recover' b)
  MUMul a b    -> UMul    (recover' a) (recover' b)
  MUDiv a b    -> UDiv    (recover' a) (recover' b)
  MUMod a b    -> UMod    (recover' a) (recover' b)
  MUNot a      -> UNot (recover' a)
  MUAnd a      -> UAnd $ map recover' a
  MUBWNot a    -> UBWNot (recover' a)
  MUBWAnd a b  -> UBWAnd  (recover' a) (recover' b)
  MUBWOr  a b  -> UBWOr   (recover' a) (recover' b)
  MUShift a b  -> UShift  (recover' a) b
  MUEq  a b    -> UEq     (recover' a) (recover' b)
  MULt  a b    -> ULt     (recover' a) (recover' b)
  MUMux a b c  -> UMux (recover' a) (recover' b) (recover' c)
  MUF2B a      -> UF2B    (recover' a)
  MUD2B a      -> UD2B    (recover' a)
  MUB2F a      -> UB2F    (recover' a)
  MUB2D a      -> UB2D    (recover' a)
-- math.h:
  MUPi         -> UPi
  MUExp   a    -> UExp    (recover' a)
  MULog   a    -> ULog    (recover' a)
  MUSqrt  a    -> USqrt   (recover' a)
  MUPow   a b  -> UPow (recover' a) (recover' b)
  MUSin   a    -> USin    (recover' a)
  MUAsin  a    -> UAsin   (recover' a)
  MUCos   a    -> UCos    (recover' a)
  MUAcos  a    -> UAcos   (recover' a)
  MUSinh  a    -> USinh   (recover' a)
  MUCosh  a    -> UCosh   (recover' a)
  MUAsinh a    -> UAsinh  (recover' a)
  MUAcosh a    -> UAcosh  (recover' a)
  MUAtan  a    -> UAtan   (recover' a)
  MUAtanh a    -> UAtanh  (recover' a)
  where recover' h' = recoverUE st h'

-- | The list of Hashes to adjacent upstream of a UE.
ueUpstream :: Hash -> UeMap -> [Hash]
ueUpstream h t = case getUE h t of
  MUVRef (MUV _ _ _)     -> []
  MUVRef (MUVArray _ a)  -> [a]
  MUVRef (MUVExtern _ _) -> []
  MUCast _ a   -> [a]
  MUConst _    -> []
  MUAdd a b    -> [a, b]
  MUSub a b    -> [a, b]
  MUMul a b    -> [a, b]
  MUDiv a b    -> [a, b]
  MUMod a b    -> [a, b]
  MUNot a      -> [a]
  MUAnd a      -> a
  MUBWNot a    -> [a]
  MUBWAnd a b  -> [a, b]
  MUBWOr  a b  -> [a, b]
  MUShift a _  -> [a]
  MUEq  a b    -> [a, b]
  MULt  a b    -> [a, b]
  MUMux a b c  -> [a, b, c]
  MUF2B a      -> [a]
  MUD2B a      -> [a]
  MUB2F a      -> [a]
  MUB2D a      -> [a]
-- math.h:
  MUPi         -> []
  MUExp   a    -> [a]
  MULog   a    -> [a]
  MUSqrt  a    -> [a]
  MUPow   a b  -> [a, b]
  MUSin   a    -> [a]
  MUAsin  a    -> [a]
  MUCos   a    -> [a]
  MUAcos  a    -> [a]
  MUSinh  a    -> [a]
  MUCosh  a    -> [a]
  MUAsinh a    -> [a]
  MUAcosh a    -> [a]
  MUAtan  a    -> [a]
  MUAtanh a    -> [a]

-- | The list of all UVs that directly control the value of an expression.
nearestUVs :: Hash -> UeMap -> [MUV]
nearestUVs h mp = nub $ f h
  where
  f :: Hash -> [MUV]
  f hash = case getUE hash mp of
             (MUVRef uv@(MUVArray _ h')) -> [uv] ++ f h'
             (MUVRef uv)                 -> [uv]
             _                           -> concatMap f $ ueUpstream hash mp

-- | All array indexing subexpressions.
arrayIndices :: Hash -> UeMap -> [(UA, Hash)]
arrayIndices h mp = nub $ f h
  where
  f :: Hash -> [(UA, Hash)]
  f hash = case getUE hash mp of
             (MUVRef (MUVArray ua h')) -> (ua, h') : f h'
             _ -> concatMap f $ ueUpstream hash mp

-- XXX can put this back after making UE map---won't be expensive.
isMathHCall :: UeElem -> Bool
isMathHCall fc = 
  case fc of
    MUPi        -> True
    MUExp   _   -> True
    MULog   _   -> True
    MUSqrt  _   -> True
    MUPow   _ _ -> True
    MUSin   _   -> True
    MUAsin  _   -> True
    MUCos   _   -> True
    MUAcos  _   -> True
    MUSinh  _   -> True
    MUCosh  _   -> True
    MUAsinh _   -> True
    MUAcosh _   -> True
    MUAtan  _   -> True
    MUAtanh _   -> True
    _          -> False