{-# LANGUAGE  GeneralizedNewtypeDeriving 
  , NoMonomorphismRestriction 
  , BangPatterns #-}
module Control.Monad.Atom
    ( MonadAtom (..)
    , AtomTable
    , Atom
    , AtomT
    , empty
    , evalAtom
    , evalAtomT
    , runAtom
    , runAtomT
    , atoms
    )
where
import Control.Monad.State
import Control.Monad.Identity
import qualified Data.Map as Map
import qualified Data.IntMap as IntMap
import qualified Data.Binary as B
import qualified Data.ByteString as Strict
import qualified Data.ByteString.Lazy as Lazy

type Blob = Strict.ByteString

data AtomTable = T { lastID :: {-# UNPACK #-} !Int 
                   , to     :: Map.Map Blob Int 
                   , from   :: IntMap.IntMap Blob } 
                   


instance B.Binary AtomTable where
    put t = do B.put (lastID t) 
               B.put (to t)
               B.put (from t)
    get = do liftM3 T B.get B.get B.get


class Monad m => MonadAtom m where
    -- | Monadically convert the argument into an atom (represented as an Int)
    toAtom      :: B.Binary a => a -> m Int
    -- | Monadically convert the argument into an atom, but only if 
    -- the corresponding atom has already been created
    maybeToAtom :: B.Binary a => a -> m (Maybe Int)
    -- | Monadically convert an atom represented as an Int to its 
    -- corresponding object
    fromAtom    :: B.Binary a => Int -> m a


instance Monad m => MonadAtom (AtomT m) where
    toAtom x = AtomT $ do
      let b = enc x
      t <- get
      case Map.lookup b (to t) of
        Just j -> return $! j
        Nothing -> do 
                 let i = lastID t
                     i' = i + 1 
                     !t' = t { lastID = i'
                             , to = Map.insert b  i (to t) 
                             , from = IntMap.insert i b (from t) }
                 put t'
                 return $! lastID t

    maybeToAtom x = 
        AtomT $ do
          t <- get
          return . Map.lookup (enc x) . to $ t
            
    fromAtom i = AtomT $ do
      t <- get
      return . dec $ (from t) IntMap.! i

table = AtomT get

empty :: AtomTable
empty = T 0 Map.empty IntMap.empty

runAtomT :: AtomT t t1 -> AtomTable -> t (t1, AtomTable)
runAtomT (AtomT x) s = runStateT x s

runAtom :: Atom t -> AtomTable -> (t, AtomTable)
runAtom (Atom x) s = runIdentity (runAtomT x s)


evalAtom :: Atom t -> t
evalAtom = fst . flip runAtom empty

evalAtomT :: (Monad m) => AtomT m a -> m a
evalAtomT = liftM fst . flip runAtomT empty

newtype AtomT m r = AtomT (StateT AtomTable m r)
    deriving (Functor,Monad,MonadTrans,MonadIO)

newtype Atom r = Atom (AtomT Identity r)
    deriving (Functor,Monad,MonadAtom)

-- | The list of atoms (as Ints) stored in the atom table
atoms :: AtomTable -> [Int]
atoms = IntMap.keys . from

enc = Strict.concat . Lazy.toChunks . B.encode
dec = B.decode . Lazy.fromChunks . return

example :: [String]
example = evalAtom $ do 
  xs <- mapM toAtom . map show $ [1,2,3,1,2,3]
  zs <- mapM fromAtom xs 
  return zs