Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
6eed46f2bf fix: Float32 runtime support
This PR adds missing features and fixes bugs in the `Float32` support
2024-12-09 16:32:06 -08:00
7 changed files with 94 additions and 25 deletions

View File

@@ -81,9 +81,9 @@ then one of the following must hold in each (execution) branch.
inductive IRType where
| float | uint8 | uint16 | uint32 | uint64 | usize
| irrelevant | object | tobject
| float32
| struct (leanTypeName : Option Name) (types : Array IRType) : IRType
| union (leanTypeName : Name) (types : Array IRType) : IRType
| float32
deriving Inhabited, Repr
namespace IRType

View File

@@ -352,6 +352,16 @@ class to_ir_fn {
return ir::mk_sset(to_var_id(args[0]), n, offset, to_var_id(args[1]), ir::type::Float, b);
}
ir::fn_body visit_f32set(local_decl const & decl, ir::fn_body const & b) {
expr val = *decl.get_value();
buffer<expr> args;
expr const & fn = get_app_args(val, args);
lean_assert(args.size() == 2);
unsigned n, offset;
lean_verify(is_llnf_f32set(fn, n, offset));
return ir::mk_sset(to_var_id(args[0]), n, offset, to_var_id(args[1]), ir::type::Float32, b);
}
ir::fn_body visit_uset(local_decl const & decl, ir::fn_body const & b) {
expr val = *decl.get_value();
buffer<expr> args;
@@ -419,6 +429,8 @@ class to_ir_fn {
return visit_sset(decl, b);
else if (is_llnf_fset(fn))
return visit_fset(decl, b);
else if (is_llnf_f32set(fn))
return visit_f32set(decl, b);
else if (is_llnf_uset(fn))
return visit_uset(decl, b);
else if (is_llnf_proj(fn))
@@ -451,7 +463,7 @@ class to_ir_fn {
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, type, val);
fvars.push_back(new_fvar);
expr const & op = get_app_fn(val);
if (is_llnf_sset(op) || is_llnf_fset(op) || is_llnf_uset(op)) {
if (is_llnf_sset(op) || is_llnf_fset(op) || is_llnf_f32set(op) || is_llnf_uset(op)) {
/* In the Lean IR, sset and uset are instructions that perform destructive updates. */
subst.push_back(app_arg(app_fn(val)));
} else {

View File

@@ -14,13 +14,13 @@ namespace ir {
inductive IRType
| float | uint8 | uint16 | uint32 | uint64 | usize
| irrelevant | object | tobject
| float32
| struct (leanTypeName : Option Name) (types : Array IRType) : IRType
| union (leanTypeName : Name) (types : Array IRType) : IRType
| float32
Remark: we don't create struct/union types from C++.
*/
enum class type { Float, UInt8, UInt16, UInt32, UInt64, USize, Irrelevant, Object, TObject, Float32 };
enum class type { Float, UInt8, UInt16, UInt32, UInt64, USize, Irrelevant, Object, TObject, Float32, Struct, Union };
typedef nat var_id;
typedef nat jp_id;

View File

@@ -188,6 +188,11 @@ option_ref<decl> find_ir_decl(environment const & env, name const & n) {
extern "C" double lean_float_of_nat(lean_obj_arg a);
// TODO: define in Lean like `lean_float_of_nat`
float lean_float32_of_nat(lean_obj_arg a) {
return lean_float_of_nat(a);
}
static string_ref * g_mangle_prefix = nullptr;
static string_ref * g_boxed_suffix = nullptr;
static string_ref * g_boxed_mangled_suffix = nullptr;
@@ -227,6 +232,7 @@ union value {
uint64 m_num; // big enough for any unboxed integral type
static_assert(sizeof(size_t) <= sizeof(uint64), "uint64 should be the largest unboxed type"); // NOLINT
double m_float;
float m_float32;
object * m_obj;
value() {}
@@ -240,36 +246,50 @@ union value {
v.m_float = f;
return v;
}
static value from_float32(float f) {
value v;
v.m_float32 = f;
return v;
}
};
object * box_t(value v, type t) {
switch (t) {
case type::Float: return box_float(v.m_float);
case type::UInt8: return box(v.m_num);
case type::UInt16: return box(v.m_num);
case type::UInt32: return box_uint32(v.m_num);
case type::UInt64: return box_uint64(v.m_num);
case type::USize: return box_size_t(v.m_num);
case type::Object:
case type::TObject:
case type::Irrelevant:
return v.m_obj;
case type::Float: return box_float(v.m_float);
case type::Float32: return box_float(v.m_float32);
case type::UInt8: return box(v.m_num);
case type::UInt16: return box(v.m_num);
case type::UInt32: return box_uint32(v.m_num);
case type::UInt64: return box_uint64(v.m_num);
case type::USize: return box_size_t(v.m_num);
case type::Object:
case type::TObject:
case type::Irrelevant:
return v.m_obj;
case type::Struct:
case type::Union:
throw exception("not implemented yet");
}
lean_unreachable();
}
value unbox_t(object * o, type t) {
switch (t) {
case type::Float: return value::from_float(unbox_float(o));
case type::UInt8: return unbox(o);
case type::UInt16: return unbox(o);
case type::UInt32: return unbox_uint32(o);
case type::UInt64: return unbox_uint64(o);
case type::USize: return unbox_size_t(o);
case type::Irrelevant:
case type::Object:
case type::TObject:
break;
case type::Float: return value::from_float(unbox_float(o));
case type::Float32: return value::from_float32(unbox_float32(o));
case type::UInt8: return unbox(o);
case type::UInt16: return unbox(o);
case type::UInt32: return unbox_uint32(o);
case type::UInt64: return unbox_uint64(o);
case type::USize: return unbox_size_t(o);
case type::Irrelevant:
case type::Object:
case type::TObject:
break;
case type::Struct:
case type::Union:
throw exception("not implemented yet");
}
lean_unreachable();
}
@@ -278,6 +298,8 @@ value unbox_t(object * o, type t) {
void print_value(tout & ios, value const & v, type t) {
if (t == type::Float) {
ios << v.m_float;
} else if (t == type::Float32) {
ios << v.m_float32;
} else if (type_is_scalar(t)) {
ios << v.m_num;
} else {
@@ -472,6 +494,7 @@ private:
object * o = var(expr_sproj_obj(e)).m_obj;
switch (t) {
case type::Float: return value::from_float(cnstr_get_float(o, offset));
case type::Float32: return value::from_float32(cnstr_get_float32(o, offset));
case type::UInt8: return cnstr_get_uint8(o, offset);
case type::UInt16: return cnstr_get_uint16(o, offset);
case type::UInt32: return cnstr_get_uint32(o, offset);
@@ -480,6 +503,8 @@ private:
case type::Irrelevant:
case type::Object:
case type::TObject:
case type::Struct:
case type::Union:
break;
}
throw exception("invalid instruction");
@@ -530,6 +555,9 @@ private:
case type::Float:
lean_inc(n.raw());
return value::from_float(lean_float_of_nat(n.raw()));
case type::Float32:
lean_inc(n.raw());
return value::from_float32(lean_float32_of_nat(n.raw()));
case type::UInt8:
case type::UInt16:
case type::UInt32:
@@ -543,6 +571,9 @@ private:
return n.to_obj_arg();
case type::Irrelevant:
break;
case type::Union:
case type::Struct:
break;
}
throw exception("invalid instruction");
}
@@ -654,6 +685,7 @@ private:
lean_assert(is_exclusive(o));
switch (fn_body_sset_type(b)) {
case type::Float: cnstr_set_float(o, offset, v.m_float); break;
case type::Float32: cnstr_set_float32(o, offset, v.m_float32); break;
case type::UInt8: cnstr_set_uint8(o, offset, v.m_num); break;
case type::UInt16: cnstr_set_uint16(o, offset, v.m_num); break;
case type::UInt32: cnstr_set_uint32(o, offset, v.m_num); break;
@@ -662,6 +694,8 @@ private:
case type::Irrelevant:
case type::Object:
case type::TObject:
case type::Struct:
case type::Union:
throw exception(sstream() << "invalid instruction");
}
b = fn_body_sset_cont(b);
@@ -807,6 +841,7 @@ private:
// constants do not have boxed wrappers, but we'll survive
switch (t) {
case type::Float: return value::from_float(*static_cast<double *>(e.m_addr));
case type::Float32: return value::from_float32(*static_cast<float *>(e.m_addr));
case type::UInt8: return *static_cast<uint8 *>(e.m_addr);
case type::UInt16: return *static_cast<uint16 *>(e.m_addr);
case type::UInt32: return *static_cast<uint32 *>(e.m_addr);
@@ -816,6 +851,9 @@ private:
case type::TObject:
case type::Irrelevant:
return *static_cast<object **>(e.m_addr);
case type::Struct:
case type::Union:
throw exception("not implemented yet");
}
}

View File

@@ -34,6 +34,7 @@ static char const * g_cnstr = "_cnstr";
static name * g_reuse = nullptr;
static name * g_reset = nullptr;
static name * g_fset = nullptr;
static name * g_f32set = nullptr;
static name * g_sset = nullptr;
static name * g_uset = nullptr;
static name * g_proj = nullptr;
@@ -162,6 +163,9 @@ bool is_llnf_sset(expr const & e, unsigned & sz, unsigned & n, unsigned & offset
expr mk_llnf_fset(unsigned n, unsigned offset) { return mk_constant(name(name(*g_fset, n), offset)); }
bool is_llnf_fset(expr const & e, unsigned & n, unsigned & offset) { return is_llnf_binary_primitive(e, *g_fset, n, offset); }
expr mk_llnf_f32set(unsigned n, unsigned offset) { return mk_constant(name(name(*g_f32set, n), offset)); }
bool is_llnf_f32set(expr const & e, unsigned & n, unsigned & offset) { return is_llnf_binary_primitive(e, *g_f32set, n, offset); }
/* The `_uset.<n>` instruction sets a `usize` value in a constructor object at offset `sizeof(void*)*n`. */
expr mk_llnf_uset(unsigned n) { return mk_constant(name(*g_uset, n)); }
bool is_llnf_uset(expr const & e, unsigned & n) { return is_llnf_unary_primitive(e, *g_uset, n); }
@@ -218,6 +222,7 @@ bool is_llnf_op(expr const & e) {
is_llnf_reset(e) ||
is_llnf_sset(e) ||
is_llnf_fset(e) ||
is_llnf_f32set(e) ||
is_llnf_uset(e) ||
is_llnf_proj(e) ||
is_llnf_sproj(e) ||
@@ -520,6 +525,10 @@ class to_lambda_pure_fn {
return mk_app(mk_llnf_fset(num, offset), major, v);
}
expr mk_f32set(expr const & major, unsigned num, unsigned offset, expr const & v) {
return mk_app(mk_llnf_f32set(num, offset), major, v);
}
expr mk_uset(expr const & major, unsigned idx, expr const & v) {
return mk_app(mk_llnf_uset(idx), major, v);
}
@@ -684,8 +693,10 @@ class to_lambda_pure_fn {
if (first) {
r = mk_let_decl(mk_enf_object_type(), r);
}
if (info.is_float() || info.is_float32()) {
if (info.is_float()) {
r = mk_let_decl(mk_enf_object_type(), mk_fset(r, info.m_idx, info.m_offset, args[j]));
} else if (info.is_float32()) {
r = mk_let_decl(mk_enf_object_type(), mk_f32set(r, info.m_idx, info.m_offset, args[j]));
} else {
r = mk_let_decl(mk_enf_object_type(), mk_sset(r, info.m_size, info.m_idx, info.m_offset, args[j]));
}
@@ -834,6 +845,8 @@ void initialize_llnf() {
mark_persistent(g_sset->raw());
g_fset = new name("_fset");
mark_persistent(g_fset->raw());
g_f32set = new name("_f32set");
mark_persistent(g_f32set->raw());
g_uset = new name("_uset");
mark_persistent(g_uset->raw());
g_proj = new name("_proj");
@@ -864,6 +877,7 @@ void finalize_llnf() {
delete g_reset;
delete g_sset;
delete g_fset;
delete g_f32set;
delete g_proj;
delete g_sproj;
delete g_fproj;

View File

@@ -26,6 +26,7 @@ bool is_llnf_fproj(expr const & e, unsigned & n, unsigned & offset);
bool is_llnf_uproj(expr const & e, unsigned & idx);
bool is_llnf_sset(expr const & e, unsigned & sz, unsigned & n, unsigned & offset);
bool is_llnf_fset(expr const & e, unsigned & n, unsigned & offset);
bool is_llnf_f32set(expr const & e, unsigned & n, unsigned & offset);
bool is_llnf_uset(expr const & e, unsigned & n);
bool is_llnf_jmp(expr const & e);
bool is_llnf_unbox(expr const & e, unsigned & n);
@@ -43,6 +44,7 @@ inline bool is_llnf_fproj(expr const & e) { unsigned d1, d2; return is_llnf_fpro
inline bool is_llnf_uproj(expr const & e) { unsigned d; return is_llnf_uproj(e, d); }
inline bool is_llnf_sset(expr const & e) { unsigned d1, d2, d3; return is_llnf_sset(e, d1, d2, d3); }
inline bool is_llnf_fset(expr const & e) { unsigned d1, d2; return is_llnf_fset(e, d1, d2); }
inline bool is_llnf_f32set(expr const & e) { unsigned d1, d2; return is_llnf_f32set(e, d1, d2); }
inline bool is_llnf_uset(expr const & e) { unsigned d; return is_llnf_uset(e, d); }
inline bool is_llnf_box(expr const & e) { unsigned n; return is_llnf_box(e, n); }
inline bool is_llnf_unbox(expr const & e) { unsigned n; return is_llnf_unbox(e, n); }

View File

@@ -85,11 +85,13 @@ inline uint16 cnstr_get_uint16(b_obj_arg o, unsigned offset) { return lean_ctor_
inline uint32 cnstr_get_uint32(b_obj_arg o, unsigned offset) { return lean_ctor_get_uint32(o, offset); }
inline uint64 cnstr_get_uint64(b_obj_arg o, unsigned offset) { return lean_ctor_get_uint64(o, offset); }
inline double cnstr_get_float(b_obj_arg o, unsigned offset) { return lean_ctor_get_float(o, offset); }
inline float cnstr_get_float32(b_obj_arg o, unsigned offset) { return lean_ctor_get_float32(o, offset); }
inline void cnstr_set_uint8(b_obj_arg o, unsigned offset, uint8 v) { lean_ctor_set_uint8(o, offset, v); }
inline void cnstr_set_uint16(b_obj_arg o, unsigned offset, uint16 v) { lean_ctor_set_uint16(o, offset, v); }
inline void cnstr_set_uint32(b_obj_arg o, unsigned offset, uint32 v) { lean_ctor_set_uint32(o, offset, v); }
inline void cnstr_set_uint64(b_obj_arg o, unsigned offset, uint64 v) { lean_ctor_set_uint64(o, offset, v); }
inline void cnstr_set_float(b_obj_arg o, unsigned offset, double v) { lean_ctor_set_float(o, offset, v); }
inline void cnstr_set_float32(b_obj_arg o, unsigned offset, float v) { lean_ctor_set_float32(o, offset, v); }
// =======================================
// Closures
@@ -372,6 +374,7 @@ inline obj_res box_uint64(unsigned long long v) { return lean_box_uint64(v); }
inline unsigned long long unbox_uint64(b_obj_arg o) { return lean_unbox_uint64(o); }
inline obj_res box_float(double v) { return lean_box_float(v); }
inline double unbox_float(b_obj_arg o) { return lean_unbox_float(o); }
inline float unbox_float32(b_obj_arg o) { return lean_unbox_float32(o); }
inline obj_res box_size_t(size_t v) { return lean_box_usize(v); }
inline size_t unbox_size_t(b_obj_arg o) { return lean_unbox_usize(o); }