Compare commits

...

3 Commits

Author SHA1 Message Date
Leonardo de Moura
5844405ee9 feat: LeanMPZ at sharecommon_quick_fn 2025-04-06 12:36:49 -07:00
Leonardo de Moura
9e4fe939dc test: mpz at sharecommon issue 2025-04-06 12:36:49 -07:00
Leonardo de Moura
4ac3cdf713 feat: support LeanMPZ objects at sharecommon.cpp 2025-04-06 12:36:49 -07:00
2 changed files with 81 additions and 12 deletions

View File

@@ -17,22 +17,32 @@ extern "C" LEAN_EXPORT uint8 lean_sharecommon_eq(b_obj_arg o1, b_obj_arg o2) {
size_t sz2 = lean_object_data_byte_size(o2);
if (sz1 != sz2) return false;
// compare relevant parts of the header
if (lean_ptr_tag(o1) != lean_ptr_tag(o2)) return false;
uint8_t tag = lean_ptr_tag(o1);
if (tag != lean_ptr_tag(o2)) return false;
if (lean_ptr_other(o1) != lean_ptr_other(o2)) return false;
size_t header_sz = sizeof(lean_object);
lean_assert(sz1 >= header_sz);
// compare objects' bodies
return memcmp(reinterpret_cast<char*>(o1) + header_sz, reinterpret_cast<char*>(o2) + header_sz, sz1 - header_sz) == 0;
if (tag == LeanMPZ) {
return mpz_value(o1) == mpz_value(o2);
} else {
size_t header_sz = sizeof(lean_object);
lean_assert(sz1 >= header_sz);
// compare objects' bodies
return memcmp(reinterpret_cast<char*>(o1) + header_sz, reinterpret_cast<char*>(o2) + header_sz, sz1 - header_sz) == 0;
}
}
extern "C" LEAN_EXPORT uint64_t lean_sharecommon_hash(b_obj_arg o) {
lean_assert(!lean_is_scalar(o));
size_t sz = lean_object_data_byte_size(o);
size_t header_sz = sizeof(lean_object);
// hash relevant parts of the header
unsigned init = hash(lean_ptr_tag(o), lean_ptr_other(o));
// hash body
return hash_str(sz - header_sz, reinterpret_cast<unsigned char const *>(o) + header_sz, init);
uint8_t tag = lean_ptr_tag(o);
if (tag == LeanMPZ) {
return hash(tag, mpz_value(o).hash());
} else {
// hash relevant parts of the header
unsigned init = hash(tag, lean_ptr_other(o));
// hash body
return hash_str(sz - header_sz, reinterpret_cast<unsigned char const *>(o) + header_sz, init);
}
}
static obj_res mk_pair(obj_arg a, obj_arg b) {
@@ -114,7 +124,7 @@ class sharecommon_fn {
case LeanReserved:
lean_unreachable();
// We do not maximize sharing for the following kinds of objects
case LeanMPZ: case LeanThunk:
case LeanThunk:
case LeanTask: case LeanRef:
case LeanExternal: case LeanClosure:
case LeanPromise:
@@ -201,6 +211,11 @@ class sharecommon_fn {
save(a, (lean_object*)new_a);
}
void visit_mpz(b_obj_arg a) {
object * new_a = alloc_mpz(mpz_value(a));
save(a, new_a);
}
void visit_ctor(b_obj_arg a) {
clear_children();
unsigned num_objs = lean_ctor_num_objs(a);
@@ -247,7 +262,7 @@ public:
case LeanArray: visit_array(curr); break;
case LeanScalarArray: visit_sarray(curr); break;
case LeanString: visit_string(curr); break;
case LeanMPZ: lean_unreachable();
case LeanMPZ: visit_mpz(curr); break;
case LeanThunk: lean_unreachable();
case LeanTask: lean_unreachable();
case LeanPromise: lean_unreachable();
@@ -409,7 +424,6 @@ lean_object * sharecommon_quick_fn::visit(lean_object * a) {
Similarly to `sharecommon_fn`, we only maximally share arrays, scalar arrays, strings, and
constructor objects.
*/
case LeanMPZ: lean_inc_ref(a); return a;
case LeanClosure: lean_inc_ref(a); return a;
case LeanThunk: lean_inc_ref(a); return a;
case LeanTask: lean_inc_ref(a); return a;
@@ -417,6 +431,7 @@ lean_object * sharecommon_quick_fn::visit(lean_object * a) {
case LeanRef: lean_inc_ref(a); return a;
case LeanExternal: lean_inc_ref(a); return a;
case LeanReserved: lean_inc_ref(a); return a;
case LeanMPZ: return visit_terminal(a);
case LeanScalarArray: return visit_terminal(a);
case LeanString: return visit_terminal(a);
case LeanArray: return visit_array(a);

View File

@@ -0,0 +1,54 @@
import Lean
open Lean Meta Tactic Grind
def runGrind (x : GrindM α) : MetaM α := do
GrindM.run x `dummy ( mkParams {}) (pure ())
@[noinline] def mkA (x : Nat) := x + 1
def tst (a b : Nat) : GrindM Unit := do
IO.println a
IO.println b
let a shareCommon (mkNatLit a)
let b shareCommon (mkNatLit b)
IO.println (isSameExpr a b)
/--
info: 1000000000000000000000000001
1000000000000000000000000001
true
-/
#guard_msgs (info) in
run_meta do
let a := mkA 1000000000000000000000000000
let b := 1000000000000000000000000001
runGrind (tst a b)
/--
info: 1001
1001
true
-/
#guard_msgs (info) in
run_meta do
let a := mkA 1000
let b := 1001
runGrind (tst a b)
def tst2 (a b : Nat) : IO Unit := do
IO.println a
IO.println b
let (a, b) := ShareCommon.shareCommon' (mkNatLit a, mkNatLit b)
IO.println (isSameExpr a b)
/--
info: 1000000000000000000000000001
1000000000000000000000000001
true
-/
#guard_msgs (info) in
run_meta do
let a := mkA 1000000000000000000000000000
let b := 1000000000000000000000000001
tst2 a b