{-# LANGUAGE GADTs #-}

module Data.STRef.Zoom 
  ( STRef
  , zoomSTRef
  , newSTRef
  , pairSTRefs
  , readSTRef
  , modifySTRef
  , modifySTRef'
  , writeSTRef
  ) where

import Control.Monad.ST
import qualified Data.STRef as ST
import Control.Lens

data STRef s a where 
  Leaf   :: ST.STRef s x              -> ALens'  x    a -> STRef s a
  Branch ::    STRef s x -> STRef s y -> ALens' (x,y) a -> STRef s a

zoomSTRef :: ALens' a b -> STRef s a -> STRef s b
zoomSTRef l1 (Leaf v l2) = Leaf v . fusing $ cloneLens l2 . cloneLens l1
zoomSTRef l1 (Branch x y l2) = Branch x y . fusing $ cloneLens l2 . cloneLens l1

newSTRef :: a -> ST s (STRef s a)
newSTRef a = Leaf <$> ST.newSTRef a <*> pure id

pairSTRefs :: STRef s a -> STRef s b -> STRef s (a,b)
pairSTRefs x y = Branch x y id

readSTRef :: STRef s a -> ST s a
readSTRef (Leaf v l) = (^#l) <$> ST.readSTRef v
readSTRef (Branch x y l) = (^#l) <$> readBranch x y

modifySTRef :: STRef s a -> (a -> a) -> ST s ()
modifySTRef (Leaf v l) f = ST.modifySTRef v $ l #%~ f
modifySTRef (Branch x'ref y'ref l) f = do
  (x,y) <- (l #%~ f) <$> readBranch x'ref y'ref
  writeSTRef x'ref x
  writeSTRef y'ref y

modifySTRef' :: STRef s a -> (a -> a) -> ST s ()
modifySTRef' (Leaf v l) f = ST.modifySTRef' v $ l #%~ f
modifySTRef' b f = modifySTRef b f

writeSTRef :: STRef s a -> a -> ST s ()
writeSTRef (Leaf v l) a = ST.modifySTRef' v $ l #~ a
writeSTRef b a = modifySTRef b $ const a

readBranch :: STRef s a -> STRef s b -> ST s (a,b)
readBranch x y = (,) <$> readSTRef x <*> readSTRef y