-- | Efficient enumeration of subsets of triangular elements. Given a list
-- @[1..n]@ we want to enumerate a subset @[(i,j)]@ of ordered pairs in
-- such a way that we only have to hold the elements necessary for this
-- subset in memory.

module Data.Paired.Foldable where

import Data.IntMap as IM
import Data.Foldable as F
import Data.List as L
import Control.Arrow ((***))
import Data.Vector as V
import Data.Vector.Generic as VG
import Debug.Trace (traceShow)
import Text.Printf

import Data.Paired.Common
import Math.TriangularNumbers



-- | Generalized upper triangular elements. Given a list of elements
-- @[e_1,...,e_k]@, we want to return pairs @(e_i,e_j)@ such that we have
-- all ordered pairs with @i<j@ (if @NoDiag@onal elements), or @i<=j@ (if
-- @OnDiag@onal elements).
--
-- @upperTri@ will force the spine of @t a@ but is consumed linearly with
-- a strict @Data.Foldable.foldl'@. Internally we keep a @Data.IntMap@ of
-- the retained elements.
--
-- This is important if the @Enumerate@ type is set to @FromN k n@. We
-- start at the @k@th element, and produce @n@ elements.
--
-- TODO compare @IntMap@ and @HashMap@.
--
-- TODO inRange is broken.

upperTri
  :: (Foldable t)
  => SizeHint
  -- ^ If the size of @t a@ is known beforehand, give the appropriate
  -- @KnownSize n@, otherwise give @UnknownSize@. Using @UnknownSize@ will
  -- force the complete spine of @t a@.
  -> OnDiag
  -- ^ The enumeration will include the pairs on the main diagonal with
  -- @OnDiag@, meaning @(i,i)@ will be included for all @i@. Otherwise,
  -- @NoDiag@ will exclude these elements.
  -> Enumerate
  -- ^ Either enumerate @All@ elements or enumerate the @s@ elements
  -- starting at @k@ with @FromN k s@.
  -> t a
  -- ^ The foldable data structure to enumerate over.
  -> Either String (IntMap a, Int, [((Int,Int),(a,a))])
  -- ^ If there is any error then return @Left errorMsg@. Otherwise we have
  -- @Right (imap, numElems, list)@. The @imap@ structure holds the subset
  -- of elements with which we actually generate elements. @numElems@ is
  -- the total number of elements that will be generated. This is
  -- calculated without touch @list@. Finally, @list@ is the lazy list of
  -- elements to be generated.
upperTri :: SizeHint
-> OnDiag
-> Enumerate
-> t a
-> Either String (IntMap a, Int, [((Int, Int), (a, a))])
upperTri SizeHint
sz OnDiag
d Enumerate
e t a
xs
  | Int
szLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
readLen = String -> Either String (IntMap a, Int, [((Int, Int), (a, a))])
forall a b. a -> Either a b
Left (String -> Either String (IntMap a, Int, [((Int, Int), (a, a))]))
-> String -> Either String (IntMap a, Int, [((Int, Int), (a, a))])
forall a b. (a -> b) -> a -> b
$ String -> Int -> Int -> String
forall r. PrintfType r => String -> r
printf String
"Expected SizeHint %d elements, but processed only %d elements!" Int
szLen Int
readLen
  | Bool
otherwise        = (IntMap a, Int, [((Int, Int), (a, a))])
-> Either String (IntMap a, Int, [((Int, Int), (a, a))])
forall a b. b -> Either a b
Right (IntMap a
imp, Int
numElems, [((Int, Int), (a, a))]
ys)
  where ys :: [((Int, Int), (a, a))]
ys   = case Enumerate
e of {Enumerate
All -> [((Int, Int), (a, a))] -> [((Int, Int), (a, a))]
forall a. a -> a
id ; FromN Int
_ Int
s -> Int -> [((Int, Int), (a, a))] -> [((Int, Int), (a, a))]
forall a. Int -> [a] -> [a]
L.take Int
s}
             ([((Int, Int), (a, a))] -> [((Int, Int), (a, a))])
-> ((Int, Int) -> [((Int, Int), (a, a))])
-> (Int, Int)
-> [((Int, Int), (a, a))]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Int) -> Maybe (((Int, Int), (a, a)), (Int, Int)))
-> (Int, Int) -> [((Int, Int), (a, a))]
forall b a. (b -> Maybe (a, b)) -> b -> [a]
L.unfoldr (Int, Int) -> Maybe (((Int, Int), (a, a)), (Int, Int))
go ((Int, Int) -> [((Int, Int), (a, a))])
-> (Int, Int) -> [((Int, Int), (a, a))]
forall a b. (a -> b) -> a -> b
$ Enumerate -> OnDiag -> (Int, Int)
initEnum Enumerate
e OnDiag
d
        -- how many elements we will emit depends on enumeration and on
        -- diagonal element counting
        numElems :: Int
numElems
          | Enumerate
All <- Enumerate
e       = Int
allSize
          | FromN Int
s Int
k <- Enumerate
e = if Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
allSize then Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 (Int
allSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
s) else Int
k
        -- The length of the input. With a given size hint, @xs :: t a@
        -- will only be touched once.
#if MIN_VERSION_base(4,8,0)
        szLen :: Int
szLen = case SizeHint
sz of { SizeHint
UnknownSize -> t a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
F.length t a
xs ; KnownSize Int
z -> Int
z }
#else
        szLen = case sz of { UnknownSize -> L.length . F.toList $ xs ; KnownSize z -> z }
