/* ----------------------------------------------------------------------------- Copyright 2021 Kevin P. Barry Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ----------------------------------------------------------------------------- */ // Author: Kevin P. Barry [ta0kira@gmail.com] define CategoricalTree { @value AutoBinaryTree,#c,Int,CategoricalSearch<#c>> tree new () { return #self{ AutoBinaryTree,#c,Int,CategoricalSearch<#c>>.new() } } getTotal () { return CategoricalTreeNode<#c>.tryTotal(tree.getRoot()) } setWeight (cat,size) { if (size < 0) { fail("size must not be negative") } elif (size > 0) { \ tree.set(cat,size) } else { \ tree.remove(cat) } return self } getWeight (cat) { scoped { optional Int size <- tree.get(cat) } in if (present(size)) { return require(size) } else { return 0 } } locate (pos) { if (pos < 0) { fail("position must not be negative") } elif (pos >= getTotal()) { fail("position must be strictly less than the total") } else { return require(CategoricalTreeNode<#c>.findPosition(pos,tree.getRoot())) } } readAt (pos) { return locate(pos) } size () { return getTotal() } } define ValidatedTree { @value AutoBinaryTree,#c,Int,CategoricalSearch<#c>> tree new () { return #self{ AutoBinaryTree,#c,Int,CategoricalSearch<#c>>.new() } } getTotal () { return CategoricalTreeNode<#c>.tryTotal(tree.getRoot()) } setWeight (cat,size) { if (size < 0) { fail("size must not be negative") } elif (size > 0) { \ tree.set(cat,size) } else { \ tree.remove(cat) } \ validateTotal(tree.getRoot()) return self } getWeight (cat) { scoped { optional Int size <- tree.get(cat) } in if (present(size)) { return require(size) } else { return 0 } } locate (pos) { if (pos < 0) { fail("position must not be negative") } elif (pos >= getTotal()) { fail("position must be strictly less than the total") } else { return require(CategoricalTreeNode<#c>.findPosition(pos,tree.getRoot())) } } @type validateTotal (optional CategoricalSearch<#c>) -> () validateTotal (node) { if (present(node)) { CategoricalSearch<#c> node2 <- require(node) \ validateTotal(node2.getLower()) \ validateTotal(node2.getHigher()) Int size <- node2.getValue() Int lower <- CategoricalTreeNode<#c>.tryTotal(node2.getLower()) Int higher <- CategoricalTreeNode<#c>.tryTotal(node2.getHigher()) if (node2.getTotal() != size+lower+higher) { fail(String.builder() .append("bad total: ") .append(size.formatted()) .append("+") .append(lower.formatted()) .append("+") .append(higher.formatted()) .append(" != ") .append(node2.getTotal().formatted()) .build()) } } } } @value interface CategoricalSearch<|#c> { refines BinaryTreeNode<#c,Int> getTotal () -> (Int) } concrete CategoricalTreeNode<#c> { defines KVFactory<#c,Int> refines CategoricalSearch<#c> refines BalancedTreeNode,#c,Int> @type findPosition (Int,optional CategoricalSearch<#c>) -> (optional #c) @type tryTotal (optional CategoricalSearch<#c>) -> (Int) @value getTotal () -> (Int) } define CategoricalTreeNode { $ReadOnly[key]$ @value Int height @value #c key @value Int size @value Int total @value optional #self lower @value optional #self higher findPosition (pos,node) (cat) { cat <- empty if (present(node)) { scoped { Int pos2 <- pos $Hidden[pos]$ CategoricalSearch<#c> node2 <- require(node) Int lower <- tryTotal(node2.getLower()) Int size <- node2.getValue() $ReadOnly[node2,lower,size]$ } in if (pos2 < node2.getTotal()) { if (pos2 < lower) { // pos2 is in the 1st (lower) of the 3 sections. return findPosition(pos2,node2.getLower()) } pos2 <- pos2-lower if (pos2 < size) { // pos2 is in the 2nd (middle) of the 3 sections. return node2.getKey() } pos2 <- pos2-size // pos2 is in the 3rd (higher) of the 3 sections. return findPosition(pos2,node2.getHigher()) } } } tryTotal (node) { if (present(node)) { return require(node).getTotal() } else { return 0 } } newNode (k,v) { return #self{ 1, k, v, v, empty, empty } } getLower () { return lower } setLower (l) { lower <- l } getHigher () { return higher } setHigher (h) { higher <- h } getKey () { return key } getValue () { return size } setValue (v) { size <- v } getHeight () { return height } getTotal () { return total } updateNode () { scoped { Int l <- 0 Int h <- 0 if (present(lower)) { l <- require(lower).getHeight() } if (present(higher)) { h <- require(higher).getHeight() } } in if (l > h) { height <- l + 1 } else { height <- h + 1 } total <- size+tryTotal(lower)+tryTotal(higher) } }