feat(util/map_foreach): add helper functions for traversing Lean maps

This commit is contained in:
Leonardo de Moura
2019-05-13 12:27:59 -07:00
parent de5b68f126
commit deb2310b6d
4 changed files with 110 additions and 3 deletions

View File

@@ -2,4 +2,4 @@ add_library(util OBJECT object_ref.cpp name.cpp name_set.cpp fresh_name.cpp
escaped.cpp bit_tricks.cpp safe_arith.cpp ascii.cpp shared_mutex.cpp
path.cpp lean_path.cpp lbool.cpp bitap_fuzzy_search.cpp
init_module.cpp list_fn.cpp file_lock.cpp timeit.cpp timer.cpp
parser_exception.cpp name_generator.cpp kvmap.cpp)
parser_exception.cpp name_generator.cpp kvmap.cpp map_foreach.cpp)

87
src/util/map_foreach.cpp Normal file
View File

@@ -0,0 +1,87 @@
/*
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "util/map_foreach.h"
namespace lean {
/*
inductive RBNode (α : Type u) (β : α → Type v)
| leaf {} : RBNode
| node (color : Rbcolor) (lchild : RBNode) (key : α) (val : β key) (rchild : RBNode) : RBNode
*/
class rbmap_visitor_fn {
std::function<void(b_obj_arg, b_obj_arg)> const & m_fn;
void visit(b_obj_arg m) {
if (is_scalar(m)) return;
visit(cnstr_get(m, 0));
m_fn(cnstr_get(m, 1), cnstr_get(m, 2));
visit(cnstr_get(m, 3));
}
public:
rbmap_visitor_fn(std::function<void(b_obj_arg, b_obj_arg)> const & fn):m_fn(fn) {}
void operator()(b_obj_arg m) { visit(m); }
};
void rbmap_foreach(b_obj_arg m, std::function<void(b_obj_arg, b_obj_arg)> const & fn) {
return rbmap_visitor_fn(fn)(m);
}
/*
inductive AssocList (α : Type u) (β : Type v)
| nil {} : AssocList
| cons (key : α) (value : β) (tail : AssocList) : AssocList
def HashMapBucket (α : Type u) (β : Type v) :=
{ b : Array (AssocList α β) // b.size > 0 }
structure HashMapImp (α : Type u) (β : Type v) :=
(size : Nat)
(buckets : HashMapBucket α β)
*/
class hashmap_visitor_fn {
std::function<void(b_obj_arg, b_obj_arg)> const & m_fn;
void visit_assoc_list(b_obj_arg lst) {
while (!is_scalar(lst)) {
m_fn(cnstr_get(lst, 0), cnstr_get(lst, 1));
lst = cnstr_get(lst, 2);
}
}
void visit_buckets(b_obj_arg bs) {
usize sz = array_size(bs);
for (usize i = 0; i < sz; i++) {
visit_assoc_list(array_get(bs, i));
}
}
public:
hashmap_visitor_fn(std::function<void(b_obj_arg, b_obj_arg)> const & fn):m_fn(fn) {}
void operator()(b_obj_arg m) {
visit_buckets(cnstr_get(m, 1));
}
};
void hashmap_foreach(b_obj_arg m, std::function<void(b_obj_arg, b_obj_arg)> const & fn) {
return hashmap_visitor_fn(fn)(m);
}
/*
structure SMap (α : Type u) (β : Type v) (lt : αα → Bool) [HasBeq α] [Hashable α] :=
(stage₁ : Bool := true)
(map₁ : HashMap α β := {})
(map₂ : RBMap α β lt := {})
*/
void smap_foreach(b_obj_arg m, std::function<void(b_obj_arg, b_obj_arg)> const & fn) {
hashmap_foreach(cnstr_get(m, 0), fn);
rbmap_foreach(cnstr_get(m, 1), fn);
}
extern "C" obj_res lean_smap_foreach_test(b_obj_arg m) {
smap_foreach(m, [](b_obj_arg k, b_obj_arg v) {
std::cout << ">> " << unbox(k) << " |-> " << unbox(v) << "\n";
});
return box(0);
}
}

16
src/util/map_foreach.h Normal file
View File

@@ -0,0 +1,16 @@
/*
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include <functional>
#include "runtime/object.h"
namespace lean {
/* Helper functions for iterating over Lean maps. */
void rbmap_foreach(b_obj_arg m, std::function<void(b_obj_arg, b_obj_arg)> const & fn);
void hashmap_foreach(b_obj_arg m, std::function<void(b_obj_arg, b_obj_arg)> const & fn);
void smap_foreach(b_obj_arg m, std::function<void(b_obj_arg, b_obj_arg)> const & fn);
}

View File

@@ -2,12 +2,16 @@ import init.lean.smap
abbrev Map : Type := Lean.SMap Nat Bool (λ a b, a < b)
@[extern "lean_smap_foreach_test"]
constant foreachTest : Map Nat := default _
def test1 (n₁ n₂ : Nat) : IO Unit :=
let m : Map := {} in
let m := n₁.for (λ i (m : Map), m.insert i (i % 2 == 0)) m in
let m := n₁.fold (λ i (m : Map), m.insert i (i % 2 == 0)) m in
let m := m.switch in
let m := n₂.for (λ i (m : Map), m.insert (i+n₁) (i % 3 == 0)) m in
let m := n₂.fold (λ i (m : Map), m.insert (i+n₁) (i % 3 == 0)) m in
do
IO.println (foreachTest m),
n₁.mfor $ λ i, IO.println (i, (m.find i)),
n₂.mfor $ λ i, IO.println (i+n₁, (m.find (i+n₁))),
IO.println (m.foldStage2 (λ kvs k v, (k, v)::kvs) [])