module Data.Extensible.Struct (
  
  Struct
  , set
  , get
  , new
  , newRepeat
  , newFor
  , newFromHList
  , WrappedPointer(..)
  , (-$>)
  
  , atomicModify
  , atomicModify'
  , atomicModify_
  , atomicModify'_
  
  , (:*)
  , unsafeFreeze
  , newFrom
  , hlookup
  , hlength
  , hfoldrWithIndex
  , thaw
  , hfrozen
  , hmodify
  , toHList) where
import GHC.Prim
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Constraint
import Data.Extensible.Class
import Data.Extensible.Internal
import Data.Extensible.Internal.Rig
import Data.Extensible.Wrapper
import Control.Comonad
import Data.Profunctor.Rep
import Data.Profunctor.Sieve
import qualified Data.StateVar as V
import Data.Typeable (Typeable)
import qualified Data.Extensible.HList as L
import GHC.Types
data Struct s (h :: k -> *) (xs :: [k]) = Struct (SmallMutableArray# s Any)
  deriving Typeable
set :: PrimMonad m => Struct (PrimState m) h xs -> Membership xs x -> h x -> m ()
set (Struct m) (getMemberId -> I# i) e = primitive
  $ \s -> case unsafeCoerce# writeSmallArray# m i e s of
    s' -> (# s', () #)
get :: PrimMonad m => Struct (PrimState m) h xs -> Membership xs x -> m (h x)
get (Struct m) (getMemberId -> I# i) = primitive $ unsafeCoerce# readSmallArray# m i
atomicModify :: PrimMonad m
  => Struct (PrimState m) h xs -> Membership xs x -> (h x -> (h x, a)) -> m a
atomicModify (Struct m) (getMemberId -> I# i) f = primitive
  $ \s0 -> case readSmallArray# m i s0 of
    (# s, x #) -> retry x s
  where
    retry x s = let p = unsafeCoerce# f x in
      case casSmallArray# m i x (fst p) s of
        (# s', b, y #) -> case b of
          0# -> (# s', snd p #)
          _ -> retry y s'
atomicModify' :: PrimMonad m
  => Struct (PrimState m) h xs -> Membership xs x -> (h x -> (h x, a)) -> m a
atomicModify' s i f = atomicModify s i
  (\x -> let (y, a) = f x in (y, y `seq` a))
  >>= (return $!)
atomicModify_ :: PrimMonad m
  => Struct (PrimState m) h xs -> Membership xs x -> (h x -> h x) -> m (h x)
atomicModify_ (Struct m) (getMemberId -> I# i) f = primitive
  $ \s0 -> case readSmallArray# m i s0 of
    (# s, x #) -> retry x s
  where
    retry x s = case casSmallArray# m i x (unsafeCoerce# f x) s of
      (# s', b, y #) -> case b of
        0# -> (# s', unsafeCoerce# y #)
        _ -> retry y s'
atomicModify'_ :: PrimMonad m
  => Struct (PrimState m) h xs -> Membership xs x -> (h x -> h x) -> m (h x)
atomicModify'_ s i f = atomicModify_ s i f >>= (return $!)
data WrappedPointer s h a where
  WrappedPointer :: !(Struct s h xs)
    -> !(Membership xs x)
    -> WrappedPointer s h (Repr h x)
instance (s ~ RealWorld, Wrapper h) => V.HasGetter (WrappedPointer s h a) a where
  get (WrappedPointer s i) = liftIO $ view _Wrapper <$> get s i
instance (s ~ RealWorld, Wrapper h) => V.HasSetter (WrappedPointer s h a) a where
  WrappedPointer s i $= v = liftIO $ set s i $ review _Wrapper v
instance (s ~ RealWorld, Wrapper h) => V.HasUpdate (WrappedPointer s h a) a a where
  WrappedPointer s i $~ f = liftIO $ void $ atomicModify_ s i $ over _Wrapper f
  WrappedPointer s i $~! f = liftIO $ void $ atomicModify'_ s i $ over _Wrapper f
(-$>) :: forall k h xs v s. (Associate k v xs) => Struct s h xs -> Proxy k -> WrappedPointer s h (Repr h (k ':> v))
s -$> _ = WrappedPointer s (association :: Membership xs (k ':> v))
new :: forall h m xs. (PrimMonad m, Generate xs)
  => (forall x. Membership xs x -> h x)
  -> m (Struct (PrimState m) h xs)
new = newDict Dict
newDict :: PrimMonad m
  => Dict (Generate xs)
  -> (forall x. Membership xs x -> h x)
  -> m (Struct (PrimState m) h xs)
newDict Dict k = do
  m <- newRepeat undefined
  henumerate (\i cont -> set m i (k i) >> cont) $ return m
newRepeat :: forall h m xs. (PrimMonad m, Generate xs)
  => (forall x. h x)
  -> m (Struct (PrimState m) h xs)
newRepeat x = do
  let !(I# n) = hcount (Proxy :: Proxy xs)
  primitive $ \s -> case newSmallArray# n (unsafeCoerce# x) s of
    (# s', a #) -> (# s', Struct a #)
newFor :: forall proxy c h m xs. (PrimMonad m, Forall c xs)
  => proxy c
  -> (forall x. c x => Membership xs x -> h x)
  -> m (Struct (PrimState m) h xs)
newFor = newForDict Dict
newForDict :: forall proxy c h m xs. (PrimMonad m)
  => Dict (Forall c xs)
  -> proxy c
  -> (forall x. c x => Membership xs x -> h x)
  -> m (Struct (PrimState m) h xs)
newForDict Dict p k = do
  m <- newRepeat undefined
  henumerateFor p (Proxy :: Proxy xs) (\i cont -> set m i (k i) >> cont) $ return m
newFromHList :: forall h m xs. PrimMonad m => L.HList h xs -> m (Struct (PrimState m) h xs)
newFromHList l = do
  let !(I# size) = L.hlength l
  m <- primitive $ \s -> case newSmallArray# size undefined s of
    (# s', a #) -> (# s', Struct a #)
  let go :: Int -> L.HList h t -> m ()
      go _ L.HNil = return ()
      go i (L.HCons x xs) = set m (unsafeMembership i) x >> go (i + 1) xs
  go 0 l
  return m
data (h :: k -> *) :* (s :: [k]) = HProduct (SmallArray# Any)
unsafeFreeze :: PrimMonad m => Struct (PrimState m) h xs -> m (h :* xs)
unsafeFreeze (Struct m) = primitive $ \s -> case unsafeFreezeSmallArray# m s of
  (# s', a #) -> (# s', HProduct a #)
thaw :: PrimMonad m => h :* xs -> m (Struct (PrimState m) h xs)
thaw (HProduct ar) = primitive $ \s -> case thawSmallArray# ar 0# (sizeofSmallArray# ar) s of
  (# s', m #) -> (# s', Struct m #)
hlength :: h :* xs -> Int
hlength (HProduct ar) = I# (sizeofSmallArray# ar)
unsafeMembership :: Int -> Membership xs x
unsafeMembership = unsafeCoerce#
hfoldrWithIndex :: (forall x. Membership xs x -> h x -> r -> r) -> r -> h :* xs -> r
hfoldrWithIndex f r p = foldr
  (\i -> let m = unsafeMembership i in f m (hlookup m p)) r [0..hlength p  1]
toHList :: forall h xs. h :* xs -> L.HList h xs
toHList p = go 0 where
  go :: Int -> L.HList h xs
  go i
    | i == hlength p = unknownHList L.HNil
    | otherwise = unknownHList $ L.HCons (hlookup (unsafeMembership i) p) (go (i + 1))
  unknownHList :: L.HList h ys -> L.HList h zs
  unknownHList = unsafeCoerce#
newFrom :: forall g h m xs. (PrimMonad m)
  => g :* xs
  -> (forall x. Membership xs x -> g x -> h x)
  -> m (Struct (PrimState m) h xs)
newFrom hp@(HProduct ar) k = do
  let !n = sizeofSmallArray# ar
  st <- primitive $ \s -> case newSmallArray# n undefined s of
    (# s', a #) -> (# s', Struct a #)
  let go i
        | i == I# n = return st
        | otherwise = do
          let !m = unsafeMembership i
          set st m $! k m (hlookup m hp)
          go (i + 1)
  go 0
hlookup :: Membership xs x -> h :* xs -> h x
hlookup (getMemberId -> I# i) (HProduct ar) = case indexSmallArray# ar i of
  (# a #) -> unsafeCoerce# a
hfrozen :: (forall s. ST s (Struct s h xs)) -> h :* xs
hfrozen m = runST $ m >>= unsafeFreeze
hmodify :: (forall s. Struct s h xs -> ST s ()) -> h :* xs -> h :* xs
hmodify f m = runST $ do
  s <- thaw m
  f s
  unsafeFreeze s
instance (Corepresentable p, Comonad (Corep p), Functor f) => Extensible f p (:*) where
  
  pieceAt i pafb = cotabulate $ \ws -> sbt (extract ws) <$> cosieve pafb (hlookup i <$> ws) where
    sbt xs !x = hmodify (\s -> set s i x) xs