module Data.Total.Array.Subset (
    Subset,
    TotalSubsetArray(..)
    ) where
import           Data.Bytes.Serial
import           Data.Distributive
import           Data.Functor.Rep
import           Data.Key
import           Data.Proxy
import           Data.Reflection
import           Data.Set (Set)
import qualified Data.Set as Set
import           Data.Total.Subset
import           Data.Vector (Vector)
import qualified Data.Vector as Vector
import           Linear
import           Prelude ()
import           Prelude.Compat hiding (zip, zipWith)
infixr 9 .:
newtype TotalSubsetArray s k a = TotalSubsetArray (Vector a)
    deriving (Eq, Ord, Show, Read, Functor, Foldable, Traversable)
keyCount :: Subset s k => Proxy s -> Int
keyCount p = Set.size (reflect p)
keys' :: Subset s k => Proxy s -> Vector k
keys' p = Vector.fromListN (keyCount p) $ Set.toAscList (reflect p)
toIndex :: (Ord k, Subset s k) => Proxy s -> k -> Int
toIndex p k = Set.findIndex k (reflect p)
keys :: forall s k. Subset s k => TotalSubsetArray s k k
keys = TotalSubsetArray (keys' (Proxy :: Proxy s))
(.:) :: (c -> d) -> (a -> b -> c) -> a -> b -> d
(f .: g) x y = f (g x y)
instance Subset s k => Applicative (TotalSubsetArray s k) where
    pure = TotalSubsetArray . Vector.replicate (keyCount (Proxy :: Proxy s))
    (<*>) = zap
type instance Key (TotalSubsetArray s k) = k
instance Subset s k => Keyed (TotalSubsetArray s k) where
    mapWithKey f v = zipWith f keys v
instance Zip (TotalSubsetArray s k) where
    zipWith f (TotalSubsetArray a) (TotalSubsetArray b) =
        TotalSubsetArray $ Vector.zipWith f a b
instance Subset s k => ZipWithKey (TotalSubsetArray s k) where
    zipWithKey f a b = zipWith (uncurry . f) keys (zip a b)
instance (Ord k, Subset s k) => Lookup (TotalSubsetArray s k) where
    lookup k (TotalSubsetArray v) =
        Just $ Vector.unsafeIndex v (toIndex (Proxy :: Proxy s) k)
instance (Ord k, Subset s k) => Indexable (TotalSubsetArray s k) where
    index (TotalSubsetArray v) k =
        Vector.unsafeIndex v (toIndex (Proxy :: Proxy s) k)
instance (Ord k, Subset s k) => Adjustable (TotalSubsetArray s k) where
    adjust f k (TotalSubsetArray v) = TotalSubsetArray $ Vector.unsafeUpd v [(i, x)]
      where
        i = toIndex (Proxy :: Proxy s) k
        x = f $ Vector.unsafeIndex v i
instance Subset s k => FoldableWithKey (TotalSubsetArray s k) where
    foldMapWithKey f v = foldMap (uncurry f) (zip keys v)
instance Subset s k => TraversableWithKey (TotalSubsetArray s k) where
    traverseWithKey f v = traverse (uncurry f) (zip keys v)
instance Subset s k => Additive (TotalSubsetArray s k) where
    zero = pure 0
instance Subset s k => Metric (TotalSubsetArray s k)
instance Subset s k => Serial1 (TotalSubsetArray s k) where
    serializeWith f (TotalSubsetArray v) = Vector.mapM_ f v
    deserializeWith f = TotalSubsetArray
        <$> Vector.replicateM (keyCount (Proxy :: Proxy s)) f
instance (Subset s k, Serial a) => Serial (TotalSubsetArray s k a) where
    serialize m = serializeWith serialize m
    deserialize = deserializeWith deserialize
instance Subset s k => Distributive (TotalSubsetArray s k) where
    distribute x = TotalSubsetArray $ Vector.generate
        (keyCount (Proxy :: Proxy s)) (\i -> fmap (index' i) x)
      where
        index' i (TotalSubsetArray v) = Vector.unsafeIndex v i
instance (Ord k, Subset s k) => Representable (TotalSubsetArray s k) where
    type Rep (TotalSubsetArray s k) = k
    tabulate f = fmap f keys
    index = Data.Key.index