module Data.Total.Map where
import           Data.Bytes.Serial
import           Data.Distributive
import           Data.Functor.Rep
import           Data.Key
import           Data.List (sort)
import           Data.Map (Map)
import qualified Data.Map as Map
import           Linear
import           Prelude ()
import           Prelude.Compat hiding (zip)
newtype TotalMap k a = TotalMap (Map k a)
    deriving (Eq, Ord, Show, Read, Functor, Foldable, Traversable)
instance (Ord k, Enum k, Bounded k) => Applicative (TotalMap k) where
    pure x = TotalMap $ Map.fromList [(k, x) | k <- [minBound .. maxBound]]
    (<*>)  = zap
type instance Key (TotalMap k) = k
deriving instance Keyed (TotalMap k)
deriving instance Ord k => Zip (TotalMap k)
deriving instance Ord k => ZipWithKey (TotalMap k)
deriving instance Ord k => Lookup (TotalMap k)
deriving instance Ord k => Indexable (TotalMap k)
deriving instance Ord k => Adjustable (TotalMap k)
deriving instance Ord k => FoldableWithKey (TotalMap k)
instance Ord k => TraversableWithKey (TotalMap k) where
    traverseWithKey f (TotalMap m) = TotalMap <$> traverseWithKey f m
instance (Ord k, Enum k, Bounded k) => Additive (TotalMap k) where
    zero = pure 0
instance (Ord k, Enum k, Bounded k) => Metric (TotalMap k)
instance (Ord k, Enum k, Bounded k) => Serial1 (TotalMap k) where
    serializeWith f (TotalMap m) = serializeWith f (Map.elems m)
    deserializeWith f = do
        elems <- deserializeWith f
        let assocs = zip (sort [minBound .. maxBound]) elems
        return $ TotalMap (Map.fromDistinctAscList assocs)
instance (Ord k, Enum k, Bounded k, Serial a) => Serial (TotalMap k a) where
    serialize m = serializeWith serialize m
    deserialize = deserializeWith deserialize
instance (Ord k, Enum k, Bounded k) => Distributive (TotalMap k) where
    distribute = TotalMap . Map.fromDistinctAscList
               . zip keys
               . distributeList . fmap asList
      where
        keys = sort [minBound .. maxBound]
        asList (TotalMap m) = Map.elems m
        distributeList x = map (fmap head) $ iterate (fmap tail) x
instance (Ord k, Enum k, Bounded k) => Representable (TotalMap k) where
    type Rep (TotalMap k) = k
    tabulate f = TotalMap $ Map.fromList [(k, f k) | k <- [minBound .. maxBound]]
    index = Data.Key.index