#endif
        szLn' :: Int
szLn' = case OnDiag
d of { OnDiag
OnDiag -> Int
szLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 ; OnDiag
NoDiag -> Int
szLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 }
        -- Construct an intmap @imp@ of all elements in the accepted range.
        -- At the same time, return the length or size of the foldable
        -- container we gave as input. @xs@ is touched only once and can
        -- be efficiently consumed.
        (!IntMap a
imp,!Int
readLen) = ((IntMap a, Int) -> a -> (IntMap a, Int))
-> (IntMap a, Int) -> t a -> (IntMap a, Int)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' (\(!IntMap a
i,!Int
l) a
x -> (if Int -> Bool
inRange Int
l then Int -> a -> IntMap a -> IntMap a
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
l a
x IntMap a
i else IntMap a
i,Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)) (IntMap a
forall a. IntMap a
IM.empty, Int
0) t a
xs
        allSize :: Int
allSize = Int
szLen Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
szLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if OnDiag
d OnDiag -> OnDiag -> Bool
forall a. Eq a => a -> a -> Bool
== OnDiag
OnDiag then Int
1 else -Int
1) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
        -- we need three ranges. @cMin@ and @cMax@ are the range for the
        -- slow-moving first element in the tuple. @rMin@ and @rMax@ are
        -- the first and last element of the range starting at @cMin@ (we
        -- can actually start at @cMax@ but it doesn't matter).
        -- Finally, @lMin@ and @lMax@ are the range to the left of @cMin@.
        (Int
lMin,Int
lMax,Int
cMin,Int
cMax,Int
rMin,Int
rMax) = case Enumerate
e of
          Enumerate
All -> (Int
0, Int
szLenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1, Int
0, Int
szLenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1, Int
0, Int
szLenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
          FromN Int
s Int
k ->
            let (Int
cmin,Int
rmin) = Int -> Int -> (Int, Int)
fromLinear Int
szLn' Int
s
                (Int
cmax,Int
_   ) = Int -> Int -> (Int, Int)
fromLinear Int
szLn' (Int
sInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k)
                rmax :: Int
rmax = Int
rminInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k -- if this is @>= len@ we are safe anyway.
                lmin :: Int
lmin = if Int
rminInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
szLen then Int
0 else Int
cmin
                lmax :: Int
lmax = if Int
rminInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
szLen then Int
lmin Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> (Int, Int) -> Int
toLinear Int
szLn' (Int
cminInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
cminInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
rminInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
szLn') else Int
cmax
            in  (Int
lmin, Int
lmax, Int
cmin, Int
cmax, Int
rmin, Int
rmax)
        -- Determine if an element at linear index @z@ is in the range to
        -- be consumed.
        inRange :: Int -> Bool
inRange Int
z =  Int
lMin Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
z Bool -> Bool -> Bool
&& Int
z Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
lMax
                  Bool -> Bool -> Bool
|| Int
cMin Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
z Bool -> Bool -> Bool
&& Int
z Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
cMax
                  Bool -> Bool -> Bool
|| Int
rMin Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
z Bool -> Bool -> Bool
&& Int
z Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
rMax
        -- index into the generated vector @xs@ when generating elements
        -- via @go@
        go :: (Int, Int) -> Maybe (((Int, Int), (a, a)), (Int, Int))
go (Int
k,Int
l)
          | Int
k Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
szLen  = Maybe (((Int, Int), (a, a)), (Int, Int))
forall a. Maybe a
Nothing
          | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
szLen  = (Int, Int) -> Maybe (((Int, Int), (a, a)), (Int, Int))
go (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ if OnDiag
d OnDiag -> OnDiag -> Bool
forall a. Eq a => a -> a -> Bool
== OnDiag
OnDiag then Int
0 else Int
1)
          | Bool
otherwise = (((Int, Int), (a, a)), (Int, Int))
-> Maybe (((Int, Int), (a, a)), (Int, Int))
forall a. a -> Maybe a
Just (((Int
k,Int
l),(IntMap a
imp IntMap a -> Int -> a
forall a. IntMap a -> Int -> a
IM.! Int
k, IntMap a
imp IntMap a -> Int -> a
forall a. IntMap a -> Int -> a
IM.! Int
l)), (Int
k,Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
        -- Initialize the enumeration at the correct pair @(i,j)@. From
        -- then on we can @take@ the correct number of elements, or stream
        -- all of them.
        initEnum :: Enumerate -> OnDiag -> (Int, Int)
initEnum Enumerate
All OnDiag
OnDiag = (Int
0,Int
0)
        initEnum Enumerate
All OnDiag
NoDiag = (Int
0,Int
1)
        initEnum (FromN Int
s Int
k) OnDiag
OnDiag
          | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
allSize = (Int
szLen,Int
szLen)
          | Bool
otherwise    = Int -> Int -> (Int, Int)
fromLinear Int
szLn' Int
s
        initEnum (FromN Int
s Int
k) OnDiag
NoDiag
          | Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
allSize = (Int
szLen,Int
szLen)
          | Bool
otherwise    = Int -> Int
forall a. a -> a
id (Int -> Int) -> (Int -> Int) -> (Int, Int) -> (Int, Int)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) ((Int, Int) -> (Int, Int)) -> (Int, Int) -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> (Int, Int)
fromLinear Int
szLn' Int
s