module Data.Total.Map.Sparse where
import           Data.Bytes.Serial
import           Data.Key
import           Data.List (sort)
import           Data.Map (Map)
import qualified Data.Map as Map
import           Data.Maybe
import           Data.Monoid (First(..))
import           Data.Semigroup hiding (First, getFirst)
import           Data.Total.Internal.SparseFold
import           Data.Total.Map
import           Linear
import           Prelude ()
import           Prelude.Compat hiding (zip, lookup)
data TotalSparseMap k a = TotalSparseMap (Map k a) a
    deriving (Show, Read, Functor)
instance (Ord k, Enum k, Bounded k, Eq a) => Eq (TotalSparseMap k a) where
    a == b = and ((==) <$> a <*> b)
instance (Ord k, Enum k, Bounded k, Ord a) => Ord (TotalSparseMap k a) where
    compare a b = fromMaybe EQ $ getFirst $ foldMap (First . notEq) (compare <$> a <*> b)
      where
        notEq EQ = Nothing
        notEq x  = Just x
instance Ord k => Applicative (TotalSparseMap k) where
    pure x = TotalSparseMap Map.empty x
    (<*>)  = zap
instance (Ord k, Enum k, Bounded k) => Foldable (TotalSparseMap k) where
    foldMap f (TotalSparseMap m d) = runSparseFold (f d) $ \_ ->
           foldPoint (toInteger (fromEnum (minBound :: k))  1) mempty
        <> Map.foldMapWithKey (\k v -> foldPoint (toInteger (fromEnum k)) (f v)) m
        <> foldPoint (toInteger (fromEnum (maxBound :: k)) + 1) mempty
type instance Key (TotalSparseMap k) = k
instance Ord k => Lookup (TotalSparseMap k) where
    lookup k (TotalSparseMap m d) =
      case lookup k m of
        Nothing -> Just d
        x -> x
instance Ord k => Indexable (TotalSparseMap k) where
    index (TotalSparseMap m d) k =
      case lookup k m of
        Nothing -> d
        Just x -> x
instance Ord k => Adjustable (TotalSparseMap k) where
    adjust f k (TotalSparseMap m d) = TotalSparseMap (Map.alter f' k m) d
      where
        f' (Just x) = Just (f x)
        f' Nothing = Just (f d)
    replace k v (TotalSparseMap m d) = TotalSparseMap (replace k v m) d
instance Ord k => Zip (TotalSparseMap k) where
    zip (TotalSparseMap m1 d1) (TotalSparseMap m2 d2) =
      TotalSparseMap
        (Map.mergeWithKey
          (\_ a b -> Just (a, b))
          (fmap (, d2))
          (fmap (d1, ))
          m1 m2)
        (d1, d2)
instance Ord k => Additive (TotalSparseMap k) where
    zero = pure 0
instance (Ord k, Enum k, Bounded k) => Metric (TotalSparseMap k)
instance (Ord k, Enum k, Bounded k, Serial k) => Serial1 (TotalSparseMap k) where
    serializeWith f (TotalSparseMap m d) = do
        serializeWith f m
        f d
    deserializeWith f = TotalSparseMap
        <$> deserializeWith f
        <*> f
instance (Ord k, Enum k, Bounded k, Serial k, Serial a)
         => Serial (TotalSparseMap k a) where
    serialize m = serializeWith serialize m
    deserialize = deserializeWith deserialize
toDenseMap :: (Ord k, Enum k, Bounded k) => TotalSparseMap k a -> TotalMap k a
toDenseMap (TotalSparseMap m d) = TotalMap (Map.union m fallback)
  where
    TotalMap fallback = pure d