{-
Copyright (c) 2008 Jim Snow
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
1. Redistributions of source code must retain the above copyright
   notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
   notice, this list of conditions and the following disclaimer in the
   documentation and/or other materials provided with the distribution.
3. The name of the author may not be used to endorse or promote products
   derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

-}

{-# OPTIONS_GHC -funbox-strict-fields #-}
{-# LANGUAGE BangPatterns #-}

module Data.Glome.Bih (bih) where
import Data.Glome.Vec
import Data.Glome.Solid
import Data.List hiding (group) -- for "partition"


-- Bounding Interval Heirarchy
-- http://en.wikipedia.org/wiki/Bounding_interval_hierarchy

data Bih = Bih {bihbb :: Bbox, bihroot :: BihNode} deriving Show
data BihNode = BihLeaf !SolidItem 
             | BihBranch {lmax :: !Flt, rmin :: !Flt, ax :: !Int, 
                          l :: BihNode, r :: BihNode} deriving Show

-- bih construction
-- create a leaf node from a list of objects
-- we use "group" so we can treat a bunch of objects as a single object
build_leaf :: [(Bbox, SolidItem)] -> BihNode
build_leaf objs =
 BihLeaf (group (map snd objs))

-- return surface area of a bounding box that encloses bounding boxes
-- divided by the surface area of the nodebox

-- this doesn't seem to be much of a win

optimality :: [(Bbox, SolidItem)] -> Bbox -> Flt
optimality objs bb =
 let bbsurf = bbsa bb
     go [] accbb = accbb
     go ((obb,_):xs) accbb = go xs (bbjoin obb accbb)
     obbsurf = bbsa $! bboverlap (go objs empty_bbox) bb
 in
  obbsurf / bbsurf

-- tuning parameter that controls threshold for separating
-- large objects from small objects instead of usual left/right
-- sorting 

-- was 0.3

max_bih_sa = 0.4 :: Flt

-- Recursive constructor, it looks like quicksort if you squint hard enough.
-- We split along the splitbox's axis of greatest extent, then sort objects
-- to one side or the other (they can overlap the center), then construct the
-- branch node and recurse.

-- I added a nonstandard heuristic: if there's a few very large objects and a lot
-- of small ones, we create one branch with big objects and the other with small
-- objects, instead of sorting by location.

build_rec :: [(Bbox,SolidItem)] -> Bbox -> Bbox -> Int -> BihNode
build_rec objs nodebox@(Bbox nodeboxp1 nodeboxp2) splitbox@(Bbox splitboxp1 splitboxp2) depth = 

 if (length (take 3 objs) < 2) -- && (optimality objs nodebox) > 0.2
 then build_leaf objs
 else
  let axis  = vmaxaxis (vsub splitboxp2 splitboxp1)
      bbmin = va splitboxp1 axis
      bbmax = va splitboxp2 axis
      candidate = (bbmin + bbmax) * 0.5
  in
   if candidate > (va nodeboxp2 axis) then
    build_rec objs nodebox 
              (Bbox splitboxp1 (vset splitboxp2 axis candidate)) 
              depth
   else
    if candidate < (va nodeboxp1 axis) then
     build_rec objs nodebox (
               Bbox (vset splitboxp1 axis candidate) splitboxp2) 
               depth
    else
     -- not sure if this is a big win
     let nbsa = bbsa nodebox
         (big,small) = partition (\ (bb,_) -> 
                                   (bbsa bb) > (nbsa * max_bih_sa)) objs
     in 
      if (not $ null big) && ((length big) < ((length small)*2))
      then (BihBranch (va nodeboxp2 0) (va nodeboxp1 0) 0
                      (build_rec big nodebox splitbox (depth+1))
                      (build_rec small nodebox splitbox (depth+1)) )
      else
       let (l,r) = partition (\((Bbox bbp1 bbp2),_)-> 
                               (((va bbp1 axis)+(va bbp2 axis))*0.5) 
                                 < candidate ) objs
           lmax = foldl fmax (-infinity) (map (\((Bbox _ p2),_) -> va p2 axis) l)
           rmin = foldl fmin   infinity  (map (\((Bbox p1 _),_) -> va p1 axis) r)
           (lsplit,rsplit) = bbsplit splitbox axis candidate
           lnb  = (Bbox nodeboxp1 (vset nodeboxp2 axis lmax))
           rnb  = (Bbox (vset nodeboxp1 axis rmin) nodeboxp2)
       in
        -- stop if there's no progress being made
        if ((null l) && (rmin <= bbmin)) ||
           ((null r) && (lmax >= bbmax))
        then build_leaf objs
        else
         (BihBranch (lmax+delta) (rmin-delta) axis
                    (build_rec l lnb lsplit (depth+1))
                    (build_rec r rnb rsplit (depth+1)) )

-- | The bih constructor creates a Bounding Interval Heirarchy
-- from a list of primitives.  BIH is a type of data structure
-- that groups primitives into a heirarchy of bounding objects,
-- so that a ray need not be tested against every single
-- primitive.  This can make the difference betweeen a rendering
-- job that takes days or seconds.  BIH usually performs a little
-- worse than a SAH-based KD-tree, but construction time is much
-- better.
--
-- See http://en.wikipedia.org/wiki/Bounding_interval_hierarchy

bih :: [SolidItem] -> SolidItem
bih [] = SolidItem Void
-- bih (sld:[]) = sld  -- sometimes we'd like to be able to use a
                       -- single object bih just for its aabb
bih slds =
 let objs = map (\x -> ((bound x),x)) (flatten_group slds)
     bb   = foldl bbjoin empty_bbox (map (\(b,_)->b) objs)
     root = build_rec objs bb bb 0
     (Bbox (Vec p1x p1y p1z) (Vec p2x p2y p2z)) = bb
 in
  if p1x == (-infinity) || p1y == (-infinity) || p1z == (-infinity) ||
     p2x == infinity    || p2y == infinity    || p2z == infinity
  then
   error $ "bih: infinite bounding box " ++ (show objs)
  else
   SolidItem (Bih bb root)

-- Standard ray traversal.
rayint_bih :: Bih -> Ray -> Flt -> Texture -> Rayint 
rayint_bih (Bih bb root) !r@(Ray orig dir) !d t =
 let dir_rcp = vrcp dir
     Interval near far = bbclip r bb
     traverse (BihLeaf s) near far = rayint s r (fmin d far) t
     traverse (BihBranch lsplit rsplit axis l r) near far =
       let dirr = va dir_rcp axis
           o    = va orig axis
           dl   = (lsplit - o) * dirr
           dr   = (rsplit - o) * dirr
       in  
           if near > far 
           then RayMiss
           else
            if dirr > 0
            then 
             (nearest
              (if near < dl
               then traverse l near (fmin dl far)
               else RayMiss)
              (if dr < far
               then traverse r (fmax dr near) far
               else RayMiss))
            else
             (nearest
              (if near < dr
               then traverse r near (fmin dr far)
               else RayMiss)
              (if dl < far
               then traverse l (fmax dl near) far
               else RayMiss))
 in
  traverse root near far

-- Ray traversal with debug counter.  The counter gets incremented
-- when we hit a box.
rayint_debug_bih :: Bih -> Ray -> Flt -> Texture -> (Rayint,Int) 
rayint_debug_bih (Bih bb root) r@(Ray orig dir) d t =
 let dir_rcp = vrcp dir
     Interval near far = bbclip r bb
     traverse (BihLeaf s) near far = rayint_debug s r (fmin d far) t
     traverse (BihBranch lsplit rsplit axis l r) near far =
       let dirr = va dir_rcp axis
           o    = va orig axis
           dl   = (lsplit - o) * dirr
           dr   = (rsplit - o) * dirr
       in 
         debug_wrap 
          (if near > far 
           then (RayMiss,0)
           else
            if dirr > 0
            then 
             (nearest_debug
              (if near < dl
               then traverse l near (fmin dl far)
               else (RayMiss,0))
              (if dr < far
               then traverse r (fmax dr near) far
               else (RayMiss,0)))
            else
             (nearest_debug
              (if near < dr
               then traverse r near (fmin dr far)
               else (RayMiss,0))
              (if dl < far
               then traverse l (fmax dl near) far
               else (RayMiss,0))))
          1 -- increment the debug value for every box we hit
 in
  traverse root near far

-- This is unwieldy, but the performance gains
-- sometimes make it worthwhile.  By testing 4 rays against 
-- each cell, we (theoretically) do ~1/4 the 
-- memory accesses. 

-- This originally made a big difference, but after switching
-- everything to typeclasses, it doesn't perform any better
-- than regular traversal.

-- One simplifying assumption we make that adds a 
-- little bit of overhead:  If one ray hits a cell, 
-- we act as though they all do.  For that reason,
-- this only works well with coherent rays.

packetint_bih :: Bih -> Ray -> Ray -> Ray -> Ray -> Flt -> Texture -> PacketResult
packetint_bih bih@(Bih bb root) 
              !r1@(Ray orig1 dir1) 
              !r2@(Ray orig2 dir2) 
              !r3@(Ray orig3 dir3) 
              !r4@(Ray orig4 dir4) !d t =
 let dir_rcp1 = vrcp dir1
     dir_rcp2 = vrcp dir2
     dir_rcp3 = vrcp dir3
     dir_rcp4 = vrcp dir4
 in
  -- We want all the ray components to have
  -- at least the same sign.
  if not $ veqsign dir_rcp1 dir_rcp2 &&
           veqsign dir_rcp1 dir_rcp3 &&
           veqsign dir_rcp1 dir_rcp4
  then
   PacketResult (rayint bih r1 d t)
                (rayint bih r2 d t)
                (rayint bih r3 d t)
                (rayint bih r4 d t)
  else 
   let Interval near1 far1 = bbclip r1 bb
       Interval near2 far2 = bbclip r2 bb
       Interval near3 far3 = bbclip r3 bb
       Interval near4 far4 = bbclip r4 bb

       near = fmin4 near1 near2 near3 near4
       far =  fmax4 far1  far2  far3  far4

       traverse (BihLeaf s) near far = packetint s r1 r2 r3 r4 (fmin d far) t
       traverse (BihBranch lsplit rsplit axis l r) near far =
           if near > far 
           then packetmiss
           else
            let dirr1 = va dir_rcp1 axis
                dirr2 = va dir_rcp2 axis
                dirr3 = va dir_rcp3 axis
                dirr4 = va dir_rcp4 axis
                     
                o1    = va orig1 axis
                o2    = va orig2 axis
                o3    = va orig3 axis
                o4    = va orig4 axis

                dl1   = (lsplit - o1) * dirr1
                dl2   = (lsplit - o2) * dirr2
                dl3   = (lsplit - o3) * dirr3
                dl4   = (lsplit - o4) * dirr4

                dr1   = (rsplit - o1) * dirr1
                dr2   = (rsplit - o2) * dirr2
                dr3   = (rsplit - o3) * dirr3
                dr4   = (rsplit - o4) * dirr4

            in  
             if dirr1 > 0  -- true for all, since signs match
             then 
              let dl = fmax4 dl1 dl2 dl3 dl4
                  dr = fmin4 dr1 dr2 dr3 dr4
              in
               (nearest_packetresult
                (if near < dl
                 then traverse l near (fmin dl far)
                 else packetmiss)
                (if dr < far
                 then traverse r (fmax dr near) far
                 else packetmiss))
             else
              let dl = fmin4 dl1 dl2 dl3 dl4
                  dr = fmax4 dr1 dr2 dr3 dr4
              in
               (nearest_packetresult
                (if near < dr
                 then traverse r near (fmin dr far)
                 else packetmiss)
                (if dl < far
                 then traverse l (fmax dl near) far
                 else packetmiss))
   in
    traverse root near far

shadow_bih :: Bih -> Ray -> Flt -> Bool
shadow_bih (Bih bb root) r@(Ray orig dir) d =
 let dir_rcp = vrcp dir
     Interval near far = bbclip r bb
     traverse (BihLeaf s) near far = shadow s r (fmin d far)
     traverse (BihBranch lsplit rsplit axis l r) near far =
      let dirr = va dir_rcp axis
          o  = va orig axis
          dl = (lsplit - o) * dirr
          dr = (rsplit - o) * dirr
      in  
          if near > far 
          then False
          else
           if dirr > 0
           then
            ((if near < dl
              then traverse l near (fmin dl far)
              else False) 
             ||
             (if dr < far
              then traverse r (fmax dr near) far
              else False))
           else
            ((if near < dr
              then traverse r near (fmin dr far)
              else False)
             ||
             (if dl < far
              then traverse l (fmax dl near) far
              else False))

 in traverse root near far

-- Inside/outside test; essentially a point traversal.
-- We test if the point is inside any of the objects contained in
-- the bih.

inside_bih :: Bih -> Vec -> Bool
inside_bih (Bih (Bbox (Vec x1 y1 z1) (Vec x2 y2 z2)) root) pt@(Vec x y z) =
 let traverse (BihLeaf s) = inside s pt
     traverse (BihBranch lsplit rsplit axis l r) =
       let o = va pt axis
       in (if o < lsplit
           then (traverse l)
           else False) 
          ||
          (if o > rsplit 
           then (traverse r)
           else False)
 in
  (x > x1) && (x < x2) && 
  (y > y1) && (y < y2) && 
  (z > z1) && (z < z2) && (traverse root)

-- We already have a bounding box computed.
bound_bih :: Bih -> Bbox
bound_bih (Bih bb root) = bb

primcount_bih :: Bih -> Pcount
primcount_bih (Bih bb root) = pcadd (bihcount root) pcsinglebound
 where bihcount (BihLeaf s) = primcount s
       bihcount (BihBranch _ _ _ l r) = 
        pcadd (pcadd (bihcount l) (bihcount r)) pcsinglebound

instance Solid Bih where
 rayint = rayint_bih
 rayint_debug = rayint_debug_bih
 packetint = packetint_bih
 shadow = shadow_bih
 inside = inside_bih
 bound = bound_bih
 primcount = primcount_bih