mirror of
https://github.com/leanprover/lean4.git
synced 2026-04-06 12:14:07 +00:00
Compare commits
12 Commits
sofia/open
...
sym-arith-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f12d008bb1 | ||
|
|
30315a59d4 | ||
|
|
9d078f64bc | ||
|
|
083b393294 | ||
|
|
9ea2b7b533 | ||
|
|
e17b0347c8 | ||
|
|
0a6c7eef66 | ||
|
|
46046b47a8 | ||
|
|
069e676532 | ||
|
|
94bf1d34d1 | ||
|
|
681769fb6d | ||
|
|
dad6fe832d |
6
.github/workflows/build-template.yml
vendored
6
.github/workflows/build-template.yml
vendored
@@ -59,11 +59,11 @@ jobs:
|
||||
with:
|
||||
msystem: clang64
|
||||
# `:` means do not prefix with msystem
|
||||
pacboy: "make: python: cmake clang ccache gmp libuv openssl: git: zip: unzip: diffutils: binutils: tree: zstd tar:"
|
||||
pacboy: "make: python: cmake clang ccache gmp libuv git: zip: unzip: diffutils: binutils: tree: zstd tar:"
|
||||
if: runner.os == 'Windows'
|
||||
- name: Install Brew Packages
|
||||
run: |
|
||||
brew install ccache tree zstd coreutils gmp libuv openssl
|
||||
brew install ccache tree zstd coreutils gmp libuv
|
||||
if: runner.os == 'macOS'
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
@@ -92,7 +92,7 @@ jobs:
|
||||
run: |
|
||||
sudo dpkg --add-architecture i386
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y gcc-multilib g++-multilib ccache libuv1-dev:i386 libssl-dev:i386 pkgconf:i386
|
||||
sudo apt-get install -y gcc-multilib g++-multilib ccache libuv1-dev:i386 pkgconf:i386
|
||||
if: matrix.cmultilib
|
||||
- name: Restore Cache
|
||||
id: restore-cache
|
||||
|
||||
@@ -9,7 +9,6 @@ Requirements
|
||||
- [CMake](http://www.cmake.org)
|
||||
- [GMP (GNU multiprecision library)](http://gmplib.org/)
|
||||
- [LibUV](https://libuv.org/)
|
||||
- [OpenSSL](https://www.openssl.org/)
|
||||
|
||||
Platform-Specific Setup
|
||||
-----------------------
|
||||
|
||||
@@ -32,7 +32,7 @@ MSYS2 has a package management system, [pacman][pacman].
|
||||
Here are the commands to install all dependencies needed to compile Lean on your machine.
|
||||
|
||||
```bash
|
||||
pacman -S make python mingw-w64-clang-x86_64-cmake mingw-w64-clang-x86_64-clang mingw-w64-clang-x86_64-ccache mingw-w64-clang-x86_64-libuv mingw-w64-clang-x86_64-gmp mingw-w64-clang-x86_64-openssl git unzip diffutils binutils
|
||||
pacman -S make python mingw-w64-clang-x86_64-cmake mingw-w64-clang-x86_64-clang mingw-w64-clang-x86_64-ccache mingw-w64-clang-x86_64-libuv mingw-w64-clang-x86_64-gmp git unzip diffutils binutils
|
||||
```
|
||||
|
||||
You should now be able to run these commands:
|
||||
|
||||
@@ -32,13 +32,12 @@ following to use `g++`.
|
||||
cmake -DCMAKE_CXX_COMPILER=g++ ...
|
||||
```
|
||||
|
||||
## Required Packages: CMake, GMP, libuv, OpenSSL, pkgconf
|
||||
## Required Packages: CMake, GMP, libuv, pkgconf
|
||||
|
||||
```bash
|
||||
brew install cmake
|
||||
brew install gmp
|
||||
brew install libuv
|
||||
brew install openssl
|
||||
brew install pkgconf
|
||||
```
|
||||
|
||||
|
||||
@@ -8,5 +8,5 @@ follow the [generic build instructions](index.md).
|
||||
## Basic packages
|
||||
|
||||
```bash
|
||||
sudo apt-get install git libgmp-dev libuv1-dev libssl-dev cmake ccache clang pkgconf
|
||||
sudo apt-get install git libgmp-dev libuv1-dev cmake ccache clang pkgconf
|
||||
```
|
||||
|
||||
22
flake.nix
22
flake.nix
@@ -24,7 +24,7 @@
|
||||
stdenv = pkgs.overrideCC pkgs.stdenv llvmPackages.clang;
|
||||
} ({
|
||||
buildInputs = with pkgs; [
|
||||
cmake gmp libuv ccache pkg-config openssl openssl.dev
|
||||
cmake gmp libuv ccache pkg-config
|
||||
llvmPackages.bintools # wrapped lld
|
||||
llvmPackages.llvm # llvm-symbolizer for asan/lsan
|
||||
gdb
|
||||
@@ -34,21 +34,7 @@
|
||||
hardeningDisable = [ "all" ];
|
||||
# more convenient `ctest` output
|
||||
CTEST_OUTPUT_ON_FAILURE = 1;
|
||||
} // pkgs.lib.optionalAttrs pkgs.stdenv.isLinux (let
|
||||
# Build OpenSSL 3 statically using pkgsDist's old-glibc stdenv,
|
||||
# so the resulting static libs don't require newer glibc symbols.
|
||||
opensslForDist = pkgsDist.stdenv.mkDerivation {
|
||||
name = "openssl-static-${pkgs.lib.getVersion pkgs.openssl.name}";
|
||||
inherit (pkgs.openssl) src;
|
||||
nativeBuildInputs = [ pkgsDist.perl ];
|
||||
configurePhase = ''
|
||||
patchShebangs .
|
||||
./config --prefix=$out no-shared no-tests
|
||||
'';
|
||||
buildPhase = "make -j$NIX_BUILD_CORES";
|
||||
installPhase = "make install_sw";
|
||||
};
|
||||
in {
|
||||
} // pkgs.lib.optionalAttrs pkgs.stdenv.isLinux {
|
||||
GMP = (pkgsDist.gmp.override { withStatic = true; }).overrideAttrs (attrs:
|
||||
pkgs.lib.optionalAttrs (pkgs.stdenv.system == "aarch64-linux") {
|
||||
# would need additional linking setup on Linux aarch64, we don't use it anywhere else either
|
||||
@@ -67,15 +53,13 @@
|
||||
};
|
||||
doCheck = false;
|
||||
});
|
||||
OPENSSL = opensslForDist;
|
||||
OPENSSL_DEV = opensslForDist;
|
||||
GLIBC = pkgsDist.glibc;
|
||||
GLIBC_DEV = pkgsDist.glibc.dev;
|
||||
GCC_LIB = pkgsDist.gcc.cc.lib;
|
||||
ZLIB = pkgsDist.zlib;
|
||||
# for CI coredumps
|
||||
GDB = pkgsDist.gdb;
|
||||
}));
|
||||
});
|
||||
in {
|
||||
devShells.${system} = {
|
||||
# The default development shell for working on lean itself
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euxo pipefail
|
||||
|
||||
# run from root build directory (from inside nix-shell or otherwise defining GLIBC/ZLIB/GMP/OPENSSL) as in
|
||||
# run from root build directory (from inside nix-shell or otherwise defining GLIBC/ZLIB/GMP) as in
|
||||
# ```
|
||||
# eval cmake ../.. $(../../script/prepare-llvm-linux.sh ~/Downloads/lean-llvm-x86_64-linux-gnu.tar.zst)
|
||||
# ```
|
||||
@@ -42,8 +42,6 @@ $CP $GLIBC/lib/*crt* stage1/lib/
|
||||
# runtime
|
||||
(cd llvm; $CP --parents lib/clang/*/lib/*/{clang_rt.*.o,libclang_rt.builtins*} ../stage1)
|
||||
$CP llvm/lib/*/lib{c++,c++abi,unwind}.* $GMP/lib/libgmp.a $LIBUV/lib/libuv.a stage1/lib/
|
||||
# bundle OpenSSL static libs
|
||||
cp $OPENSSL/lib/libssl.a $OPENSSL/lib/libcrypto.a stage1/lib/
|
||||
# LLVM 19 appears to ship the dependencies in 'llvm/lib/<target-triple>/' and 'llvm/include/<target-triple>/'
|
||||
# but clang-19 that we use to compile is linked against 'llvm/lib/' and 'llvm/include'
|
||||
# https://github.com/llvm/llvm-project/issues/54955
|
||||
@@ -59,7 +57,6 @@ for f in $GLIBC/lib/{ld,lib{c,dl,m,rt,pthread}}-*; do b=$(basename $f); cp $f st
|
||||
OPTIONS=()
|
||||
# We build cadical using the custom toolchain on Linux to avoid glibc versioning issues
|
||||
echo -n " -DLEAN_STANDALONE=ON -DCADICAL_USE_CUSTOM_CXX=ON"
|
||||
echo -n " -DOPENSSL_INCLUDE_DIR=$OPENSSL_DEV/include -DOPENSSL_SSL_LIBRARY=$OPENSSL/lib/libssl.a -DOPENSSL_CRYPTO_LIBRARY=$OPENSSL/lib/libcrypto.a"
|
||||
echo -n " -DCMAKE_CXX_COMPILER=$PWD/llvm-host/bin/clang++ -DLEAN_CXX_STDLIB='-Wl,-Bstatic -lc++ -lc++abi -Wl,-Bdynamic'"
|
||||
# these should also be used for cadical, so do not use `LEAN_EXTRA_CXX_FLAGS` here
|
||||
echo -n " -DCMAKE_CXX_FLAGS='--sysroot $PWD/llvm -idirafter $GLIBC_DEV/include ${EXTRA_FLAGS:-}'"
|
||||
@@ -77,8 +74,8 @@ fi
|
||||
echo -n " -DLEANC_INTERNAL_FLAGS='--sysroot ROOT -nostdinc -isystem ROOT/include/clang' -DLEANC_CC=ROOT/bin/clang"
|
||||
# ld.so is usually included by the libc.so linker script but we discard those. Make sure it is linked to only after `libc.so` like in the original
|
||||
# linker script so that no libc symbols are bound to it instead.
|
||||
echo -n " -DLEANC_INTERNAL_LINKER_FLAGS='--sysroot ROOT -L ROOT/lib -L ROOT/lib/glibc -lc -lc_nonshared -Wl,--as-needed -l:ld.so -Wl,--no-as-needed -lpthread_nonshared -Wl,--as-needed -Wl,-Bstatic -lgmp -lunwind -luv -lssl -lcrypto -Wl,-Bdynamic -Wl,--no-as-needed -Wl,--disable-new-dtags,-rpath,ROOT/lib -fuse-ld=lld'"
|
||||
# when not using the above flags, link GMP/libuv/OpenSSL dynamically/as usual
|
||||
echo -n " -DLEAN_EXTRA_LINKER_FLAGS='-Wl,--as-needed -lgmp -luv -Wl,-Bstatic -lssl -lcrypto -Wl,-Bdynamic -lpthread -ldl -lrt -Wl,--no-as-needed'"
|
||||
echo -n " -DLEANC_INTERNAL_LINKER_FLAGS='--sysroot ROOT -L ROOT/lib -L ROOT/lib/glibc -lc -lc_nonshared -Wl,--as-needed -l:ld.so -Wl,--no-as-needed -lpthread_nonshared -Wl,--as-needed -Wl,-Bstatic -lgmp -lunwind -luv -Wl,-Bdynamic -Wl,--no-as-needed -fuse-ld=lld'"
|
||||
# when not using the above flags, link GMP dynamically/as usual
|
||||
echo -n " -DLEAN_EXTRA_LINKER_FLAGS='-Wl,--as-needed -lgmp -luv -lpthread -ldl -lrt -Wl,--no-as-needed'"
|
||||
# do not set `LEAN_CC` for tests
|
||||
echo -n " -DLEAN_TEST_VARS=''"
|
||||
|
||||
@@ -10,7 +10,6 @@ set -uxo pipefail
|
||||
|
||||
GMP=${GMP:-$(brew --prefix)}
|
||||
LIBUV=${LIBUV:-$(brew --prefix)}
|
||||
OPENSSL=${OPENSSL:-$(brew --prefix openssl@3)}
|
||||
|
||||
[[ -d llvm ]] || (mkdir llvm; gtar xf $1 --strip-components 1 --directory llvm)
|
||||
[[ -d llvm-host ]] || if [[ "$#" -gt 1 ]]; then
|
||||
@@ -42,7 +41,6 @@ gcp llvm/lib/libc++.dylib stage1/lib/libc
|
||||
# and apparently since Sonoma does not do so implicitly either
|
||||
install_name_tool -id /usr/lib/libc++.dylib stage1/lib/libc/libc++.dylib
|
||||
echo -n " -DLEAN_STANDALONE=ON"
|
||||
echo -n " -DOPENSSL_INCLUDE_DIR=$OPENSSL/include -DOPENSSL_SSL_LIBRARY=$OPENSSL/lib/libssl.a -DOPENSSL_CRYPTO_LIBRARY=$OPENSSL/lib/libcrypto.a"
|
||||
# do not change C++ compiler; libc++ etc. being system libraries means there's no danger of conflicts,
|
||||
# and the custom clang++ outputs a myriad of warnings when consuming the SDK
|
||||
echo -n " -DLEAN_EXTRA_CXX_FLAGS='${EXTRA_FLAGS:-}'"
|
||||
@@ -50,7 +48,7 @@ if [[ -L llvm-host ]]; then
|
||||
echo -n " -DCMAKE_C_COMPILER=$PWD/stage1/bin/clang"
|
||||
gcp $GMP/lib/libgmp.a stage1/lib/
|
||||
gcp $LIBUV/lib/libuv.a stage1/lib/
|
||||
echo -n " -DLEAN_EXTRA_LINKER_FLAGS='-lgmp -luv $OPENSSL/lib/libssl.a $OPENSSL/lib/libcrypto.a'"
|
||||
echo -n " -DLEAN_EXTRA_LINKER_FLAGS='-lgmp -luv'"
|
||||
else
|
||||
echo -n " -DCMAKE_C_COMPILER=$PWD/llvm-host/bin/clang -DLEANC_OPTS='--sysroot $PWD/stage1 -resource-dir $PWD/stage1/lib/clang/15.0.1 ${EXTRA_FLAGS:-}'"
|
||||
fi
|
||||
|
||||
@@ -40,14 +40,14 @@ cp /clang64/lib/{crtbegin,crtend,crt2,dllcrt2}.o stage1/lib/
|
||||
# tells the compiler how to dynamically link against `bcrypt.dll` (which is located in the System32 folder).
|
||||
# This distinction is relevant specifically for `libicu.a`/`icu.dll` because there we want updates to the time zone database to
|
||||
# be delivered to users via Windows Update without having to recompile Lean or Lean programs.
|
||||
cp /clang64/lib/lib{m,bcrypt,mingw32,moldname,mingwex,msvcrt,pthread,advapi32,shell32,user32,kernel32,ucrtbase,psapi,iphlpapi,userenv,ws2_32,dbghelp,ole32,icu,crypt32,gdi32}.* /clang64/lib/libgmp.a /clang64/lib/libuv.a /clang64/lib/libssl.a /clang64/lib/libcrypto.a llvm/lib/lib{c++,c++abi,unwind}.a stage1/lib/
|
||||
cp /clang64/lib/lib{m,bcrypt,mingw32,moldname,mingwex,msvcrt,pthread,advapi32,shell32,user32,kernel32,ucrtbase,psapi,iphlpapi,userenv,ws2_32,dbghelp,ole32,icu}.* /clang64/lib/libgmp.a /clang64/lib/libuv.a llvm/lib/lib{c++,c++abi,unwind}.a stage1/lib/
|
||||
echo -n " -DLEAN_STANDALONE=ON"
|
||||
echo -n " -DCMAKE_C_COMPILER=$PWD/stage1/bin/clang.exe -DCMAKE_C_COMPILER_WORKS=1 -DCMAKE_CXX_COMPILER=$PWD/llvm/bin/clang++.exe -DCMAKE_CXX_COMPILER_WORKS=1 -DLEAN_CXX_STDLIB='-lc++ -lc++abi'"
|
||||
echo -n " -DSTAGE0_CMAKE_C_COMPILER=clang -DSTAGE0_CMAKE_CXX_COMPILER=clang++"
|
||||
echo -n " -DLEAN_EXTRA_CXX_FLAGS='--sysroot $PWD/llvm -idirafter /clang64/include/'"
|
||||
echo -n " -DLEANC_INTERNAL_FLAGS='--sysroot ROOT -nostdinc -isystem ROOT/include/clang' -DLEANC_CC=ROOT/bin/clang.exe"
|
||||
echo -n " -DLEANC_INTERNAL_LINKER_FLAGS='--sysroot ROOT -L ROOT/lib -Wl,-Bstatic -lgmp $(pkg-config --static --libs libuv) -lssl -lcrypto -lunwind -Wl,-Bdynamic -lcrypt32 -lgdi32 -fuse-ld=lld'"
|
||||
# when not using the above flags, link GMP/libuv/OpenSSL dynamically/as usual. Always link ICU dynamically.
|
||||
echo -n " -DLEAN_EXTRA_LINKER_FLAGS='-lgmp $(pkg-config --libs libuv) -lssl -lcrypto -lcrypt32 -lgdi32 -lucrtbase'"
|
||||
echo -n " -DLEANC_INTERNAL_LINKER_FLAGS='--sysroot ROOT -L ROOT/lib -Wl,-Bstatic -lgmp $(pkg-config --static --libs libuv) -lunwind -Wl,-Bdynamic -fuse-ld=lld'"
|
||||
# when not using the above flags, link GMP dynamically/as usual. Always link ICU dynamically.
|
||||
echo -n " -DLEAN_EXTRA_LINKER_FLAGS='-lgmp $(pkg-config --libs libuv) -lucrtbase'"
|
||||
# do not set `LEAN_CC` for tests
|
||||
echo -n " -DLEAN_TEST_VARS=''"
|
||||
|
||||
@@ -356,48 +356,6 @@ if(NOT LEAN_STANDALONE)
|
||||
string(APPEND LEAN_EXTRA_LINKER_FLAGS " ${LIBUV_LDFLAGS}")
|
||||
endif()
|
||||
|
||||
# OpenSSL
|
||||
if("${CMAKE_SYSTEM_NAME}" MATCHES "Emscripten")
|
||||
# Only on WebAssembly we compile OpenSSL ourselves
|
||||
set(OPENSSL_EMSCRIPTEN_FLAGS "${EMSCRIPTEN_SETTINGS}")
|
||||
|
||||
# OpenSSL needs to be configured for Emscripten using their configuration system
|
||||
ExternalProject_add(openssl
|
||||
PREFIX openssl
|
||||
GIT_REPOSITORY https://github.com/openssl/openssl
|
||||
# Sync version with flake.nix if applicable
|
||||
GIT_TAG openssl-3.0.15
|
||||
CONFIGURE_COMMAND <SOURCE_DIR>/Configure linux-generic32 no-shared no-dso no-engine no-tests --prefix=<INSTALL_DIR> CC=${CMAKE_C_COMPILER} CXX=${CMAKE_CXX_COMPILER} AR=${CMAKE_AR} CFLAGS=${OPENSSL_EMSCRIPTEN_FLAGS}
|
||||
BUILD_COMMAND emmake make -j
|
||||
INSTALL_COMMAND emmake make install_sw
|
||||
BUILD_IN_SOURCE ON)
|
||||
set(OPENSSL_INCLUDE_DIR "${CMAKE_BINARY_DIR}/openssl/include")
|
||||
set(OPENSSL_CRYPTO_LIBRARY "${CMAKE_BINARY_DIR}/openssl/lib/libcrypto.a")
|
||||
set(OPENSSL_SSL_LIBRARY "${CMAKE_BINARY_DIR}/openssl/lib/libssl.a")
|
||||
set(OPENSSL_LIBRARIES "${OPENSSL_SSL_LIBRARY} ${OPENSSL_CRYPTO_LIBRARY}")
|
||||
else()
|
||||
find_package(OpenSSL 3 REQUIRED)
|
||||
set(OPENSSL_LIBRARIES ${OPENSSL_SSL_LIBRARY} ${OPENSSL_CRYPTO_LIBRARY})
|
||||
endif()
|
||||
include_directories(${OPENSSL_INCLUDE_DIR})
|
||||
string(JOIN " " OPENSSL_LIBRARIES_STR ${OPENSSL_LIBRARIES})
|
||||
string(APPEND LEANSHARED_LINKER_FLAGS " ${OPENSSL_LIBRARIES_STR}")
|
||||
|
||||
if(NOT LEAN_STANDALONE)
|
||||
string(APPEND LEAN_EXTRA_LINKER_FLAGS " ${OPENSSL_LIBRARIES_STR}")
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
string(APPEND LEANSHARED_LINKER_FLAGS " -Wl,--disable-new-dtags,-rpath,$$ORIGIN")
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
string(APPEND LEANSHARED_LINKER_FLAGS " -lcrypt32 -lgdi32")
|
||||
if(NOT LEAN_STANDALONE)
|
||||
string(APPEND LEAN_EXTRA_LINKER_FLAGS " -lcrypt32 -lgdi32")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Windows SDK (for ICU)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||
# Pass 'tools' to skip MSVC version check (as MSVC/Visual Studio is not necessarily installed)
|
||||
@@ -772,9 +730,9 @@ if(STAGE GREATER 1)
|
||||
endif()
|
||||
else()
|
||||
add_subdirectory(runtime)
|
||||
if("${CMAKE_SYSTEM_NAME}" MATCHES "Emscripten")
|
||||
add_dependencies(leanrt libuv openssl)
|
||||
add_dependencies(leanrt_initial-exec libuv openssl)
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Emscripten")
|
||||
add_dependencies(leanrt libuv)
|
||||
add_dependencies(leanrt_initial-exec libuv)
|
||||
endif()
|
||||
|
||||
add_subdirectory(util)
|
||||
|
||||
@@ -25,6 +25,7 @@ public import Lean.Meta.Sym.Simp
|
||||
public import Lean.Meta.Sym.Util
|
||||
public import Lean.Meta.Sym.Eta
|
||||
public import Lean.Meta.Sym.Canon
|
||||
public import Lean.Meta.Sym.Arith
|
||||
public import Lean.Meta.Sym.Grind
|
||||
public import Lean.Meta.Sym.SynthInstance
|
||||
|
||||
|
||||
20
src/Lean/Meta/Sym/Arith.lean
Normal file
20
src/Lean/Meta/Sym/Arith.lean
Normal file
@@ -0,0 +1,20 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Types
|
||||
public import Lean.Meta.Sym.Arith.EvalNum
|
||||
public import Lean.Meta.Sym.Arith.Classify
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public import Lean.Meta.Sym.Arith.MonadRing
|
||||
public import Lean.Meta.Sym.Arith.MonadSemiring
|
||||
public import Lean.Meta.Sym.Arith.MonadVar
|
||||
public import Lean.Meta.Sym.Arith.Functions
|
||||
public import Lean.Meta.Sym.Arith.Reify
|
||||
public import Lean.Meta.Sym.Arith.DenoteExpr
|
||||
public import Lean.Meta.Sym.Arith.ToExpr
|
||||
public import Lean.Meta.Sym.Arith.VarRename
|
||||
public import Lean.Meta.Sym.Arith.Poly
|
||||
143
src/Lean/Meta/Sym/Arith/Classify.lean
Normal file
143
src/Lean/Meta/Sym/Arith/Classify.lean
Normal file
@@ -0,0 +1,143 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.EvalNum
|
||||
import Lean.Meta.Sym.SynthInstance
|
||||
import Lean.Meta.Sym.Canon
|
||||
import Lean.Meta.DecLevel
|
||||
import Init.Grind.Ring
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-!
|
||||
# Algebraic structure classification
|
||||
|
||||
Detects the strongest algebraic structure available for a type and caches
|
||||
the classification in `Arith.State.typeClassify`. The detection order is:
|
||||
|
||||
1. `Grind.CommRing` (includes `Field` check)
|
||||
2. `Grind.Ring` (non-commutative)
|
||||
3. `Grind.CommSemiring` (via `OfSemiring.Q` envelope)
|
||||
4. `Grind.Semiring` (non-commutative)
|
||||
|
||||
Results (including failures) are cached in a single `PHashMap ExprPtr ClassifyResult`
|
||||
to avoid repeated synthesis attempts.
|
||||
-/
|
||||
|
||||
private def getIsCharInst? (u : Level) (type : Expr) (semiringInst : Expr) : SymM (Option (Expr × Nat)) := do
|
||||
withNewMCtxDepth do
|
||||
let n ← mkFreshExprMVar (mkConst ``Nat)
|
||||
let charType := mkApp3 (mkConst ``Grind.IsCharP [u]) type semiringInst n
|
||||
let some charInst ← Sym.synthInstance? charType | return none
|
||||
let n ← instantiateMVars n
|
||||
let some n ← evalNat? n | return none
|
||||
return some (charInst, n)
|
||||
|
||||
private def getNoZeroDivInst? (u : Level) (type : Expr) : SymM (Option Expr) := do
|
||||
let natModuleType := mkApp (mkConst ``Grind.NatModule [u]) type
|
||||
let some natModuleInst ← Sym.synthInstance? natModuleType | return none
|
||||
let noZeroDivType := mkApp2 (mkConst ``Grind.NoNatZeroDivisors [u]) type natModuleInst
|
||||
Sym.synthInstance? noZeroDivType
|
||||
|
||||
/-- Try to classify `type` as a `CommRing`. Returns the ring id on success. -/
|
||||
private def tryCommRing? (type : Expr) : SymM (Option Nat) := do
|
||||
let u ← getDecLevel type
|
||||
let commRing := mkApp (mkConst ``Grind.CommRing [u]) type
|
||||
let some commRingInst ← Sym.synthInstance? commRing | return none
|
||||
let ringInst := mkApp2 (mkConst ``Grind.CommRing.toRing [u]) type commRingInst
|
||||
let semiringInst := mkApp2 (mkConst ``Grind.Ring.toSemiring [u]) type ringInst
|
||||
let commSemiringInst := mkApp2 (mkConst ``Grind.CommRing.toCommSemiring [u]) type semiringInst
|
||||
let charInst? ← getIsCharInst? u type semiringInst
|
||||
let noZeroDivInst? ← getNoZeroDivInst? u type
|
||||
let fieldInst? ← Sym.synthInstance? <| mkApp (mkConst ``Grind.Field [u]) type
|
||||
let semiringId? := none
|
||||
let id := (← getArithState).rings.size
|
||||
let ring : CommRing := {
|
||||
id, semiringId?, type, u, semiringInst, ringInst, commSemiringInst,
|
||||
commRingInst, charInst?, noZeroDivInst?, fieldInst?,
|
||||
}
|
||||
modifyArithState fun s => { s with rings := s.rings.push ring }
|
||||
return some id
|
||||
|
||||
/-- Try to classify `type` as a non-commutative `Ring`. -/
|
||||
private def tryNonCommRing? (type : Expr) : SymM (Option Nat) := do
|
||||
let u ← getDecLevel type
|
||||
let ring := mkApp (mkConst ``Grind.Ring [u]) type
|
||||
let some ringInst ← Sym.synthInstance? ring | return none
|
||||
let semiringInst := mkApp2 (mkConst ``Grind.Ring.toSemiring [u]) type ringInst
|
||||
let charInst? ← getIsCharInst? u type semiringInst
|
||||
let id := (← getArithState).ncRings.size
|
||||
let ring : Ring := {
|
||||
id, type, u, semiringInst, ringInst, charInst?
|
||||
}
|
||||
modifyArithState fun s => { s with ncRings := s.ncRings.push ring }
|
||||
return some id
|
||||
|
||||
/-- Helper function for `tryCommSemiring? -/
|
||||
private def tryCacheAndCommRing? (type : Expr) : SymM (Option Nat) := do
|
||||
if let some result := (← getArithState).typeClassify.find? { expr := type } then
|
||||
let .commRing id := result | return none
|
||||
return id
|
||||
let id? ← tryCommRing? type
|
||||
let result := match id? with
|
||||
| none => .none
|
||||
| some id => .commRing id
|
||||
modifyArithState fun s => { s with typeClassify := s.typeClassify.insert { expr := type } result }
|
||||
return id?
|
||||
|
||||
/-- Try to classify `type` as a `CommSemiring`. Creates the `OfSemiring.Q` envelope ring. -/
|
||||
private def tryCommSemiring? (type : Expr) : SymM (Option Nat) := do
|
||||
let u ← getDecLevel type
|
||||
let commSemiring := mkApp (mkConst ``Grind.CommSemiring [u]) type
|
||||
let some commSemiringInst ← Sym.synthInstance? commSemiring | return none
|
||||
let semiringInst := mkApp2 (mkConst ``Grind.CommSemiring.toSemiring [u]) type commSemiringInst
|
||||
let q ← shareCommon (← Sym.canon (mkApp2 (mkConst ``Grind.Ring.OfSemiring.Q [u]) type semiringInst))
|
||||
-- The envelope `Q` type must be classifiable as a CommRing.
|
||||
let some ringId ← tryCacheAndCommRing? q
|
||||
| reportIssue! "unexpected failure initializing ring{indentExpr q}"; return none
|
||||
let id := (← getArithState).semirings.size
|
||||
let semiring : CommSemiring := {
|
||||
id, type, ringId, u, semiringInst, commSemiringInst
|
||||
}
|
||||
modifyArithState fun s => { s with semirings := s.semirings.push semiring }
|
||||
-- Link the envelope ring back to this semiring
|
||||
modifyArithState fun s =>
|
||||
let rings := s.rings.modify ringId fun r => { r with semiringId? := some id }
|
||||
{ s with rings }
|
||||
return some id
|
||||
|
||||
/-- Try to classify `type` as a non-commutative `Semiring`. -/
|
||||
private def tryNonCommSemiring? (type : Expr) : SymM (Option Nat) := do
|
||||
let u ← getDecLevel type
|
||||
let semiring := mkApp (mkConst ``Grind.Semiring [u]) type
|
||||
let some semiringInst ← Sym.synthInstance? semiring | return none
|
||||
let id := (← getArithState).ncSemirings.size
|
||||
let semiring : Semiring := { id, type, u, semiringInst }
|
||||
modifyArithState fun s => { s with ncSemirings := s.ncSemirings.push semiring }
|
||||
return some id
|
||||
|
||||
/--
|
||||
Classify the algebraic structure of `type`, trying the strongest first:
|
||||
CommRing > Ring > CommSemiring > Semiring.
|
||||
Results are cached in `Arith.State.typeClassify`.
|
||||
-/
|
||||
def classify? (type : Expr) : SymM ClassifyResult := do
|
||||
if let some result := (← getArithState).typeClassify.find? { expr := type } then
|
||||
return result
|
||||
let result ← go
|
||||
modifyArithState fun s => { s with typeClassify := s.typeClassify.insert { expr := type } result }
|
||||
return result
|
||||
where
|
||||
go : SymM ClassifyResult := do
|
||||
if let some id ← tryCommRing? type then return .commRing id
|
||||
if let some id ← tryNonCommRing? type then return .nonCommRing id
|
||||
if let some id ← tryCommSemiring? type then return .commSemiring id
|
||||
if let some id ← tryNonCommSemiring? type then return .nonCommSemiring id
|
||||
return .none
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
93
src/Lean/Meta/Sym/Arith/DenoteExpr.lean
Normal file
93
src/Lean/Meta/Sym/Arith/DenoteExpr.lean
Normal file
@@ -0,0 +1,93 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Functions
|
||||
public import Lean.Meta.Sym.Arith.MonadVar
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-!
|
||||
# Denotation of reified expressions
|
||||
|
||||
Converts reified `RingExpr`, `Poly`, `Mon`, `Power` back into Lean `Expr`s using
|
||||
the ring's cached operator functions and variable array.
|
||||
-/
|
||||
|
||||
variable [Monad m] [MonadError m] [MonadLiftT MetaM m] [MonadCanon m] [MonadRing m]
|
||||
|
||||
/-- Convert an integer to a numeral expression in the ring. Negative values use `getNegFn`. -/
|
||||
def denoteNum (k : Int) : m Expr := do
|
||||
let ring ← getRing
|
||||
let n := mkRawNatLit k.natAbs
|
||||
let ofNatInst ← if let some inst ← MonadCanon.synthInstance? (mkApp2 (mkConst ``OfNat [ring.u]) ring.type n) then
|
||||
pure inst
|
||||
else
|
||||
pure <| mkApp3 (mkConst ``Grind.Semiring.ofNat [ring.u]) ring.type ring.semiringInst n
|
||||
let e := mkApp3 (mkConst ``OfNat.ofNat [ring.u]) ring.type n ofNatInst
|
||||
if k < 0 then
|
||||
return mkApp (← getNegFn) e
|
||||
else
|
||||
return e
|
||||
|
||||
/-- Denote a `Power` (variable raised to a power). -/
|
||||
def denotePower [MonadGetVar m] (pw : Power) : m Expr := do
|
||||
let x ← getVar pw.x
|
||||
if pw.k == 1 then
|
||||
return x
|
||||
else
|
||||
return mkApp2 (← getPowFn) x (toExpr pw.k)
|
||||
|
||||
/-- Denote a `Mon` (product of powers). -/
|
||||
def denoteMon [MonadGetVar m] (mn : Mon) : m Expr := do
|
||||
match mn with
|
||||
| .unit => denoteNum 1
|
||||
| .mult pw mn => go mn (← denotePower pw)
|
||||
where
|
||||
go (mn : Mon) (acc : Expr) : m Expr := do
|
||||
match mn with
|
||||
| .unit => return acc
|
||||
| .mult pw mn => go mn (mkApp2 (← getMulFn) acc (← denotePower pw))
|
||||
|
||||
/-- Denote a `Poly` (sum of coefficient × monomial terms). -/
|
||||
def denotePoly [MonadGetVar m] (p : Poly) : m Expr := do
|
||||
match p with
|
||||
| .num k => denoteNum k
|
||||
| .add k mn p => go p (← denoteTerm k mn)
|
||||
where
|
||||
denoteTerm (k : Int) (mn : Mon) : m Expr := do
|
||||
if k == 1 then
|
||||
denoteMon mn
|
||||
else
|
||||
return mkApp2 (← getMulFn) (← denoteNum k) (← denoteMon mn)
|
||||
|
||||
go (p : Poly) (acc : Expr) : m Expr := do
|
||||
match p with
|
||||
| .num 0 => return acc
|
||||
| .num k => return mkApp2 (← getAddFn) acc (← denoteNum k)
|
||||
| .add k mn p => go p (mkApp2 (← getAddFn) acc (← denoteTerm k mn))
|
||||
|
||||
/-- Denote a `RingExpr` using a variable lookup function. -/
|
||||
@[specialize]
|
||||
private def denoteRingExprCore (getVarExpr : Nat → Expr) (e : RingExpr) : m Expr := do
|
||||
go e
|
||||
where
|
||||
go : RingExpr → m Expr
|
||||
| .num k => denoteNum k
|
||||
| .natCast k => return mkApp (← getNatCastFn) (mkNatLit k)
|
||||
| .intCast k => return mkApp (← getIntCastFn) (mkIntLit k)
|
||||
| .var x => return getVarExpr x
|
||||
| .add a b => return mkApp2 (← getAddFn) (← go a) (← go b)
|
||||
| .sub a b => return mkApp2 (← getSubFn) (← go a) (← go b)
|
||||
| .mul a b => return mkApp2 (← getMulFn) (← go a) (← go b)
|
||||
| .pow a k => return mkApp2 (← getPowFn) (← go a) (toExpr k)
|
||||
| .neg a => return mkApp (← getNegFn) (← go a)
|
||||
|
||||
/-- Denote a `RingExpr` using an explicit variable array. -/
|
||||
def denoteRingExpr (vars : Array Expr) (e : RingExpr) : m Expr := do
|
||||
denoteRingExprCore (fun x => vars[x]!) e
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
90
src/Lean/Meta/Sym/Arith/EvalNum.lean
Normal file
90
src/Lean/Meta/Sym/Arith/EvalNum.lean
Normal file
@@ -0,0 +1,90 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Types
|
||||
import Lean.Meta.Sym.LitValues
|
||||
import Lean.Meta.IntInstTesters
|
||||
import Lean.Meta.NatInstTesters
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-!
|
||||
Functions for evaluating simple `Nat` and `Int` expressions that appear in type classes
|
||||
(e.g., `ToInt` and `IsCharP`). Using `whnf` for this purpose is too expensive and can
|
||||
exhaust the stack. We considered `evalExpr` as an alternative, but it introduces
|
||||
considerable overhead. We may use `evalExpr` as a fallback in the future.
|
||||
-/
|
||||
|
||||
def checkExp (k : Nat) : OptionT SymM Unit := do
|
||||
let exp ← getExpThreshold
|
||||
if k > exp then
|
||||
reportIssue! "exponent {k} exceeds threshold for exponentiation `(exp := {exp})`"
|
||||
failure
|
||||
|
||||
/-
|
||||
**Note**: It is safe to use (the more efficient) structural instance tests here because
|
||||
`Sym.Canon` has already run.
|
||||
-/
|
||||
open Structural in
|
||||
mutual
|
||||
private partial def evalNatCore (e : Expr) : OptionT SymM Nat := do
|
||||
match_expr e with
|
||||
| Nat.zero => return 0
|
||||
| Nat.succ a => return (← evalNatCore a) + 1
|
||||
| Int.toNat a => return (← evalIntCore a).toNat
|
||||
| Int.natAbs a => return (← evalIntCore a).natAbs
|
||||
| HAdd.hAdd _ _ _ inst a b => guard (← isInstHAddNat inst); return (← evalNatCore a) + (← evalNatCore b)
|
||||
| HMul.hMul _ _ _ inst a b => guard (← isInstHMulNat inst); return (← evalNatCore a) * (← evalNatCore b)
|
||||
| HSub.hSub _ _ _ inst a b => guard (← isInstHSubNat inst); return (← evalNatCore a) - (← evalNatCore b)
|
||||
| HDiv.hDiv _ _ _ inst a b => guard (← isInstHDivNat inst); return (← evalNatCore a) / (← evalNatCore b)
|
||||
| HMod.hMod _ _ _ inst a b => guard (← isInstHModNat inst); return (← evalNatCore a) % (← evalNatCore b)
|
||||
| OfNat.ofNat _ _ _ =>
|
||||
let some n := Sym.getNatValue? e |>.run | failure
|
||||
return n
|
||||
| HPow.hPow _ _ _ inst a k =>
|
||||
guard (← isInstHPowNat inst)
|
||||
let k ← evalNatCore k
|
||||
checkExp k
|
||||
let a ← evalNatCore a
|
||||
return a ^ k
|
||||
| _ => failure
|
||||
|
||||
private partial def evalIntCore (e : Expr) : OptionT SymM Int := do
|
||||
match_expr e with
|
||||
| Neg.neg _ i a => guard (← isInstNegInt i); return - (← evalIntCore a)
|
||||
| HAdd.hAdd _ _ _ i a b => guard (← isInstHAddInt i); return (← evalIntCore a) + (← evalIntCore b)
|
||||
| HSub.hSub _ _ _ i a b => guard (← isInstHSubInt i); return (← evalIntCore a) - (← evalIntCore b)
|
||||
| HMul.hMul _ _ _ i a b => guard (← isInstHMulInt i); return (← evalIntCore a) * (← evalIntCore b)
|
||||
| HDiv.hDiv _ _ _ i a b => guard (← isInstHDivInt i); return (← evalIntCore a) / (← evalIntCore b)
|
||||
| HMod.hMod _ _ _ i a b => guard (← isInstHModInt i); return (← evalIntCore a) % (← evalIntCore b)
|
||||
| HPow.hPow _ _ _ i a k =>
|
||||
guard (← isInstHPowInt i)
|
||||
let a ← evalIntCore a
|
||||
let k ← evalNatCore k
|
||||
checkExp k
|
||||
return a ^ k
|
||||
| OfNat.ofNat _ _ _ =>
|
||||
let some n := Sym.getIntValue? e |>.run | failure
|
||||
return n
|
||||
| NatCast.natCast _ i a =>
|
||||
let_expr instNatCastInt ← i | failure
|
||||
return (← evalNatCore a)
|
||||
| Nat.cast _ i a =>
|
||||
let_expr instNatCastInt ← i | failure
|
||||
return (← evalNatCore a)
|
||||
| _ => failure
|
||||
|
||||
end
|
||||
|
||||
def evalNat? (e : Expr) : SymM (Option Nat) :=
|
||||
evalNatCore e |>.run
|
||||
|
||||
def evalInt? (e : Expr) : SymM (Option Int) :=
|
||||
evalIntCore e |>.run
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
171
src/Lean/Meta/Sym/Arith/Functions.lean
Normal file
171
src/Lean/Meta/Sym/Arith/Functions.lean
Normal file
@@ -0,0 +1,171 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.MonadRing
|
||||
public import Lean.Meta.Sym.Arith.MonadSemiring
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-!
|
||||
# Cached function expressions for arithmetic operators
|
||||
|
||||
Synthesizes and caches the canonical Lean expressions for arithmetic operators
|
||||
(`+`, `*`, `-`, `^`, `intCast`, `natCast`, etc.). These cached expressions are used
|
||||
during reification to validate instances via pointer equality (`isSameExpr`).
|
||||
|
||||
Each getter checks the cache field first. On a miss, it synthesizes the instance,
|
||||
verifies it against the expected instance from the ring structure using `isDefEqI`,
|
||||
canonicalizes the result via `canonExpr`, and stores it.
|
||||
-/
|
||||
|
||||
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m]
|
||||
|
||||
private def checkInst (declName : Name) (inst inst' : Expr) : MetaM Unit := do
|
||||
unless (← withReducibleAndInstances <| isDefEq inst inst') do
|
||||
throwError "error while initializing arithmetic operators:\ninstance for `{declName}` {indentExpr inst}\nis not definitionally equal to the expected one {indentExpr inst'}\nwhen only reducible definitions and instances are reduced"
|
||||
|
||||
private def mkUnaryFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) (expectedInst : Expr) : m Expr := do
|
||||
let inst ← MonadCanon.synthInstance <| mkApp (mkConst instDeclName [u]) type
|
||||
checkInst declName inst expectedInst
|
||||
canonExpr <| mkApp2 (mkConst declName [u]) type inst
|
||||
|
||||
private def mkBinHomoFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) (expectedInst : Expr) : m Expr := do
|
||||
let inst ← MonadCanon.synthInstance <| mkApp3 (mkConst instDeclName [u, u, u]) type type type
|
||||
checkInst declName inst expectedInst
|
||||
canonExpr <| mkApp4 (mkConst declName [u, u, u]) type type type inst
|
||||
|
||||
private def mkPowFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
|
||||
let inst ← MonadCanon.synthInstance <| mkApp3 (mkConst ``HPow [u, 0, u]) type Nat.mkType type
|
||||
let inst' := mkApp2 (mkConst ``Grind.Semiring.npow [u]) type semiringInst
|
||||
checkInst ``HPow.hPow inst inst'
|
||||
canonExpr <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst
|
||||
|
||||
private def mkNatCastFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
|
||||
let inst' := mkApp2 (mkConst ``Grind.Semiring.natCast [u]) type semiringInst
|
||||
let instType := mkApp (mkConst ``NatCast [u]) type
|
||||
-- Note: `Semiring.natCast` is not a global instance, so `NatCast α` may not be available.
|
||||
-- When present, verify defeq; otherwise fall back to the semiring field.
|
||||
let inst ← match (← MonadCanon.synthInstance? instType) with
|
||||
| none => pure inst'
|
||||
| some inst => checkInst ``NatCast.natCast inst inst'; pure inst
|
||||
canonExpr <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst
|
||||
|
||||
section RingFns
|
||||
variable [MonadRing m]
|
||||
|
||||
def getAddFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some addFn := ring.addFn? then return addFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHAdd [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Semiring.toAdd [ring.u]) ring.type ring.semiringInst
|
||||
let addFn ← mkBinHomoFn ring.type ring.u ``HAdd ``HAdd.hAdd expectedInst
|
||||
modifyRing fun s => { s with addFn? := some addFn }
|
||||
return addFn
|
||||
|
||||
def getMulFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some mulFn := ring.mulFn? then return mulFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHMul [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Semiring.toMul [ring.u]) ring.type ring.semiringInst
|
||||
let mulFn ← mkBinHomoFn ring.type ring.u ``HMul ``HMul.hMul expectedInst
|
||||
modifyRing fun s => { s with mulFn? := some mulFn }
|
||||
return mulFn
|
||||
|
||||
def getSubFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some subFn := ring.subFn? then return subFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHSub [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Ring.toSub [ring.u]) ring.type ring.ringInst
|
||||
let subFn ← mkBinHomoFn ring.type ring.u ``HSub ``HSub.hSub expectedInst
|
||||
modifyRing fun s => { s with subFn? := some subFn }
|
||||
return subFn
|
||||
|
||||
def getNegFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some negFn := ring.negFn? then return negFn
|
||||
let expectedInst := mkApp2 (mkConst ``Grind.Ring.toNeg [ring.u]) ring.type ring.ringInst
|
||||
let negFn ← mkUnaryFn ring.type ring.u ``Neg ``Neg.neg expectedInst
|
||||
modifyRing fun s => { s with negFn? := some negFn }
|
||||
return negFn
|
||||
|
||||
def getPowFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some powFn := ring.powFn? then return powFn
|
||||
let powFn ← mkPowFn ring.u ring.type ring.semiringInst
|
||||
modifyRing fun s => { s with powFn? := some powFn }
|
||||
return powFn
|
||||
|
||||
def getIntCastFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some intCastFn := ring.intCastFn? then return intCastFn
|
||||
let inst' := mkApp2 (mkConst ``Grind.Ring.intCast [ring.u]) ring.type ring.ringInst
|
||||
let instType := mkApp (mkConst ``IntCast [ring.u]) ring.type
|
||||
-- Note: `Ring.intCast` is not a global instance. Same pattern as `mkNatCastFn`.
|
||||
let inst ← match (← MonadCanon.synthInstance? instType) with
|
||||
| none => pure inst'
|
||||
| some inst => checkInst ``Int.cast inst inst'; pure inst
|
||||
let intCastFn ← canonExpr <| mkApp2 (mkConst ``IntCast.intCast [ring.u]) ring.type inst
|
||||
modifyRing fun s => { s with intCastFn? := some intCastFn }
|
||||
return intCastFn
|
||||
|
||||
def getNatCastFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some natCastFn := ring.natCastFn? then return natCastFn
|
||||
let natCastFn ← mkNatCastFn ring.u ring.type ring.semiringInst
|
||||
modifyRing fun s => { s with natCastFn? := some natCastFn }
|
||||
return natCastFn
|
||||
|
||||
end RingFns
|
||||
|
||||
section CommRingFns
|
||||
variable [MonadCommRing m]
|
||||
|
||||
def getInvFn : m Expr := do
|
||||
let ring ← getCommRing
|
||||
let some fieldInst := ring.fieldInst?
|
||||
| throwError "internal error: type is not a field{indentExpr ring.type}"
|
||||
if let some invFn := ring.invFn? then return invFn
|
||||
let expectedInst := mkApp2 (mkConst ``Grind.Field.toInv [ring.u]) ring.type fieldInst
|
||||
let invFn ← mkUnaryFn ring.type ring.u ``Inv ``Inv.inv expectedInst
|
||||
modifyCommRing fun s => { s with invFn? := some invFn }
|
||||
return invFn
|
||||
|
||||
end CommRingFns
|
||||
|
||||
section SemiringFns
|
||||
variable [MonadSemiring m]
|
||||
|
||||
def getAddFn' : m Expr := do
|
||||
let sr ← getSemiring
|
||||
if let some addFn := sr.addFn? then return addFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHAdd [sr.u]) sr.type <| mkApp2 (mkConst ``Grind.Semiring.toAdd [sr.u]) sr.type sr.semiringInst
|
||||
let addFn ← mkBinHomoFn sr.type sr.u ``HAdd ``HAdd.hAdd expectedInst
|
||||
modifySemiring fun s => { s with addFn? := some addFn }
|
||||
return addFn
|
||||
|
||||
def getMulFn' : m Expr := do
|
||||
let sr ← getSemiring
|
||||
if let some mulFn := sr.mulFn? then return mulFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHMul [sr.u]) sr.type <| mkApp2 (mkConst ``Grind.Semiring.toMul [sr.u]) sr.type sr.semiringInst
|
||||
let mulFn ← mkBinHomoFn sr.type sr.u ``HMul ``HMul.hMul expectedInst
|
||||
modifySemiring fun s => { s with mulFn? := some mulFn }
|
||||
return mulFn
|
||||
|
||||
def getPowFn' : m Expr := do
|
||||
let sr ← getSemiring
|
||||
if let some powFn := sr.powFn? then return powFn
|
||||
let powFn ← mkPowFn sr.u sr.type sr.semiringInst
|
||||
modifySemiring fun s => { s with powFn? := some powFn }
|
||||
return powFn
|
||||
|
||||
def getNatCastFn' : m Expr := do
|
||||
let sr ← getSemiring
|
||||
if let some natCastFn := sr.natCastFn? then return natCastFn
|
||||
let natCastFn ← mkNatCastFn sr.u sr.type sr.semiringInst
|
||||
modifySemiring fun s => { s with natCastFn? := some natCastFn }
|
||||
return natCastFn
|
||||
|
||||
end SemiringFns
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
@@ -1,24 +1,23 @@
|
||||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public import Lean.Meta.Sym.Arith.Types
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
class MonadCanon (m : Type → Type) where
|
||||
/--
|
||||
Helper function for removing dependency on `GoalM`.
|
||||
In `RingM` and `SemiringM`, this is just `sharedCommon (← canon e)`
|
||||
When printing counterexamples, we are at `MetaM`, and this is just the identity function.
|
||||
Canonicalize an expression (types, instances, support arguments).
|
||||
In `SymM`, this is `Sym.canon`. In `PP.M` (diagnostics), this is the identity.
|
||||
-/
|
||||
canonExpr : Expr → m Expr
|
||||
/--
|
||||
Helper function for removing dependency on `GoalM`. During search we
|
||||
want to track the instances synthesized by `grind`, and this is `Grind.synthInstance`.
|
||||
Synthesize an instance, returning `none` on failure.
|
||||
In `SymM`, this is `Sym.synthInstance?`. In `PP.M`, this is `Meta.synthInstance?`.
|
||||
-/
|
||||
synthInstance? : Expr → m (Option Expr)
|
||||
|
||||
@@ -31,7 +30,7 @@ instance (m n) [MonadLift m n] [MonadCanon m] : MonadCanon n where
|
||||
|
||||
def MonadCanon.synthInstance [Monad m] [MonadError m] [MonadCanon m] (type : Expr) : m Expr := do
|
||||
let some inst ← synthInstance? type
|
||||
| throwError "`grind` failed to find instance{indentExpr type}"
|
||||
| throwError "failed to find instance{indentExpr type}"
|
||||
return inst
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
end Lean.Meta.Sym.Arith
|
||||
39
src/Lean/Meta/Sym/Arith/MonadRing.lean
Normal file
39
src/Lean/Meta/Sym/Arith/MonadRing.lean
Normal file
@@ -0,0 +1,39 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
class MonadRing (m : Type → Type) where
|
||||
getRing : m Ring
|
||||
modifyRing : (Ring → Ring) → m Unit
|
||||
|
||||
export MonadRing (getRing modifyRing)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadRing m] : MonadRing n where
|
||||
getRing := liftM (getRing : m Ring)
|
||||
modifyRing f := liftM (modifyRing f : m Unit)
|
||||
|
||||
class MonadCommRing (m : Type → Type) where
|
||||
getCommRing : m CommRing
|
||||
modifyCommRing : (CommRing → CommRing) → m Unit
|
||||
|
||||
export MonadCommRing (getCommRing modifyCommRing)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadCommRing m] : MonadCommRing n where
|
||||
getCommRing := liftM (getCommRing : m CommRing)
|
||||
modifyCommRing f := liftM (modifyCommRing f : m Unit)
|
||||
|
||||
@[always_inline]
|
||||
instance (m) [Monad m] [MonadCommRing m] : MonadRing m where
|
||||
getRing := return (← getCommRing).toRing
|
||||
modifyRing f := modifyCommRing fun s => { s with toRing := f s.toRing }
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
39
src/Lean/Meta/Sym/Arith/MonadSemiring.lean
Normal file
39
src/Lean/Meta/Sym/Arith/MonadSemiring.lean
Normal file
@@ -0,0 +1,39 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
class MonadSemiring (m : Type → Type) where
|
||||
getSemiring : m Semiring
|
||||
modifySemiring : (Semiring → Semiring) → m Unit
|
||||
|
||||
export MonadSemiring (getSemiring modifySemiring)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadSemiring m] : MonadSemiring n where
|
||||
getSemiring := liftM (getSemiring : m Semiring)
|
||||
modifySemiring f := liftM (modifySemiring f : m Unit)
|
||||
|
||||
class MonadCommSemiring (m : Type → Type) where
|
||||
getCommSemiring : m CommSemiring
|
||||
modifyCommSemiring : (CommSemiring → CommSemiring) → m Unit
|
||||
|
||||
export MonadCommSemiring (getCommSemiring modifyCommSemiring)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadCommSemiring m] : MonadCommSemiring n where
|
||||
getCommSemiring := liftM (getCommSemiring : m CommSemiring)
|
||||
modifyCommSemiring f := liftM (modifyCommSemiring f : m Unit)
|
||||
|
||||
@[always_inline]
|
||||
instance (m) [Monad m] [MonadCommSemiring m] : MonadSemiring m where
|
||||
getSemiring := return (← getCommSemiring).toSemiring
|
||||
modifySemiring f := modifyCommSemiring fun s => { s with toSemiring := f s.toSemiring }
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
32
src/Lean/Meta/Sym/Arith/MonadVar.lean
Normal file
32
src/Lean/Meta/Sym/Arith/MonadVar.lean
Normal file
@@ -0,0 +1,32 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Types
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-- Read a variable's Lean expression by index. Used by `DenoteExpr` and diagnostics (PP). -/
|
||||
class MonadGetVar (m : Type → Type) where
|
||||
getVar : Var → m Expr
|
||||
|
||||
export MonadGetVar (getVar)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadGetVar m] : MonadGetVar n where
|
||||
getVar x := liftM (getVar x : m Expr)
|
||||
|
||||
/-- Create or lookup a variable for a Lean expression. Used by reification. -/
|
||||
class MonadMkVar (m : Type → Type) where
|
||||
mkVar : Expr → m Var
|
||||
|
||||
export MonadMkVar (mkVar)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadMkVar m] : MonadMkVar n where
|
||||
mkVar e := liftM (mkVar e : m Var)
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
205
src/Lean/Meta/Sym/Arith/Reify.lean
Normal file
205
src/Lean/Meta/Sym/Arith/Reify.lean
Normal file
@@ -0,0 +1,205 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Functions
|
||||
public import Lean.Meta.Sym.Arith.MonadVar
|
||||
public import Lean.Meta.Sym.LitValues
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
/-!
|
||||
# Reification of arithmetic expressions
|
||||
|
||||
Converts Lean expressions into `CommRing.Expr` (ring) or `CommSemiring.Expr`
|
||||
(semiring) for reflection-based normalization.
|
||||
|
||||
Instance validation uses pointer equality (`isSameExpr`) against cached function
|
||||
expressions from `Functions.lean`.
|
||||
|
||||
## Differences from grind's `Reify.lean`
|
||||
|
||||
- Uses `MonadMkVar` for variable creation instead of grind's `internalize` + `mkVarCore`
|
||||
- Uses `Sym.getNatValue?`/`Sym.getIntValue?` (pure) instead of `MetaM` versions
|
||||
- No `MonadSetTermId` — term-to-ring-id tracking is grind-specific
|
||||
-/
|
||||
|
||||
section RingReify
|
||||
|
||||
variable [MonadLiftT SymM m] [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m] [MonadRing m] [MonadMkVar m]
|
||||
|
||||
def isAddInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getAddFn).appArg! inst
|
||||
def isMulInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getMulFn).appArg! inst
|
||||
def isSubInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getSubFn).appArg! inst
|
||||
def isNegInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getNegFn).appArg! inst
|
||||
def isPowInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getPowFn).appArg! inst
|
||||
def isIntCastInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getIntCastFn).appArg! inst
|
||||
def isNatCastInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getNatCastFn).appArg! inst
|
||||
|
||||
private def reportRingAppIssue [MonadLiftT SymM m] (e : Expr) : m Unit := do
|
||||
reportIssue! "ring term with unexpected instance{indentExpr e}"
|
||||
|
||||
/--
|
||||
Converts a Lean expression `e` into a `RingExpr`.
|
||||
|
||||
If `skipVar` is `true`, returns `none` if `e` is not an interpreted ring term
|
||||
(used for equalities/disequalities). If `false`, treats non-interpreted terms
|
||||
as variables (used for inequalities).
|
||||
-/
|
||||
partial def reifyRing? (e : Expr) (skipVar : Bool := true) : m (Option RingExpr) := do
|
||||
let toVar (e : Expr) : m RingExpr := do
|
||||
return .var (← mkVar e)
|
||||
let asVar (e : Expr) : m RingExpr := do
|
||||
reportRingAppIssue e
|
||||
return .var (← mkVar e)
|
||||
let rec go (e : Expr) : m RingExpr := do
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if (← isAddInst i) then return .add (← go a) (← go b) else asVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if (← isMulInst i) then return .mul (← go a) (← go b) else asVar e
|
||||
| HSub.hSub _ _ _ i a b =>
|
||||
if (← isSubInst i) then return .sub (← go a) (← go b) else asVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k := Sym.getNatValue? b |>.run | toVar e
|
||||
if (← isPowInst i) then return .pow (← go a) k else asVar e
|
||||
| Neg.neg _ i a =>
|
||||
if (← isNegInst i) then return .neg (← go a) else asVar e
|
||||
| IntCast.intCast _ i a =>
|
||||
if (← isIntCastInst i) then
|
||||
let some k := Sym.getIntValue? a |>.run | toVar e
|
||||
return .intCast k
|
||||
else
|
||||
asVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if (← isNatCastInst i) then
|
||||
let some k := Sym.getNatValue? a |>.run | toVar e
|
||||
return .natCast k
|
||||
else
|
||||
asVar e
|
||||
| OfNat.ofNat _ n _ =>
|
||||
/-
|
||||
**Note**: We extract `n` directly as a raw nat literal. The grind version uses `MetaM`'s
|
||||
`getNatValue?` which handles multiple encodings (raw literals, nested `OfNat`, etc.).
|
||||
In `SymM`, we assume terms have been canonicalized by `Sym.canon` before reification,
|
||||
so `OfNat.ofNat _ n _` always has a raw nat literal at position 1.
|
||||
-/
|
||||
let .lit (.natVal k) := n | toVar e
|
||||
return .num k
|
||||
| BitVec.ofNat _ n =>
|
||||
let .lit (.natVal k) := n | toVar e
|
||||
return .num k
|
||||
| _ => toVar e
|
||||
let toTopVar (e : Expr) : m (Option RingExpr) := do
|
||||
if skipVar then
|
||||
return none
|
||||
else
|
||||
return some (← toVar e)
|
||||
let asTopVar (e : Expr) : m (Option RingExpr) := do
|
||||
reportRingAppIssue e
|
||||
toTopVar e
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if (← isAddInst i) then return some (.add (← go a) (← go b)) else asTopVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if (← isMulInst i) then return some (.mul (← go a) (← go b)) else asTopVar e
|
||||
| HSub.hSub _ _ _ i a b =>
|
||||
if (← isSubInst i) then return some (.sub (← go a) (← go b)) else asTopVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k := Sym.getNatValue? b |>.run | asTopVar e
|
||||
if (← isPowInst i) then return some (.pow (← go a) k) else asTopVar e
|
||||
| Neg.neg _ i a =>
|
||||
if (← isNegInst i) then return some (.neg (← go a)) else asTopVar e
|
||||
| IntCast.intCast _ i a =>
|
||||
if (← isIntCastInst i) then
|
||||
let some k := Sym.getIntValue? a |>.run | toTopVar e
|
||||
return some (.intCast k)
|
||||
else
|
||||
asTopVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if (← isNatCastInst i) then
|
||||
let some k := Sym.getNatValue? a |>.run | toTopVar e
|
||||
return some (.natCast k)
|
||||
else
|
||||
asTopVar e
|
||||
| OfNat.ofNat _ n _ =>
|
||||
let .lit (.natVal k) := n | asTopVar e
|
||||
return some (.num k)
|
||||
| _ => toTopVar e
|
||||
|
||||
end RingReify
|
||||
|
||||
section SemiringReify
|
||||
|
||||
variable [MonadLiftT SymM m] [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m] [MonadSemiring m] [MonadMkVar m]
|
||||
|
||||
private def reportSemiringAppIssue [MonadLiftT SymM m] (e : Expr) : m Unit := do
|
||||
reportIssue! "semiring term with unexpected instance{indentExpr e}"
|
||||
|
||||
/--
|
||||
Converts a Lean expression `e` into a `SemiringExpr`.
|
||||
Only recognizes `add`, `mul`, `pow`, `natCast`, and numerals (no `sub`, `neg`, `intCast`).
|
||||
-/
|
||||
partial def reifySemiring? (e : Expr) : m (Option SemiringExpr) := do
|
||||
let toVar (e : Expr) : m SemiringExpr := do
|
||||
return .var (← mkVar e)
|
||||
let asVar (e : Expr) : m SemiringExpr := do
|
||||
reportSemiringAppIssue e
|
||||
return .var (← mkVar e)
|
||||
let rec go (e : Expr) : m SemiringExpr := do
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if isSameExpr (← getAddFn').appArg! i then return .add (← go a) (← go b) else asVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if isSameExpr (← getMulFn').appArg! i then return .mul (← go a) (← go b) else asVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k := Sym.getNatValue? b |>.run | toVar e
|
||||
if isSameExpr (← getPowFn').appArg! i then return .pow (← go a) k else asVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if isSameExpr (← getNatCastFn').appArg! i then
|
||||
let some k := Sym.getNatValue? a |>.run | toVar e
|
||||
return .num k
|
||||
else
|
||||
asVar e
|
||||
| OfNat.ofNat _ n _ =>
|
||||
let .lit (.natVal k) := n | toVar e
|
||||
return .num k
|
||||
| _ => toVar e
|
||||
let toTopVar (e : Expr) : m (Option SemiringExpr) := do
|
||||
return some (← toVar e)
|
||||
let asTopVar (e : Expr) : m (Option SemiringExpr) := do
|
||||
reportSemiringAppIssue e
|
||||
toTopVar e
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if isSameExpr (← getAddFn').appArg! i then return some (.add (← go a) (← go b)) else asTopVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if isSameExpr (← getMulFn').appArg! i then return some (.mul (← go a) (← go b)) else asTopVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k := Sym.getNatValue? b |>.run | return none
|
||||
if isSameExpr (← getPowFn').appArg! i then return some (.pow (← go a) k) else asTopVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if isSameExpr (← getNatCastFn').appArg! i then
|
||||
let some k := Sym.getNatValue? a |>.run | toTopVar e
|
||||
return some (.num k)
|
||||
else
|
||||
asTopVar e
|
||||
| OfNat.ofNat _ n _ =>
|
||||
let .lit (.natVal k) := n | asTopVar e
|
||||
return some (.num k)
|
||||
| _ => toTopVar e
|
||||
|
||||
end SemiringReify
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
@@ -8,7 +8,7 @@ prelude
|
||||
public import Init.Grind.Ring.CommSemiringAdapter
|
||||
public import Lean.ToExpr
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
open Grind.CommRing
|
||||
/-!
|
||||
`ToExpr` instances for `CommRing.Poly` types.
|
||||
@@ -57,4 +57,4 @@ instance : ToExpr CommRing.Expr where
|
||||
toExpr := ofRingExpr
|
||||
toTypeExpr := mkConst ``CommRing.Expr
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
end Lean.Meta.Sym.Arith
|
||||
137
src/Lean/Meta/Sym/Arith/Types.lean
Normal file
137
src/Lean/Meta/Sym/Arith/Types.lean
Normal file
@@ -0,0 +1,137 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Init.Grind.Ring.CommSemiringAdapter
|
||||
public import Lean.Meta.Sym.SymM
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
export Lean.Grind.CommRing (Var Power Mon Poly)
|
||||
abbrev RingExpr := Grind.CommRing.Expr
|
||||
/-
|
||||
**Note**: recall that we use ring expressions to represent semiring expressions,
|
||||
and ignore non-applicable constructors.
|
||||
-/
|
||||
abbrev SemiringExpr := Grind.CommRing.Expr
|
||||
|
||||
/-- Classification state for a type with a `Semiring` instance. -/
|
||||
structure Semiring where
|
||||
id : Nat
|
||||
type : Expr
|
||||
/-- Cached `getDecLevel type` -/
|
||||
u : Level
|
||||
/-- `Semiring` instance for `type` -/
|
||||
semiringInst : Expr
|
||||
addFn? : Option Expr := none
|
||||
mulFn? : Option Expr := none
|
||||
powFn? : Option Expr := none
|
||||
natCastFn? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/-- Classification state for a type with a `Ring` instance. -/
|
||||
structure Ring where
|
||||
id : Nat
|
||||
type : Expr
|
||||
/-- Cached `getDecLevel type` -/
|
||||
u : Level
|
||||
/-- `Ring` instance for `type` -/
|
||||
ringInst : Expr
|
||||
/-- `Semiring` instance for `type` -/
|
||||
semiringInst : Expr
|
||||
/-- `IsCharP` instance for `type` if available. -/
|
||||
charInst? : Option (Expr × Nat)
|
||||
addFn? : Option Expr := none
|
||||
mulFn? : Option Expr := none
|
||||
subFn? : Option Expr := none
|
||||
negFn? : Option Expr := none
|
||||
powFn? : Option Expr := none
|
||||
intCastFn? : Option Expr := none
|
||||
natCastFn? : Option Expr := none
|
||||
one? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/-- Classification state for a type with a `CommRing` instance. -/
|
||||
structure CommRing extends Ring where
|
||||
/-- Inverse function if `fieldInst?` is `some inst` -/
|
||||
invFn? : Option Expr := none
|
||||
/--
|
||||
If this is a `OfSemiring.Q α` ring, this field contains the
|
||||
`semiringId` for `α`.
|
||||
-/
|
||||
semiringId? : Option Nat
|
||||
/-- `CommSemiring` instance for `type` -/
|
||||
commSemiringInst : Expr
|
||||
/-- `CommRing` instance for `type` -/
|
||||
commRingInst : Expr
|
||||
/-- `NoNatZeroDivisors` instance for `type` if available. -/
|
||||
noZeroDivInst? : Option Expr
|
||||
/-- `Field` instance for `type` if available. -/
|
||||
fieldInst? : Option Expr
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
Classification state for a type with a `CommSemiring` instance.
|
||||
Recall that `CommSemiring` types are normalized using the `OfSemiring.Q` envelope.
|
||||
-/
|
||||
structure CommSemiring extends Semiring where
|
||||
/-- Id of the envelope ring `OfSemiring.Q type` -/
|
||||
ringId : Nat
|
||||
/-- `CommSemiring` instance for `type` -/
|
||||
commSemiringInst : Expr
|
||||
/-- `AddRightCancel` instance for `type` if available. -/
|
||||
addRightCancelInst? : Option (Option Expr) := none
|
||||
toQFn? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/-- Result of classifying a type's algebraic structure. -/
|
||||
inductive ClassifyResult where
|
||||
| commRing (id : Nat)
|
||||
| nonCommRing (id : Nat)
|
||||
| commSemiring (id : Nat)
|
||||
| nonCommSemiring (id : Nat)
|
||||
| /-- No algebraic structure found. -/ none
|
||||
deriving Inhabited
|
||||
|
||||
/-- Arith type classification state, stored as a `SymExtension`. -/
|
||||
structure State where
|
||||
/-- Exponent threshold for `HPow` evaluation. -/
|
||||
exp : Nat := 8
|
||||
/-- Commutative rings. -/
|
||||
rings : Array CommRing := {}
|
||||
/-- Commutative semirings. -/
|
||||
semirings : Array CommSemiring := {}
|
||||
/-- Non-commutative rings. -/
|
||||
ncRings : Array Ring := {}
|
||||
/-- Non-commutative semirings. -/
|
||||
ncSemirings : Array Semiring := {}
|
||||
/-- Mapping from types to their classification result. Caches failures as `.none`. -/
|
||||
typeClassify : PHashMap ExprPtr ClassifyResult := {}
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize arithExt : SymExtension State ← registerSymExtension (return {})
|
||||
|
||||
def getArithState : SymM State :=
|
||||
arithExt.getState
|
||||
|
||||
@[inline] def modifyArithState (f : State → State) : SymM Unit :=
|
||||
arithExt.modifyState f
|
||||
|
||||
/-- Get the exponent threshold. -/
|
||||
def getExpThreshold : SymM Nat :=
|
||||
return (← getArithState).exp
|
||||
|
||||
/-- Set the exponent threshold. -/
|
||||
def setExpThreshold (exp : Nat) : SymM Unit :=
|
||||
modifyArithState fun s => { s with exp }
|
||||
|
||||
/-- Run `k` with a temporary exponent threshold. -/
|
||||
def withExpThreshold (exp : Nat) (k : SymM α) : SymM α := do
|
||||
let oldExp := (← getArithState).exp
|
||||
setExpThreshold exp
|
||||
try k finally setExpThreshold oldExp
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
@@ -5,11 +5,9 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Internalize
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommRingM
|
||||
@@ -21,8 +19,6 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.Proof
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Inv
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.PP
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadSemiring
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Action
|
||||
|
||||
@@ -8,6 +8,7 @@ prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
/-!
|
||||
Helper functions for converting reified terms back into their denotations.
|
||||
-/
|
||||
|
||||
@@ -8,6 +8,7 @@ prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m]
|
||||
|
||||
section
|
||||
|
||||
@@ -6,7 +6,7 @@ Authors: Leonardo de Moura
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
import Lean.Meta.Sym.Arith.Poly
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
open Sym.Arith
|
||||
structure NonCommRingM.Context where
|
||||
ringId : Nat
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
structure NonCommSemiringM.Context where
|
||||
semiringId : Nat
|
||||
|
||||
@@ -10,6 +10,7 @@ import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
import Init.Omega
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
|
||||
private abbrev M := StateT CommRing MetaM
|
||||
|
||||
|
||||
@@ -12,12 +12,14 @@ import Lean.Data.RArray
|
||||
import Lean.Meta.Tactic.Grind.Diseq
|
||||
import Lean.Meta.Tactic.Grind.ProofUtil
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
import Lean.Meta.Sym.Arith.ToExpr
|
||||
import Lean.Meta.Sym.Arith.VarRename
|
||||
import Init.Data.Nat.Order
|
||||
import Init.Data.Order.Lemmas
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
/--
|
||||
Returns a context of type `RArray α` containing the variables `vars` where
|
||||
`α` is the type of the ring.
|
||||
|
||||
@@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommRingM
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommSemiringM
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m] [MonadRing m]
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Grind.SynthInstance
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
|
||||
def checkMaxSteps : GoalM Bool := do
|
||||
return (← get').steps >= (← getConfig).ringSteps
|
||||
|
||||
@@ -6,7 +6,7 @@ Authors: Leonardo de Moura
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
public import Lean.Meta.Sym.Arith.Poly
|
||||
import Lean.Meta.Tactic.Grind.Arith.EvalNum
|
||||
import Init.Data.Nat.Linear
|
||||
public section
|
||||
|
||||
@@ -11,6 +11,7 @@ import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
|
||||
structure SemiringM.Context where
|
||||
semiringId : Nat
|
||||
|
||||
@@ -7,7 +7,7 @@ module
|
||||
prelude
|
||||
public import Init.Grind.Ring.CommSemiringAdapter
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
import Lean.Meta.Sym.Arith.Poly
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
@@ -14,8 +14,8 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.CommRing
|
||||
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
|
||||
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Nat
|
||||
import Lean.Meta.Tactic.Grind.Arith.Cutsat.VarRename
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
|
||||
import Lean.Meta.Sym.Arith.VarRename
|
||||
import Lean.Meta.Sym.Arith.ToExpr
|
||||
import Init.Data.Nat.Order
|
||||
import Init.Data.Order.Lemmas
|
||||
public section
|
||||
|
||||
@@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Grind.Arith.Linear.Types
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.Linear
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
def get' : GoalM State := do
|
||||
linearExt.getState
|
||||
|
||||
@@ -11,8 +11,8 @@ import Lean.Data.RArray
|
||||
import Lean.Meta.Tactic.Grind.Arith.Linear.ToExpr
|
||||
import Lean.Meta.Tactic.Grind.Diseq
|
||||
import Lean.Meta.Tactic.Grind.ProofUtil
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
|
||||
import Lean.Meta.Sym.Arith.VarRename
|
||||
import Lean.Meta.Sym.Arith.ToExpr
|
||||
import Lean.Meta.Tactic.Grind.Arith.Linear.VarRename
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Linear.OfNatModule
|
||||
|
||||
@@ -97,6 +97,8 @@ def mkCnstrNorm0 (s : Struct) (ringInst : Expr) (kind : CnstrKind) (lhs rhs : Ex
|
||||
| .le => mkLeNorm0 s ringInst lhs rhs
|
||||
| .lt => mkLtNorm0 s ringInst lhs rhs
|
||||
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
/--
|
||||
Returns `rel lhs (rhs + 0)`
|
||||
-/
|
||||
|
||||
@@ -21,9 +21,6 @@ opaque maxSmallNatFn : Unit → Nat
|
||||
@[extern "lean_libuv_version"]
|
||||
opaque libUVVersionFn : Unit → Nat
|
||||
|
||||
@[extern "lean_openssl_version"]
|
||||
opaque openSSLVersionFn : Unit → Nat
|
||||
|
||||
def closureMaxArgs : Nat :=
|
||||
closureMaxArgsFn ()
|
||||
|
||||
@@ -33,7 +30,4 @@ def maxSmallNat : Nat :=
|
||||
def libUVVersion : Nat :=
|
||||
libUVVersionFn ()
|
||||
|
||||
def openSSLVersion : Nat :=
|
||||
openSSLVersionFn ()
|
||||
|
||||
end Lean
|
||||
|
||||
@@ -10,7 +10,6 @@ public import Std.Internal.Async
|
||||
public import Std.Internal.Http
|
||||
public import Std.Internal.Parsec
|
||||
public import Std.Internal.UV
|
||||
public import Std.Internal.SSL
|
||||
|
||||
@[expose] public section
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ public import Std.Internal.Async.Basic
|
||||
public import Std.Internal.Async.ContextAsync
|
||||
public import Std.Internal.Async.Timer
|
||||
public import Std.Internal.Async.TCP
|
||||
public import Std.Internal.Async.TCP.SSL
|
||||
public import Std.Internal.Async.UDP
|
||||
public import Std.Internal.Async.DNS
|
||||
public import Std.Internal.Async.Select
|
||||
@@ -18,4 +17,3 @@ public import Std.Internal.Async.Process
|
||||
public import Std.Internal.Async.System
|
||||
public import Std.Internal.Async.Signal
|
||||
public import Std.Internal.Async.IO
|
||||
public import Std.Internal.SSL
|
||||
|
||||
@@ -8,7 +8,6 @@ module
|
||||
prelude
|
||||
public import Std.Time
|
||||
public import Std.Internal.UV.TCP
|
||||
public import Std.Internal.Async.IO
|
||||
public import Std.Internal.Async.Select
|
||||
|
||||
public section
|
||||
|
||||
@@ -1,442 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Time
|
||||
public import Std.Internal.UV.TCP
|
||||
public import Std.Internal.Async.Timer
|
||||
public import Std.Internal.Async.Select
|
||||
public import Std.Internal.SSL
|
||||
|
||||
public section
|
||||
|
||||
namespace Std.Internal.IO.Async.TCP.SSL
|
||||
|
||||
open Std.Internal.SSL
|
||||
open Std.Internal.UV.TCP
|
||||
open Std.Net
|
||||
|
||||
/--
|
||||
Default chunk size used by TLS I/O loops.
|
||||
-/
|
||||
def ioChunkSize : UInt64 := 16 * 1024
|
||||
|
||||
-- ## Private helpers: SSL ↔ TCP I/O bridge
|
||||
|
||||
/--
|
||||
Feeds an encrypted chunk into the SSL input BIO.
|
||||
Raises an error if OpenSSL consumed fewer bytes than supplied.
|
||||
-/
|
||||
@[inline]
|
||||
private def feedEncryptedChunk (ssl : Session r) (encrypted : ByteArray) : IO Unit := do
|
||||
if encrypted.size == 0 then return ()
|
||||
let consumed ← ssl.feedEncrypted encrypted
|
||||
if consumed.toNat != encrypted.size then
|
||||
throw <| IO.userError s!"TLS input short write: consumed {consumed} / {encrypted.size} bytes"
|
||||
|
||||
/--
|
||||
Drains all pending encrypted bytes from the SSL output BIO and sends them over TCP.
|
||||
-/
|
||||
private partial def flushEncrypted (native : Socket) (ssl : Session r) : Async Unit := do
|
||||
let out ← ssl.drainEncrypted
|
||||
if out.size == 0 then return ()
|
||||
Async.ofPromise <| native.send #[out]
|
||||
flushEncrypted native ssl
|
||||
|
||||
/--
|
||||
Runs the TLS handshake loop to completion, interleaving SSL state machine steps
|
||||
with TCP I/O.
|
||||
-/
|
||||
private partial def doHandshake (native : Socket) (ssl : Session r) (chunkSize : UInt64) : Async Unit := do
|
||||
let want ← ssl.handshake
|
||||
flushEncrypted native ssl
|
||||
match want with
|
||||
| none =>
|
||||
return ()
|
||||
| some .write =>
|
||||
doHandshake native ssl chunkSize
|
||||
| some .read =>
|
||||
let encrypted? ← Async.ofPromise <| native.recv? chunkSize
|
||||
match encrypted? with
|
||||
| none =>
|
||||
throw <| IO.userError "connection closed during TLS handshake"
|
||||
| some encrypted =>
|
||||
feedEncryptedChunk ssl encrypted
|
||||
doHandshake native ssl chunkSize
|
||||
|
||||
-- ## Types
|
||||
|
||||
/--
|
||||
Represents a TLS-enabled TCP server socket. Carries its own server context so
|
||||
that each accepted connection gets a session configured from the same context.
|
||||
-/
|
||||
structure Server where
|
||||
private ofNative ::
|
||||
native : Socket
|
||||
serverCtx : Context.Server
|
||||
|
||||
/--
|
||||
Represents a TLS-enabled TCP connection, parameterized by TLS role.
|
||||
Use `Client` for outgoing connections and `ServerConn` for server-accepted connections.
|
||||
-/
|
||||
structure Connection (r : Role) where
|
||||
private ofNative ::
|
||||
native : Socket
|
||||
ssl : Session r
|
||||
|
||||
/--
|
||||
An outgoing TLS client connection.
|
||||
-/
|
||||
abbrev Client := Connection .client
|
||||
|
||||
/--
|
||||
An incoming TLS connection accepted by a `Server`.
|
||||
-/
|
||||
abbrev ServerConn := Connection .server
|
||||
|
||||
namespace Server
|
||||
|
||||
/--
|
||||
Creates a new TLS-enabled TCP server socket using the given context.
|
||||
-/
|
||||
@[inline]
|
||||
def mk (serverCtx : Context.Server) : IO Server := do
|
||||
let native ← Socket.new
|
||||
return Server.ofNative native serverCtx
|
||||
|
||||
/--
|
||||
Configures the server context with a PEM certificate and private key.
|
||||
-/
|
||||
@[inline]
|
||||
def configureServer (s : Server) (certFile keyFile : String) : IO Unit :=
|
||||
s.serverCtx.configure certFile keyFile
|
||||
|
||||
/--
|
||||
Binds the server socket to the specified address.
|
||||
-/
|
||||
@[inline]
|
||||
def bind (s : Server) (addr : SocketAddress) : IO Unit :=
|
||||
s.native.bind addr
|
||||
|
||||
/--
|
||||
Listens for incoming connections with the given backlog.
|
||||
-/
|
||||
@[inline]
|
||||
def listen (s : Server) (backlog : UInt32) : IO Unit :=
|
||||
s.native.listen backlog
|
||||
|
||||
@[inline] private def mkServerConn (native : Socket) (ctx : Context.Server) : IO ServerConn := do
|
||||
let ssl ← Session.Server.mk ctx
|
||||
return ⟨native, ssl⟩
|
||||
|
||||
/--
|
||||
Accepts an incoming TLS connection and performs the TLS handshake.
|
||||
-/
|
||||
@[inline]
|
||||
def accept (s : Server) (chunkSize : UInt64 := ioChunkSize) : Async ServerConn := do
|
||||
let native ← Async.ofPromise <| s.native.accept
|
||||
let conn ← mkServerConn native s.serverCtx
|
||||
doHandshake conn.native conn.ssl chunkSize
|
||||
return conn
|
||||
|
||||
/--
|
||||
Creates a `Selector` that resolves once `s` has a connection available and the TLS handshake
|
||||
has completed.
|
||||
-/
|
||||
def acceptSelector (s : Server) : Selector ServerConn :=
|
||||
{
|
||||
tryFn := do
|
||||
let res ← s.native.tryAccept
|
||||
match ← IO.ofExcept res with
|
||||
| none => return none
|
||||
| some native =>
|
||||
let conn ← mkServerConn native s.serverCtx
|
||||
doHandshake conn.native conn.ssl ioChunkSize
|
||||
return some conn
|
||||
|
||||
registerFn waiter := do
|
||||
let connTask ← (do
|
||||
let native ← Async.ofPromise <| s.native.accept
|
||||
let ssl ← Session.Server.mk s.serverCtx
|
||||
let conn : ServerConn := ⟨native, ssl⟩
|
||||
doHandshake conn.native conn.ssl ioChunkSize
|
||||
return conn
|
||||
).asTask
|
||||
|
||||
-- If we get cancelled the promise will be dropped so prepare for that
|
||||
discard <| IO.mapTask (t := connTask) fun res => do
|
||||
let lose := return ()
|
||||
let win promise := do
|
||||
try
|
||||
let conn ← IO.ofExcept res
|
||||
promise.resolve (.ok conn)
|
||||
catch e =>
|
||||
promise.resolve (.error e)
|
||||
waiter.race lose win
|
||||
|
||||
unregisterFn := s.native.cancelAccept
|
||||
}
|
||||
|
||||
/--
|
||||
Gets the local address of the server socket.
|
||||
-/
|
||||
@[inline]
|
||||
def getSockName (s : Server) : IO SocketAddress :=
|
||||
s.native.getSockName
|
||||
|
||||
/--
|
||||
Disables the Nagle algorithm for all client sockets accepted by this server socket.
|
||||
-/
|
||||
@[inline]
|
||||
def noDelay (s : Server) : IO Unit :=
|
||||
s.native.noDelay
|
||||
|
||||
/--
|
||||
Enables TCP keep-alive for all client sockets accepted by this server socket.
|
||||
-/
|
||||
@[inline]
|
||||
def keepAlive (s : Server) (enable : Bool) (delay : Std.Time.Second.Offset) (_ : delay.val ≥ 1 := by decide) : IO Unit :=
|
||||
s.native.keepAlive enable.toInt8 delay.val.toNat.toUInt32
|
||||
|
||||
end Server
|
||||
|
||||
namespace Connection
|
||||
|
||||
/--
|
||||
Attempts to write plaintext data into TLS. Returns true when accepted.
|
||||
Any encrypted TLS output generated is flushed to the socket.
|
||||
-/
|
||||
@[inline]
|
||||
def write {r : Role} (s : Connection r) (data : ByteArray) : Async Bool := do
|
||||
match ← s.ssl.write data with
|
||||
| none =>
|
||||
flushEncrypted s.native s.ssl
|
||||
return true
|
||||
| some _ =>
|
||||
-- Data was queued internally; flush whatever the SSL engine produced.
|
||||
flushEncrypted s.native s.ssl
|
||||
return false
|
||||
|
||||
/--
|
||||
Sends data through a TLS-enabled socket.
|
||||
Fails if OpenSSL reports the write as pending additional I/O.
|
||||
-/
|
||||
@[inline]
|
||||
def send {r : Role} (s : Connection r) (data : ByteArray) : Async Unit := do
|
||||
if ← s.write data then
|
||||
return ()
|
||||
else
|
||||
throw <| IO.userError "TLS write is pending additional I/O; call `recv?` or retry later"
|
||||
|
||||
/--
|
||||
Sends multiple data buffers through the TLS-enabled socket.
|
||||
-/
|
||||
@[inline]
|
||||
def sendAll {r : Role} (s : Connection r) (data : Array ByteArray) : Async Unit :=
|
||||
data.forM (s.send ·)
|
||||
|
||||
/--
|
||||
Receives decrypted plaintext data from TLS.
|
||||
If no plaintext is immediately available, this function performs the required socket I/O
|
||||
(flush or receive) and retries until data arrives or the connection is closed.
|
||||
-/
|
||||
partial def recv? {r : Role} (s : Connection r) (size : UInt64) (chunkSize : UInt64 := ioChunkSize) : Async (Option ByteArray) := do
|
||||
match ← s.ssl.read? size with
|
||||
| .data plain =>
|
||||
flushEncrypted s.native s.ssl
|
||||
return some plain
|
||||
| .closed =>
|
||||
return none
|
||||
| .wantIO _ =>
|
||||
flushEncrypted s.native s.ssl
|
||||
let encrypted? ← Async.ofPromise <| s.native.recv? chunkSize
|
||||
match encrypted? with
|
||||
| none =>
|
||||
return none
|
||||
| some encrypted =>
|
||||
feedEncryptedChunk s.ssl encrypted
|
||||
recv? s size chunkSize
|
||||
|
||||
/--
|
||||
Tries to receive decrypted plaintext data without blocking.
|
||||
Returns `some (some data)` if plaintext is available, `some none` if the peer closed,
|
||||
or `none` if no data is ready yet.
|
||||
-/
|
||||
partial def tryRecv {r : Role} (s : Connection r) (size : UInt64) (chunkSize : UInt64 := ioChunkSize) : Async (Option (Option ByteArray)) := do
|
||||
let pending ← s.ssl.pendingPlaintext
|
||||
|
||||
if pending > 0 then
|
||||
return some (← s.recv? size)
|
||||
else
|
||||
let readableWaiter ← s.native.waitReadable
|
||||
|
||||
flushEncrypted s.native s.ssl
|
||||
|
||||
if ← readableWaiter.isResolved then
|
||||
let encrypted? ← Async.ofPromise <| s.native.recv? ioChunkSize
|
||||
match encrypted? with
|
||||
| none =>
|
||||
return none
|
||||
| some encrypted =>
|
||||
feedEncryptedChunk s.ssl encrypted
|
||||
tryRecv s size chunkSize
|
||||
else
|
||||
s.native.cancelRecv
|
||||
return none
|
||||
|
||||
/--
|
||||
Feeds encrypted socket data into SSL until plaintext is pending.
|
||||
Resolves the returned promise once plaintext is available.
|
||||
-/
|
||||
partial def waitReadable {r : Role} (s : Connection r) : Async Unit := do
|
||||
flushEncrypted s.native s.ssl
|
||||
|
||||
let pending ← s.ssl.pendingPlaintext
|
||||
if pending > 0 then
|
||||
return ()
|
||||
|
||||
if (← s.ssl.pendingPlaintext) > 0 then
|
||||
return ()
|
||||
|
||||
match ← s.ssl.read? 0 with
|
||||
| .data _ =>
|
||||
flushEncrypted s.native s.ssl
|
||||
return ()
|
||||
| .closed =>
|
||||
return ()
|
||||
| .wantIO _ =>
|
||||
flushEncrypted s.native s.ssl
|
||||
let encrypted? ← Async.ofPromise <| s.native.recv? ioChunkSize
|
||||
match encrypted? with
|
||||
| none => return ()
|
||||
| some encrypted =>
|
||||
feedEncryptedChunk s.ssl encrypted
|
||||
waitReadable s
|
||||
|
||||
/--
|
||||
Creates a `Selector` that resolves once `s` has plaintext data available.
|
||||
-/
|
||||
def recvSelector {r : Role} (s : Connection r) (size : UInt64) : Selector (Option ByteArray) :=
|
||||
{
|
||||
tryFn := s.tryRecv size
|
||||
|
||||
registerFn waiter := do
|
||||
let readableWaiter ← s.waitReadable.asTask
|
||||
|
||||
-- If we get cancelled the promise will be dropped so prepare for that
|
||||
discard <| IO.mapTask (t := readableWaiter) fun res => do
|
||||
match res with
|
||||
| .error _ => return ()
|
||||
| .ok _ =>
|
||||
let lose := return ()
|
||||
let win promise := do
|
||||
try
|
||||
-- We know that this read should not block.
|
||||
let result ← (s.recv? size).block
|
||||
promise.resolve (.ok result)
|
||||
catch e =>
|
||||
promise.resolve (.error e)
|
||||
waiter.race lose win
|
||||
|
||||
unregisterFn := s.native.cancelRecv
|
||||
}
|
||||
|
||||
/--
|
||||
Shuts down the write side of the socket.
|
||||
-/
|
||||
@[inline]
|
||||
def shutdown {r : Role} (s : Connection r) : Async Unit :=
|
||||
Async.ofPromise <| s.native.shutdown
|
||||
|
||||
/--
|
||||
Gets the remote address of the socket.
|
||||
-/
|
||||
@[inline]
|
||||
def getPeerName {r : Role} (s : Connection r) : IO SocketAddress :=
|
||||
s.native.getPeerName
|
||||
|
||||
/--
|
||||
Gets the local address of the socket.
|
||||
-/
|
||||
@[inline]
|
||||
def getSockName {r : Role} (s : Connection r) : IO SocketAddress :=
|
||||
s.native.getSockName
|
||||
|
||||
/--
|
||||
Returns the X.509 verification result code for this TLS session.
|
||||
-/
|
||||
@[inline]
|
||||
def verifyResult {r : Role} (s : Connection r) : IO UInt64 :=
|
||||
s.ssl.verifyResult
|
||||
|
||||
/--
|
||||
Disables the Nagle algorithm for the socket.
|
||||
-/
|
||||
@[inline]
|
||||
def noDelay {r : Role} (s : Connection r) : IO Unit :=
|
||||
s.native.noDelay
|
||||
|
||||
/--
|
||||
Enables TCP keep-alive with a specified delay for the socket.
|
||||
-/
|
||||
@[inline]
|
||||
def keepAlive {r : Role} (s : Connection r) (enable : Bool) (delay : Std.Time.Second.Offset) (_ : delay.val ≥ 0 := by decide) : IO Unit :=
|
||||
s.native.keepAlive enable.toInt8 delay.val.toNat.toUInt32
|
||||
|
||||
end Connection
|
||||
|
||||
-- ## Client (outgoing connection setup)
|
||||
|
||||
namespace Client
|
||||
|
||||
/--
|
||||
Creates a new outgoing TLS client connection using the given client context.
|
||||
-/
|
||||
@[inline]
|
||||
def mk (ctx : Context.Client) : IO Client := do
|
||||
let native ← Socket.new
|
||||
let ssl ← Session.Client.mk ctx
|
||||
return ⟨native, ssl⟩
|
||||
|
||||
/--
|
||||
Configures the given client context.
|
||||
`caFile` may be empty to use default trust settings.
|
||||
-/
|
||||
@[inline]
|
||||
def configureContext (ctx : Context.Client) (caFile := "") (verifyPeer := true) : IO Unit :=
|
||||
ctx.configure caFile verifyPeer
|
||||
|
||||
/--
|
||||
Binds the client socket to the specified address.
|
||||
-/
|
||||
@[inline]
|
||||
def bind (s : Client) (addr : SocketAddress) : IO Unit :=
|
||||
s.native.bind addr
|
||||
|
||||
/--
|
||||
Sets SNI server name used during the TLS handshake.
|
||||
-/
|
||||
@[inline]
|
||||
def setServerName (s : Client) (host : String) : IO Unit :=
|
||||
Session.Client.setServerName s.ssl host
|
||||
|
||||
/--
|
||||
Performs the TLS handshake on an established TCP connection.
|
||||
-/
|
||||
@[inline]
|
||||
def handshake (s : Client) (chunkSize : UInt64 := ioChunkSize) : Async Unit :=
|
||||
doHandshake (Connection.native s) (Connection.ssl s) chunkSize
|
||||
|
||||
/--
|
||||
Connects the client socket to the given address and performs the TLS handshake.
|
||||
-/
|
||||
@[inline]
|
||||
def connect (s : Client) (addr : SocketAddress) (chunkSize : UInt64 := ioChunkSize) : Async Unit := do
|
||||
Async.ofPromise <| (Connection.native s).connect addr
|
||||
s.handshake chunkSize
|
||||
|
||||
end Std.Internal.IO.Async.TCP.SSL.Client
|
||||
@@ -1,10 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Internal.SSL.Context
|
||||
public import Std.Internal.SSL.Session
|
||||
@@ -1,75 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.System.Promise
|
||||
|
||||
public section
|
||||
|
||||
namespace Std.Internal.SSL
|
||||
|
||||
/--
|
||||
Distinguishes server-side from client-side TLS contexts and sessions at the type level.
|
||||
-/
|
||||
inductive Role where
|
||||
| server
|
||||
| client
|
||||
|
||||
private opaque ContextServerImpl : NonemptyType.{0}
|
||||
private opaque ContextClientImpl : NonemptyType.{0}
|
||||
|
||||
/--
|
||||
Server-side TLS context (`SSL_CTX` configured with `TLS_server_method`).
|
||||
-/
|
||||
def Context.Server : Type := ContextServerImpl.type
|
||||
|
||||
/--
|
||||
Client-side TLS context (`SSL_CTX` configured with `TLS_client_method`).
|
||||
-/
|
||||
def Context.Client : Type := ContextClientImpl.type
|
||||
|
||||
instance : Nonempty Context.Server := ContextServerImpl.property
|
||||
instance : Nonempty Context.Client := ContextClientImpl.property
|
||||
|
||||
namespace Context
|
||||
|
||||
namespace Server
|
||||
|
||||
/--
|
||||
Creates a new server-side TLS context using `TLS_server_method`.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_ctx_mk_server"]
|
||||
opaque mk : IO Context.Server
|
||||
|
||||
/--
|
||||
Loads a PEM certificate and private key into a server context.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_ctx_configure_server"]
|
||||
opaque configure (ctx : @& Context.Server) (certFile : @& String) (keyFile : @& String) : IO Unit
|
||||
|
||||
end Server
|
||||
|
||||
namespace Client
|
||||
|
||||
/--
|
||||
Creates a new client-side TLS context using `TLS_client_method`.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_ctx_mk_client"]
|
||||
opaque mk : IO Context.Client
|
||||
|
||||
/--
|
||||
Configures CA trust anchors and peer verification for a client context.
|
||||
`caFile` may be empty to use platform default trust anchors.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_ctx_configure_client"]
|
||||
opaque configure (ctx : @& Context.Client) (caFile : @& String) (verifyPeer : Bool) : IO Unit
|
||||
|
||||
end Client
|
||||
|
||||
end Context
|
||||
|
||||
end Std.Internal.SSL
|
||||
@@ -1,152 +0,0 @@
|
||||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Internal.SSL.Context
|
||||
|
||||
public section
|
||||
|
||||
namespace Std.Internal.SSL
|
||||
|
||||
private opaque SessionImpl : Role → NonemptyType.{0}
|
||||
|
||||
/--
|
||||
Indicates what kind of socket I/O OpenSSL needs before the current operation can proceed.
|
||||
-/
|
||||
inductive IOWant where
|
||||
|
||||
/--
|
||||
OpenSSL needs more encrypted bytes from the socket (`SSL_ERROR_WANT_READ`).
|
||||
-/
|
||||
| read
|
||||
|
||||
/--
|
||||
OpenSSL needs to flush encrypted bytes to the socket (`SSL_ERROR_WANT_WRITE`).
|
||||
-/
|
||||
| write
|
||||
|
||||
/--
|
||||
Result of a `Session.read?` call.
|
||||
-/
|
||||
inductive ReadResult where
|
||||
|
||||
/--
|
||||
Plaintext data was successfully decrypted.
|
||||
-/
|
||||
| data (bytes : ByteArray)
|
||||
|
||||
/--
|
||||
OpenSSL needs socket I/O before it can produce plaintext.
|
||||
-/
|
||||
| wantIO (want : IOWant)
|
||||
|
||||
/--
|
||||
The peer closed the TLS session cleanly (`SSL_ERROR_ZERO_RETURN`).
|
||||
-/
|
||||
| closed
|
||||
|
||||
/--
|
||||
Represents an OpenSSL SSL session parameterized by role.
|
||||
Use `Session.Server` or `Session.Client` for the concrete aliases.
|
||||
-/
|
||||
def Session (r : Role) : Type := (SessionImpl r).type
|
||||
|
||||
/--
|
||||
Server-side TLS session.
|
||||
-/
|
||||
abbrev Session.Server := Session .server
|
||||
|
||||
/--
|
||||
Client-side TLS session.
|
||||
-/
|
||||
abbrev Session.Client := Session .client
|
||||
|
||||
instance (r : Role) : Nonempty (Session r) := (SessionImpl r).property
|
||||
|
||||
namespace Session
|
||||
|
||||
namespace Server
|
||||
|
||||
/--
|
||||
Creates a new server-side SSL session from the given context.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_mk_server"]
|
||||
opaque mk (ctx : @& Context.Server) : IO Session.Server
|
||||
|
||||
end Server
|
||||
|
||||
namespace Client
|
||||
|
||||
/--
|
||||
Creates a new client-side SSL session from the given context.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_mk_client"]
|
||||
opaque mk (ctx : @& Context.Client) : IO Session.Client
|
||||
|
||||
/--
|
||||
Sets the SNI host name for client-side handshakes.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_set_server_name"]
|
||||
opaque setServerName (ssl : @& Session.Client) (host : @& String) : IO Unit
|
||||
|
||||
end Client
|
||||
|
||||
/--
|
||||
Gets the X.509 verify result code after handshake.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_verify_result"]
|
||||
opaque verifyResult {r : Role} (ssl : @& Session r) : IO UInt64
|
||||
|
||||
/--
|
||||
Runs one handshake step.
|
||||
Returns `none` when the handshake is complete, or `some w` when OpenSSL needs socket I/O of
|
||||
kind `w` before the handshake can proceed.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_handshake"]
|
||||
opaque handshake {r : Role} (ssl : @& Session r) : IO (Option IOWant)
|
||||
|
||||
/--
|
||||
Attempts to write plaintext application data into SSL.
|
||||
Returns `none` when the data was accepted, or `some w` when OpenSSL needs socket I/O of kind
|
||||
`w` before the write can complete (the data is queued internally and retried after the next read).
|
||||
-/
|
||||
@[extern "lean_uv_ssl_write"]
|
||||
opaque write {r : Role} (ssl : @& Session r) (data : @& ByteArray) : IO (Option IOWant)
|
||||
|
||||
/--
|
||||
Attempts to read decrypted plaintext data.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_read"]
|
||||
opaque read? {r : Role} (ssl : @& Session r) (maxBytes : UInt64) : IO ReadResult
|
||||
|
||||
/--
|
||||
Feeds encrypted TLS bytes into the SSL input BIO.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_feed_encrypted"]
|
||||
opaque feedEncrypted {r : Role} (ssl : @& Session r) (data : @& ByteArray) : IO UInt64
|
||||
|
||||
/--
|
||||
Drains encrypted TLS bytes from the SSL output BIO.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_drain_encrypted"]
|
||||
opaque drainEncrypted {r : Role} (ssl : @& Session r) : IO ByteArray
|
||||
|
||||
/--
|
||||
Returns the amount of encrypted TLS bytes currently pending in the output BIO.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_pending_encrypted"]
|
||||
opaque pendingEncrypted {r : Role} (ssl : @& Session r) : IO UInt64
|
||||
|
||||
/--
|
||||
Returns the amount of decrypted plaintext bytes currently buffered inside the SSL object.
|
||||
-/
|
||||
@[extern "lean_uv_ssl_pending_plaintext"]
|
||||
opaque pendingPlaintext {r : Role} (ssl : @& Session r) : IO UInt64
|
||||
|
||||
end Session
|
||||
|
||||
end Std.Internal.SSL
|
||||
@@ -33,9 +33,6 @@ set(
|
||||
uv/dns.cpp
|
||||
uv/system.cpp
|
||||
uv/signal.cpp
|
||||
openssl.cpp
|
||||
openssl/context.cpp
|
||||
openssl/session.cpp
|
||||
)
|
||||
if(USE_MIMALLOC)
|
||||
list(APPEND RUNTIME_OBJS ${LEAN_BINARY_DIR}/../mimalloc/src/mimalloc/src/static.c)
|
||||
|
||||
@@ -14,9 +14,6 @@ Author: Leonardo de Moura
|
||||
#include "runtime/mutex.h"
|
||||
#include "runtime/init_module.h"
|
||||
#include "runtime/libuv.h"
|
||||
#include "runtime/openssl.h"
|
||||
#include "runtime/openssl/context.h"
|
||||
#include "runtime/openssl/session.h"
|
||||
|
||||
namespace lean {
|
||||
extern "C" LEAN_EXPORT void lean_initialize_runtime_module() {
|
||||
@@ -28,9 +25,6 @@ extern "C" LEAN_EXPORT void lean_initialize_runtime_module() {
|
||||
initialize_mutex();
|
||||
initialize_process();
|
||||
initialize_stack_overflow();
|
||||
initialize_openssl();
|
||||
initialize_openssl_context();
|
||||
initialize_openssl_session();
|
||||
initialize_libuv();
|
||||
}
|
||||
void initialize_runtime_module() {
|
||||
@@ -38,7 +32,6 @@ void initialize_runtime_module() {
|
||||
}
|
||||
void finalize_runtime_module() {
|
||||
finalize_stack_overflow();
|
||||
finalize_openssl();
|
||||
finalize_process();
|
||||
finalize_mutex();
|
||||
finalize_thread();
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Sofia Rodrigues
|
||||
*/
|
||||
#include "runtime/openssl.h"
|
||||
#include "runtime/io.h"
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
#include <openssl/opensslv.h>
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/err.h>
|
||||
|
||||
namespace lean {
|
||||
|
||||
void initialize_openssl() {
|
||||
if (OPENSSL_init_ssl(0, nullptr) != 1) {
|
||||
lean_internal_panic("failed to initialize OpenSSL");
|
||||
}
|
||||
}
|
||||
|
||||
void finalize_openssl() {}
|
||||
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_openssl_version(lean_obj_arg o) {
|
||||
return lean_unsigned_to_nat(OPENSSL_VERSION_NUMBER);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
namespace lean {
|
||||
|
||||
void initialize_openssl() {}
|
||||
void finalize_openssl() {}
|
||||
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_openssl_version(lean_obj_arg o) {
|
||||
return lean_box(0);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,16 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Sofia Rodrigues
|
||||
*/
|
||||
#pragma once
|
||||
#include <lean/lean.h>
|
||||
|
||||
namespace lean {
|
||||
|
||||
void initialize_openssl();
|
||||
void finalize_openssl();
|
||||
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_openssl_version(lean_obj_arg);
|
||||
@@ -1,148 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Sofia Rodrigues
|
||||
*/
|
||||
|
||||
#include "runtime/openssl/context.h"
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
#include <openssl/err.h>
|
||||
#endif
|
||||
|
||||
namespace lean {
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
|
||||
static inline lean_obj_res mk_ssl_ctx_io_error(char const * where) {
|
||||
unsigned long err = ERR_get_error();
|
||||
char err_buf[256];
|
||||
err_buf[0] = '\0';
|
||||
|
||||
if (err != 0) {
|
||||
ERR_error_string_n(err, err_buf, sizeof(err_buf));
|
||||
}
|
||||
|
||||
ERR_clear_error();
|
||||
|
||||
std::string msg(where);
|
||||
if (err_buf[0] != '\0') {
|
||||
msg += ": ";
|
||||
msg += err_buf;
|
||||
}
|
||||
|
||||
return lean_io_result_mk_error(lean_mk_io_user_error(mk_string(msg.c_str())));
|
||||
}
|
||||
|
||||
static void configure_ctx_options(SSL_CTX * ctx) {
|
||||
SSL_CTX_clear_options(ctx, SSL_OP_NO_RENEGOTIATION);
|
||||
}
|
||||
|
||||
static void lean_ssl_context_finalizer(void * ptr) {
|
||||
lean_ssl_context_object * obj = (lean_ssl_context_object*)ptr;
|
||||
if (obj->ctx != nullptr) {
|
||||
SSL_CTX_free(obj->ctx);
|
||||
}
|
||||
free(obj);
|
||||
}
|
||||
|
||||
void initialize_openssl_context() {
|
||||
g_ssl_context_external_class = lean_register_external_class(lean_ssl_context_finalizer, [](void * obj, lean_object * f) {
|
||||
(void)obj;
|
||||
(void)f;
|
||||
});
|
||||
}
|
||||
|
||||
static lean_obj_res mk_ssl_context(const SSL_METHOD * method) {
|
||||
SSL_CTX * ctx = SSL_CTX_new(method);
|
||||
if (ctx == nullptr) {
|
||||
return mk_ssl_ctx_io_error("SSL_CTX_new failed");
|
||||
}
|
||||
|
||||
configure_ctx_options(ctx);
|
||||
|
||||
lean_ssl_context_object * obj = (lean_ssl_context_object*)malloc(sizeof(lean_ssl_context_object));
|
||||
if (obj == nullptr) {
|
||||
SSL_CTX_free(ctx);
|
||||
return mk_ssl_ctx_io_error("failed to allocate SSL context object");
|
||||
}
|
||||
|
||||
obj->ctx = ctx;
|
||||
lean_object * lean_obj = lean_ssl_context_object_new(obj);
|
||||
lean_mark_mt(lean_obj);
|
||||
return lean_io_result_mk_ok(lean_obj);
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Context.mkServer : IO Context */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_mk_server() {
|
||||
return mk_ssl_context(TLS_server_method());
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Context.mkClient : IO Context */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_mk_client() {
|
||||
return mk_ssl_context(TLS_client_method());
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Context.configureServer (ctx : @& Context) (certFile keyFile : @& String) : IO Unit */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_configure_server(b_obj_arg ctx_obj, b_obj_arg cert_file, b_obj_arg key_file) {
|
||||
lean_ssl_context_object * obj = lean_to_ssl_context_object(ctx_obj);
|
||||
const char * cert = lean_string_cstr(cert_file);
|
||||
const char * key = lean_string_cstr(key_file);
|
||||
|
||||
if (SSL_CTX_use_certificate_file(obj->ctx, cert, SSL_FILETYPE_PEM) <= 0) {
|
||||
return mk_ssl_ctx_io_error("SSL_CTX_use_certificate_file failed");
|
||||
}
|
||||
if (SSL_CTX_use_PrivateKey_file(obj->ctx, key, SSL_FILETYPE_PEM) <= 0) {
|
||||
return mk_ssl_ctx_io_error("SSL_CTX_use_PrivateKey_file failed");
|
||||
}
|
||||
if (SSL_CTX_check_private_key(obj->ctx) != 1) {
|
||||
return mk_ssl_ctx_io_error("SSL_CTX_check_private_key failed");
|
||||
}
|
||||
|
||||
return lean_io_result_mk_ok(lean_box(0));
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Context.configureClient (ctx : @& Context) (caFile : @& String) (verifyPeer : Bool) : IO Unit */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_configure_client(b_obj_arg ctx_obj, b_obj_arg ca_file, uint8_t verify_peer) {
|
||||
lean_ssl_context_object * obj = lean_to_ssl_context_object(ctx_obj);
|
||||
const char * ca = lean_string_cstr(ca_file);
|
||||
|
||||
if (ca != nullptr && ca[0] != '\0') {
|
||||
if (SSL_CTX_load_verify_locations(obj->ctx, ca, nullptr) != 1) {
|
||||
return mk_ssl_ctx_io_error("SSL_CTX_load_verify_locations failed");
|
||||
}
|
||||
} else if (verify_peer) {
|
||||
if (SSL_CTX_set_default_verify_paths(obj->ctx) != 1) {
|
||||
return mk_ssl_ctx_io_error("SSL_CTX_set_default_verify_paths failed");
|
||||
}
|
||||
}
|
||||
|
||||
SSL_CTX_set_verify(obj->ctx, verify_peer ? SSL_VERIFY_PEER : SSL_VERIFY_NONE, nullptr);
|
||||
return lean_io_result_mk_ok(lean_box(0));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void initialize_openssl_context() {}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_mk_server() {
|
||||
return io_result_mk_error("lean_uv_ssl_ctx_mk_server is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_mk_client() {
|
||||
return io_result_mk_error("lean_uv_ssl_ctx_mk_client is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_configure_server(b_obj_arg ctx_obj, b_obj_arg cert_file, b_obj_arg key_file) {
|
||||
(void)ctx_obj; (void)cert_file; (void)key_file;
|
||||
return io_result_mk_error("lean_uv_ssl_ctx_configure_server is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_configure_client(b_obj_arg ctx_obj, b_obj_arg ca_file, uint8_t verify_peer) {
|
||||
(void)ctx_obj; (void)ca_file; (void)verify_peer;
|
||||
return io_result_mk_error("lean_uv_ssl_ctx_configure_client is not supported");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Sofia Rodrigues
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <lean/lean.h>
|
||||
#include "runtime/io.h"
|
||||
#include "runtime/object.h"
|
||||
#include "runtime/openssl.h"
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
#include <openssl/ssl.h>
|
||||
#endif
|
||||
|
||||
namespace lean {
|
||||
|
||||
static lean_external_class * g_ssl_context_external_class = nullptr;
|
||||
void initialize_openssl_context();
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
typedef struct {
|
||||
SSL_CTX * ctx;
|
||||
} lean_ssl_context_object;
|
||||
|
||||
static inline lean_object * lean_ssl_context_object_new(lean_ssl_context_object * c) {
|
||||
return lean_alloc_external(g_ssl_context_external_class, c);
|
||||
}
|
||||
static inline lean_ssl_context_object * lean_to_ssl_context_object(lean_object * o) {
|
||||
return (lean_ssl_context_object*)(lean_get_external_data(o));
|
||||
}
|
||||
#endif
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_mk_server();
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_mk_client();
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_configure_server(b_obj_arg ctx, b_obj_arg cert_file, b_obj_arg key_file);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_ctx_configure_client(b_obj_arg ctx, b_obj_arg ca_file, uint8_t verify_peer);
|
||||
|
||||
}
|
||||
@@ -1,501 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Sofia Rodrigues
|
||||
*/
|
||||
|
||||
#include "runtime/openssl/session.h"
|
||||
|
||||
#include <climits>
|
||||
#include <new>
|
||||
#include <string>
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
#include <openssl/err.h>
|
||||
#endif
|
||||
|
||||
namespace lean {
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
|
||||
static inline lean_object * mk_ssl_error(char const * where, int ssl_err = 0) {
|
||||
unsigned long err = ERR_get_error();
|
||||
char err_buf[256];
|
||||
err_buf[0] = '\0';
|
||||
|
||||
if (err != 0) {
|
||||
ERR_error_string_n(err, err_buf, sizeof(err_buf));
|
||||
}
|
||||
|
||||
// Drain remaining errors so they don't pollute future calls.
|
||||
ERR_clear_error();
|
||||
|
||||
std::string msg(where);
|
||||
|
||||
if (ssl_err != 0) {
|
||||
msg += " (ssl_error=" + std::to_string(ssl_err) + ")";
|
||||
}
|
||||
if (err_buf[0] != '\0') {
|
||||
msg += ": ";
|
||||
msg += err_buf;
|
||||
}
|
||||
|
||||
return lean_mk_io_user_error(mk_string(msg.c_str()));
|
||||
}
|
||||
|
||||
static inline lean_obj_res mk_ssl_io_error(char const * where, int ssl_err = 0) {
|
||||
return lean_io_result_mk_error(mk_ssl_error(where, ssl_err));
|
||||
}
|
||||
|
||||
/*
|
||||
* Lean encoding for `Option IOWant`:
|
||||
* none = lean_box(0) (handshake done / write accepted)
|
||||
* some IOWant.read = ctor(1){ lean_box(0) } (SSL_ERROR_WANT_READ)
|
||||
* some IOWant.write = ctor(1){ lean_box(1) } (SSL_ERROR_WANT_WRITE)
|
||||
*
|
||||
* Lean encoding for `ReadResult`:
|
||||
* data bytes = ctor(0){ bytes } (non-nullary constructor 0)
|
||||
* wantIO .read = ctor(1){ lean_box(0) } (non-nullary constructor 1)
|
||||
* wantIO .write = ctor(1){ lean_box(1) }
|
||||
* closed = lean_box(0) (first nullary constructor)
|
||||
*/
|
||||
static inline lean_obj_res mk_option_io_want_none() {
|
||||
return lean_io_result_mk_ok(lean_box(0));
|
||||
}
|
||||
|
||||
static inline lean_obj_res mk_option_io_want_read() {
|
||||
lean_object * r = lean_alloc_ctor(1, 1, 0);
|
||||
lean_ctor_set(r, 0, lean_box(0));
|
||||
return lean_io_result_mk_ok(r);
|
||||
}
|
||||
|
||||
static inline lean_obj_res mk_option_io_want_write() {
|
||||
lean_object * r = lean_alloc_ctor(1, 1, 0);
|
||||
lean_ctor_set(r, 0, lean_box(1));
|
||||
return lean_io_result_mk_ok(r);
|
||||
}
|
||||
|
||||
static inline lean_obj_res mk_read_result_data(lean_object * bytes) {
|
||||
lean_object * r = lean_alloc_ctor(0, 1, 0);
|
||||
lean_ctor_set(r, 0, bytes);
|
||||
return lean_io_result_mk_ok(r);
|
||||
}
|
||||
|
||||
static inline lean_obj_res mk_read_result_want_read() {
|
||||
lean_object * r = lean_alloc_ctor(1, 1, 0);
|
||||
lean_ctor_set(r, 0, lean_box(0));
|
||||
return lean_io_result_mk_ok(r);
|
||||
}
|
||||
|
||||
static inline lean_obj_res mk_read_result_want_write() {
|
||||
lean_object * r = lean_alloc_ctor(1, 1, 0);
|
||||
lean_ctor_set(r, 0, lean_box(1));
|
||||
return lean_io_result_mk_ok(r);
|
||||
}
|
||||
|
||||
static inline lean_obj_res mk_read_result_closed() {
|
||||
return lean_io_result_mk_ok(lean_box(0));
|
||||
}
|
||||
|
||||
static inline lean_object * mk_empty_byte_array() {
|
||||
lean_object * arr = lean_alloc_sarray(1, 0, 0);
|
||||
lean_sarray_set_size(arr, 0);
|
||||
return arr;
|
||||
}
|
||||
|
||||
/*
|
||||
Return values:
|
||||
1 -> write completed
|
||||
0 -> write blocked (WANT_READ / WANT_WRITE / ZERO_RETURN)
|
||||
-1 -> fatal error
|
||||
*/
|
||||
static int ssl_write_step(lean_ssl_session_object * obj, char const * data, size_t size, int * out_err) {
|
||||
if (size > INT_MAX) {
|
||||
*out_err = SSL_ERROR_SSL;
|
||||
return -1;
|
||||
}
|
||||
|
||||
int rc = SSL_write(obj->ssl, data, (int)size);
|
||||
if (rc > 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
int err = SSL_get_error(obj->ssl, rc);
|
||||
*out_err = err;
|
||||
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE || err == SSL_ERROR_ZERO_RETURN) {
|
||||
return 0;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
/*
|
||||
Return values:
|
||||
1 -> all pending writes flushed
|
||||
0 -> still blocked, *out_err filled with SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE
|
||||
-1 -> fatal error, *out_err filled
|
||||
*/
|
||||
static int try_flush_pending_writes(lean_ssl_session_object * obj, int * out_err) {
|
||||
while (!obj->pending_writes.empty()) {
|
||||
auto & pw = obj->pending_writes.front();
|
||||
int step = ssl_write_step(obj, pw.data(), pw.size(), out_err);
|
||||
if (step < 0) return -1;
|
||||
if (step == 0) return 0;
|
||||
obj->pending_writes.erase(obj->pending_writes.begin());
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
void lean_ssl_session_finalizer(void * ptr) {
|
||||
lean_ssl_session_object * obj = (lean_ssl_session_object*)ptr;
|
||||
if (obj->ssl != nullptr) SSL_free(obj->ssl);
|
||||
delete obj;
|
||||
}
|
||||
|
||||
void initialize_openssl_session() {
|
||||
g_ssl_session_external_class = lean_register_external_class(lean_ssl_session_finalizer, [](void * obj, lean_object * f) {
|
||||
(void)obj;
|
||||
(void)f;
|
||||
});
|
||||
}
|
||||
|
||||
static lean_obj_res mk_ssl_session(SSL_CTX * ctx, uint8_t is_server) {
|
||||
SSL * ssl = SSL_new(ctx);
|
||||
if (ssl == nullptr) {
|
||||
return mk_ssl_io_error("SSL_new failed");
|
||||
}
|
||||
|
||||
BIO * read_bio = BIO_new(BIO_s_mem());
|
||||
BIO * write_bio = BIO_new(BIO_s_mem());
|
||||
|
||||
if (read_bio == nullptr || write_bio == nullptr) {
|
||||
if (read_bio != nullptr) BIO_free(read_bio);
|
||||
if (write_bio != nullptr) BIO_free(write_bio);
|
||||
SSL_free(ssl);
|
||||
return mk_ssl_io_error("BIO_new failed");
|
||||
}
|
||||
|
||||
BIO_set_nbio(read_bio, 1);
|
||||
BIO_set_nbio(write_bio, 1);
|
||||
|
||||
SSL_set_bio(ssl, read_bio, write_bio);
|
||||
|
||||
if (is_server) {
|
||||
SSL_set_accept_state(ssl);
|
||||
} else {
|
||||
SSL_set_connect_state(ssl);
|
||||
}
|
||||
|
||||
lean_ssl_session_object * ssl_obj = new (std::nothrow) lean_ssl_session_object();
|
||||
if (ssl_obj == nullptr) {
|
||||
SSL_free(ssl);
|
||||
return mk_ssl_io_error("failed to allocate SSL session object");
|
||||
}
|
||||
|
||||
ssl_obj->ssl = ssl;
|
||||
ssl_obj->read_bio = read_bio;
|
||||
ssl_obj->write_bio = write_bio;
|
||||
|
||||
lean_object * obj = lean_ssl_session_object_new(ssl_obj);
|
||||
lean_mark_mt(obj);
|
||||
return lean_io_result_mk_ok(obj);
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.Server.mk (ctx : @& Context.Server) : IO Session.Server */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_mk_server(b_obj_arg ctx_obj) {
|
||||
lean_ssl_context_object * ctx = lean_to_ssl_context_object(ctx_obj);
|
||||
return mk_ssl_session(ctx->ctx, 1);
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.Client.mk (ctx : @& Context.Client) : IO Session.Client */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_mk_client(b_obj_arg ctx_obj) {
|
||||
lean_ssl_context_object * ctx = lean_to_ssl_context_object(ctx_obj);
|
||||
return mk_ssl_session(ctx->ctx, 0);
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.Client.setServerName (ssl : @& Session.Client) (host : @& String) : IO Unit */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_set_server_name(b_obj_arg ssl, b_obj_arg host) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
const char * server_name = lean_string_cstr(host);
|
||||
if (SSL_set_tlsext_host_name(ssl_obj->ssl, server_name) != 1) {
|
||||
return mk_ssl_io_error("SSL_set_tlsext_host_name failed");
|
||||
}
|
||||
return lean_io_result_mk_ok(lean_box(0));
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.verifyResult (ssl : @& Session) : IO UInt64 */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_verify_result(b_obj_arg _role, b_obj_arg ssl) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
long result = SSL_get_verify_result(ssl_obj->ssl);
|
||||
return lean_io_result_mk_ok(lean_box_uint64((uint64_t)result));
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.handshake (ssl : @& Session) : IO (Option IOWant) */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_handshake(b_obj_arg _role, b_obj_arg ssl) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
int rc = SSL_do_handshake(ssl_obj->ssl);
|
||||
|
||||
if (rc == 1) {
|
||||
return mk_option_io_want_none();
|
||||
}
|
||||
|
||||
int err = SSL_get_error(ssl_obj->ssl, rc);
|
||||
if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_ZERO_RETURN) {
|
||||
return mk_option_io_want_read();
|
||||
}
|
||||
if (err == SSL_ERROR_WANT_WRITE) {
|
||||
return mk_option_io_want_write();
|
||||
}
|
||||
|
||||
return mk_ssl_io_error("SSL_do_handshake failed", err);
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.write (ssl : @& Session) (data : @& ByteArray) : IO (Option IOWant) */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_write(b_obj_arg _role, b_obj_arg ssl, b_obj_arg data) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
size_t data_len = lean_sarray_size(data);
|
||||
char const * payload = (char const*)lean_sarray_cptr(data);
|
||||
|
||||
if (data_len == 0) {
|
||||
return mk_option_io_want_none();
|
||||
}
|
||||
|
||||
// If there are pending writes, try to flush them first to preserve write order.
|
||||
// Only attempt the new write directly if the queue fully drains.
|
||||
if (!ssl_obj->pending_writes.empty()) {
|
||||
int flush_err = 0;
|
||||
int flushed = try_flush_pending_writes(ssl_obj, &flush_err);
|
||||
|
||||
if (flushed < 0) {
|
||||
return mk_ssl_io_error("pending SSL write flush failed", flush_err);
|
||||
}
|
||||
|
||||
if (flushed == 0) {
|
||||
ssl_obj->pending_writes.emplace_back(payload, payload + data_len);
|
||||
if (flush_err == SSL_ERROR_WANT_READ) {
|
||||
return mk_option_io_want_read();
|
||||
}
|
||||
return mk_option_io_want_write();
|
||||
}
|
||||
// flushed == 1: queue is clear, fall through to attempt the new write
|
||||
}
|
||||
|
||||
int err = 0;
|
||||
int step = ssl_write_step(ssl_obj, payload, data_len, &err);
|
||||
|
||||
if (step == 1) {
|
||||
return mk_option_io_want_none();
|
||||
}
|
||||
|
||||
if (step == 0 && err == SSL_ERROR_ZERO_RETURN) {
|
||||
return mk_ssl_io_error("SSL_write failed: peer closed the TLS session", err);
|
||||
}
|
||||
|
||||
// Queue plaintext so it is retried after the required socket I/O completes.
|
||||
if (step == 0) {
|
||||
ssl_obj->pending_writes.emplace_back(payload, payload + data_len);
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
return mk_option_io_want_read();
|
||||
}
|
||||
return mk_option_io_want_write();
|
||||
}
|
||||
|
||||
return mk_ssl_io_error("SSL_write failed", err);
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.read? (ssl : @& Session) (maxBytes : UInt64) : IO ReadResult */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_read(b_obj_arg _role, b_obj_arg ssl, uint64_t max_bytes) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
|
||||
if (max_bytes == 0) {
|
||||
return mk_read_result_data(mk_empty_byte_array());
|
||||
}
|
||||
|
||||
if (max_bytes > INT_MAX) {
|
||||
max_bytes = INT_MAX;
|
||||
}
|
||||
|
||||
lean_object * out = lean_alloc_sarray(1, 0, max_bytes);
|
||||
int rc = SSL_read(ssl_obj->ssl, (void*)lean_sarray_cptr(out), (int)max_bytes);
|
||||
|
||||
if (rc > 0) {
|
||||
int flush_err = 0;
|
||||
if (try_flush_pending_writes(ssl_obj, &flush_err) < 0) {
|
||||
lean_dec(out);
|
||||
return mk_ssl_io_error("pending SSL write flush failed", flush_err);
|
||||
}
|
||||
lean_sarray_set_size(out, (size_t)rc);
|
||||
return mk_read_result_data(out);
|
||||
}
|
||||
|
||||
lean_dec(out);
|
||||
|
||||
int err = SSL_get_error(ssl_obj->ssl, rc);
|
||||
|
||||
if (err == SSL_ERROR_ZERO_RETURN) {
|
||||
int flush_err = 0;
|
||||
if (try_flush_pending_writes(ssl_obj, &flush_err) < 0) {
|
||||
return mk_ssl_io_error("pending SSL write flush failed", flush_err);
|
||||
}
|
||||
return mk_read_result_closed();
|
||||
}
|
||||
|
||||
if (err == SSL_ERROR_WANT_READ) {
|
||||
int flush_err = 0;
|
||||
int flushed = try_flush_pending_writes(ssl_obj, &flush_err);
|
||||
if (flushed < 0) {
|
||||
return mk_ssl_io_error("pending SSL write flush failed", flush_err);
|
||||
}
|
||||
if (flushed == 0 && flush_err == SSL_ERROR_WANT_WRITE) {
|
||||
return mk_read_result_want_write();
|
||||
}
|
||||
return mk_read_result_want_read();
|
||||
}
|
||||
|
||||
if (err == SSL_ERROR_WANT_WRITE) {
|
||||
int flush_err = 0;
|
||||
int flushed = try_flush_pending_writes(ssl_obj, &flush_err);
|
||||
if (flushed < 0) {
|
||||
return mk_ssl_io_error("pending SSL write flush failed", flush_err);
|
||||
}
|
||||
if (flushed == 0 && flush_err == SSL_ERROR_WANT_READ) {
|
||||
return mk_read_result_want_read();
|
||||
}
|
||||
return mk_read_result_want_write();
|
||||
}
|
||||
|
||||
return mk_ssl_io_error("SSL_read failed", err);
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.feedEncrypted (ssl : @& Session) (data : @& ByteArray) : IO UInt64 */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_feed_encrypted(b_obj_arg _role, b_obj_arg ssl, b_obj_arg data) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
size_t data_len = lean_sarray_size(data);
|
||||
|
||||
if (data_len == 0) {
|
||||
return lean_io_result_mk_ok(lean_box_uint64(0));
|
||||
}
|
||||
|
||||
if (data_len > INT_MAX) {
|
||||
return mk_ssl_io_error("BIO_write input too large");
|
||||
}
|
||||
|
||||
int rc = BIO_write(ssl_obj->read_bio, lean_sarray_cptr(data), (int)data_len);
|
||||
if (rc >= 0) {
|
||||
return lean_io_result_mk_ok(lean_box_uint64((uint64_t)rc));
|
||||
}
|
||||
|
||||
if (BIO_should_retry(ssl_obj->read_bio)) {
|
||||
return lean_io_result_mk_ok(lean_box_uint64(0));
|
||||
}
|
||||
|
||||
return mk_ssl_io_error("BIO_write failed");
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.drainEncrypted (ssl : @& Session) : IO ByteArray */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_drain_encrypted(b_obj_arg _role, b_obj_arg ssl) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
size_t pending = BIO_ctrl_pending(ssl_obj->write_bio);
|
||||
|
||||
if (pending == 0) {
|
||||
return lean_io_result_mk_ok(mk_empty_byte_array());
|
||||
}
|
||||
|
||||
if (pending > INT_MAX) {
|
||||
return mk_ssl_io_error("BIO_pending output too large");
|
||||
}
|
||||
|
||||
lean_object * out = lean_alloc_sarray(1, 0, pending);
|
||||
int rc = BIO_read(ssl_obj->write_bio, (void*)lean_sarray_cptr(out), (int)pending);
|
||||
|
||||
if (rc >= 0) {
|
||||
lean_sarray_set_size(out, (size_t)rc);
|
||||
return lean_io_result_mk_ok(out);
|
||||
}
|
||||
|
||||
lean_dec(out);
|
||||
|
||||
if (BIO_should_retry(ssl_obj->write_bio)) {
|
||||
return lean_io_result_mk_ok(mk_empty_byte_array());
|
||||
}
|
||||
|
||||
return mk_ssl_io_error("BIO_read failed");
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.pendingEncrypted (ssl : @& Session) : IO UInt64 */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_pending_encrypted(b_obj_arg _role, b_obj_arg ssl) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
return lean_io_result_mk_ok(lean_box_uint64((uint64_t)BIO_ctrl_pending(ssl_obj->write_bio)));
|
||||
}
|
||||
|
||||
/* Std.Internal.SSL.Session.pendingPlaintext (ssl : @& Session) : IO UInt64 */
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_pending_plaintext(b_obj_arg _role, b_obj_arg ssl) {
|
||||
lean_ssl_session_object * ssl_obj = lean_to_ssl_session_object(ssl);
|
||||
return lean_io_result_mk_ok(lean_box_uint64((uint64_t)SSL_pending(ssl_obj->ssl)));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void initialize_openssl_session() {}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_mk_server(b_obj_arg ctx_obj) {
|
||||
(void)ctx_obj;
|
||||
return io_result_mk_error("lean_uv_ssl_mk_server is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_mk_client(b_obj_arg ctx_obj) {
|
||||
(void)ctx_obj;
|
||||
return io_result_mk_error("lean_uv_ssl_mk_client is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_set_server_name(b_obj_arg ssl, b_obj_arg host) {
|
||||
(void)ssl;
|
||||
(void)host;
|
||||
return io_result_mk_error("lean_uv_ssl_set_server_name is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_verify_result(b_obj_arg _role, b_obj_arg ssl) {
|
||||
(void)ssl;
|
||||
return io_result_mk_error("lean_uv_ssl_verify_result is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_handshake(b_obj_arg _role, b_obj_arg ssl) {
|
||||
(void)ssl;
|
||||
return io_result_mk_error("lean_uv_ssl_handshake is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_write(b_obj_arg _role, b_obj_arg ssl, b_obj_arg data) {
|
||||
(void)ssl;
|
||||
(void)data;
|
||||
return io_result_mk_error("lean_uv_ssl_write is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_read(b_obj_arg _role, b_obj_arg ssl, uint64_t max_bytes) {
|
||||
(void)ssl;
|
||||
(void)max_bytes;
|
||||
return io_result_mk_error("lean_uv_ssl_read is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_feed_encrypted(b_obj_arg _role, b_obj_arg ssl, b_obj_arg data) {
|
||||
(void)ssl;
|
||||
(void)data;
|
||||
return io_result_mk_error("lean_uv_ssl_feed_encrypted is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_drain_encrypted(b_obj_arg _role, b_obj_arg ssl) {
|
||||
(void)ssl;
|
||||
return io_result_mk_error("lean_uv_ssl_drain_encrypted is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_pending_encrypted(b_obj_arg _role, b_obj_arg ssl) {
|
||||
(void)ssl;
|
||||
return io_result_mk_error("lean_uv_ssl_pending_encrypted is not supported");
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_pending_plaintext(b_obj_arg _role, b_obj_arg ssl) {
|
||||
(void)ssl;
|
||||
return io_result_mk_error("lean_uv_ssl_pending_plaintext is not supported");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
/*
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Sofia Rodrigues
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <lean/lean.h>
|
||||
#include "runtime/io.h"
|
||||
#include "runtime/object.h"
|
||||
#include "runtime/openssl/context.h"
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/bio.h>
|
||||
#endif
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace lean {
|
||||
|
||||
static lean_external_class * g_ssl_session_external_class = nullptr;
|
||||
void initialize_openssl_session();
|
||||
|
||||
#ifndef LEAN_EMSCRIPTEN
|
||||
struct lean_ssl_session_object {
|
||||
SSL * ssl;
|
||||
BIO * read_bio;
|
||||
BIO * write_bio;
|
||||
std::vector<std::vector<char>> pending_writes;
|
||||
};
|
||||
|
||||
static inline lean_object * lean_ssl_session_object_new(lean_ssl_session_object * s) { return lean_alloc_external(g_ssl_session_external_class, s); }
|
||||
static inline lean_ssl_session_object * lean_to_ssl_session_object(lean_object * o) { return (lean_ssl_session_object*)(lean_get_external_data(o)); }
|
||||
#endif
|
||||
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_mk_server(b_obj_arg ctx);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_mk_client(b_obj_arg ctx);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_set_server_name(b_obj_arg ssl, b_obj_arg host);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_verify_result(b_obj_arg _role, b_obj_arg ssl);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_handshake(b_obj_arg _role, b_obj_arg ssl);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_write(b_obj_arg _role, b_obj_arg ssl, b_obj_arg data);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_read(b_obj_arg _role, b_obj_arg ssl, uint64_t max_bytes);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_feed_encrypted(b_obj_arg _role, b_obj_arg ssl, b_obj_arg data);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_drain_encrypted(b_obj_arg _role, b_obj_arg ssl);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_pending_encrypted(b_obj_arg _role, b_obj_arg ssl);
|
||||
extern "C" LEAN_EXPORT lean_obj_res lean_uv_ssl_pending_plaintext(b_obj_arg _role, b_obj_arg ssl);
|
||||
|
||||
}
|
||||
@@ -1,601 +0,0 @@
|
||||
import Std.Internal.Async.TCP.SSL
|
||||
import Std.Net.Addr
|
||||
|
||||
open Std.Internal.IO Async TCP.SSL
|
||||
open Std.Net
|
||||
open Std.Internal.SSL
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Helpers
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def assertEqStr (actual expected : String) : IO Unit := do
|
||||
unless actual == expected do
|
||||
throw <| IO.userError s!"expected '{expected}', got '{actual}'"
|
||||
|
||||
def assertGt (actual : UInt64) (bound : UInt64) (label : String) : IO Unit := do
|
||||
unless actual > bound do
|
||||
throw <| IO.userError s!"{label}: expected > {bound}, got {actual}"
|
||||
|
||||
def assertEqN (actual expected : UInt64) (label : String) : IO Unit := do
|
||||
unless actual == expected do
|
||||
throw <| IO.userError s!"{label}: expected {expected}, got {actual}"
|
||||
|
||||
-- Generate a self-signed certificate for testing.
|
||||
def setupTestCerts : IO (String × String) := do
|
||||
IO.FS.createDirAll "/tmp/lean_ssl_test"
|
||||
let keyFile := "/tmp/lean_ssl_test/key.pem"
|
||||
let certFile := "/tmp/lean_ssl_test/cert.pem"
|
||||
|
||||
discard <| IO.Process.output {
|
||||
cmd := "openssl"
|
||||
args := #["genrsa", "-out", keyFile, "2048"]
|
||||
}
|
||||
|
||||
discard <| IO.Process.output {
|
||||
cmd := "openssl"
|
||||
args := #["req", "-new", "-x509", "-key", keyFile, "-out", certFile,
|
||||
"-days", "1", "-subj", "/CN=localhost"]
|
||||
}
|
||||
|
||||
return (certFile, keyFile)
|
||||
|
||||
-- Drive one handshake step: advance both state machines and exchange encrypted
|
||||
-- bytes between their memory BIOs. Returns (clientDone, serverDone).
|
||||
def handshakeStep {rc rs : Role} (c : Session rc) (s : Session rs) : IO (Bool × Bool) := do
|
||||
let cd ← c.handshake
|
||||
let cOut ← c.drainEncrypted
|
||||
if cOut.size > 0 then
|
||||
discard <| s.feedEncrypted cOut
|
||||
let sd ← s.handshake
|
||||
let sOut ← s.drainEncrypted
|
||||
if sOut.size > 0 then
|
||||
discard <| c.feedEncrypted sOut
|
||||
return (cd.isNone, sd.isNone)
|
||||
|
||||
partial def runHandshake {rc rs : Role} (c : Session rc) (s : Session rs) : IO Unit := do
|
||||
let (cd, sd) ← handshakeStep c s
|
||||
unless cd && sd do runHandshake c s
|
||||
|
||||
-- Pipe all pending encrypted output from src into dst's read BIO.
|
||||
def pipeEncrypted {r1 r2 : Role} (src : Session r1) (dst : Session r2) : IO Unit := do
|
||||
let bytes ← src.drainEncrypted
|
||||
if bytes.size > 0 then
|
||||
discard <| dst.feedEncrypted bytes
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 1: Context creation and configuration (smoke test)
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testContextCreation (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
|
||||
let clientCtx ← Context.Client.mk
|
||||
clientCtx.configure "" false
|
||||
|
||||
-- Configuring with a CA file path (non-empty) exercises the other branch.
|
||||
let clientCtx2 ← Context.Client.mk
|
||||
clientCtx2.configure certFile false
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 2: In-process TLS handshake between two memory-BIO sessions
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testInProcessHandshake (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
|
||||
let clientCtx ← Context.Client.mk
|
||||
clientCtx.configure "" false -- skip peer verification
|
||||
|
||||
let serverSess ← Session.Server.mk serverCtx
|
||||
let clientSess ← Session.Client.mk clientCtx
|
||||
|
||||
-- setServerName exercises SSL_set_tlsext_host_name.
|
||||
clientSess.setServerName "localhost"
|
||||
|
||||
runHandshake clientSess serverSess
|
||||
|
||||
-- verifyResult: just verify the call succeeds (self-signed cert returns
|
||||
-- X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT even with VERIFY_NONE).
|
||||
discard <| clientSess.verifyResult
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 3: write / pendingEncrypted / drainEncrypted / feedEncrypted / read?
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testDataTransfer (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
|
||||
let clientCtx ← Context.Client.mk
|
||||
clientCtx.configure "" false
|
||||
|
||||
let serverSess ← Session.Server.mk serverCtx
|
||||
let clientSess ← Session.Client.mk clientCtx
|
||||
|
||||
runHandshake clientSess serverSess
|
||||
|
||||
-- write plaintext → encrypted bytes appear in the write BIO.
|
||||
let msg := "hello, tls!".toUTF8
|
||||
discard <| clientSess.write msg
|
||||
|
||||
-- pendingEncrypted > 0 before draining.
|
||||
let pending ← clientSess.pendingEncrypted
|
||||
assertGt pending 0 "pendingEncrypted"
|
||||
|
||||
-- Pipe to server and read back.
|
||||
pipeEncrypted clientSess serverSess
|
||||
let received ← serverSess.read? 1024
|
||||
match received with
|
||||
| .data bytes => assertEqStr (String.fromUTF8! bytes) "hello, tls!"
|
||||
| _ => throw <| IO.userError "expected data from server session"
|
||||
|
||||
-- After draining, pendingEncrypted drops to 0.
|
||||
let pendingAfter ← clientSess.pendingEncrypted
|
||||
assertEqN pendingAfter 0 "pendingEncrypted after drain"
|
||||
|
||||
-- read? returns wantIO when no data is available.
|
||||
let empty ← clientSess.read? 1024
|
||||
match empty with
|
||||
| .wantIO _ => return ()
|
||||
| _ => throw <| IO.userError "expected wantIO when no data available"
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 4: pendingPlaintext — write 100 bytes, read 10, rest stays buffered
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testPendingPlaintext (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
|
||||
let clientCtx ← Context.Client.mk
|
||||
clientCtx.configure "" false
|
||||
|
||||
let serverSess ← Session.Server.mk serverCtx
|
||||
let clientSess ← Session.Client.mk clientCtx
|
||||
|
||||
runHandshake clientSess serverSess
|
||||
|
||||
let bigMsg := (String.ofList (List.replicate 100 'x')).toUTF8
|
||||
discard <| clientSess.write bigMsg
|
||||
pipeEncrypted clientSess serverSess
|
||||
|
||||
-- Read only 10 bytes; the remaining 90 stay in SSL's plaintext buffer.
|
||||
discard <| serverSess.read? 10
|
||||
let remaining ← serverSess.pendingPlaintext
|
||||
assertEqN remaining 90 "pendingPlaintext after partial read"
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 5: Full TCP/TLS round-trip
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def serverTask (server : TCP.SSL.Server) : Async Unit := do
|
||||
let client ← server.accept
|
||||
let msg ← client.recv? 1024
|
||||
client.send (msg.getD ByteArray.empty) -- echo
|
||||
client.shutdown
|
||||
|
||||
def clientTask (addr : SocketAddress) (clientCtx : Context.Client) : Async Unit := do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
client.noDelay
|
||||
client.send "hello over tls".toUTF8
|
||||
let resp ← client.recv? 1024
|
||||
let got := String.fromUTF8! (resp.getD ByteArray.empty)
|
||||
unless got == "hello over tls" do
|
||||
throw <| IO.userError s!"round-trip mismatch: '{got}'"
|
||||
let _ ← client.getPeerName
|
||||
let _ ← client.getSockName
|
||||
let _ ← client.verifyResult
|
||||
client.shutdown
|
||||
|
||||
def testTCPSSL (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.configureServer certFile keyFile -- idempotent re-configuration
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
let _ ← server.getSockName
|
||||
|
||||
let srvTask ← (serverTask server).toIO
|
||||
let cliTask ← (clientTask addr clientCtx).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 6: Multiple sequential round-trips (no hang between messages)
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testMultipleRoundTrips (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
let srvTask ← (do
|
||||
let conn ← server.accept
|
||||
for _ in List.range 5 do
|
||||
let msg ← conn.recv? 1024
|
||||
conn.send (msg.getD ByteArray.empty)
|
||||
conn.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
let cliTask ← (do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
for i in List.range 5 do
|
||||
let payload := s!"msg{i}".toUTF8
|
||||
client.send payload
|
||||
let resp ← client.recv? 1024
|
||||
let got := String.fromUTF8! (resp.getD ByteArray.empty)
|
||||
unless got == s!"msg{i}" do
|
||||
throw <| IO.userError s!"round-trip {i} mismatch: '{got}'"
|
||||
client.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 7: Large payload (> one TLS record = 16 KB), no hang on fragmentation
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testLargePayload (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
let payloadSize := 64 * 1024 -- 64 KB: spans multiple TLS records
|
||||
let payload := ByteArray.mk (List.replicate payloadSize 0x42).toArray
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
let srvTask ← (do
|
||||
let conn ← server.accept
|
||||
-- Accumulate until we have all bytes, then echo back.
|
||||
let mut buf := ByteArray.empty
|
||||
while buf.size < payloadSize do
|
||||
let chunk ← conn.recv? (payloadSize - buf.size).toUInt64
|
||||
buf := buf ++ chunk.getD ByteArray.empty
|
||||
conn.send buf
|
||||
conn.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
let cliTask ← (do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
client.send payload
|
||||
let mut buf := ByteArray.empty
|
||||
while buf.size < payloadSize do
|
||||
let chunk ← client.recv? (payloadSize - buf.size).toUInt64
|
||||
buf := buf ++ chunk.getD ByteArray.empty
|
||||
unless buf.size == payloadSize do
|
||||
throw <| IO.userError s!"large payload size mismatch: {buf.size}"
|
||||
client.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 8: recv? returns none after peer shutdown (no hang on closed conn)
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testRecvAfterShutdown (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
let srvTask ← (do
|
||||
let conn ← server.accept
|
||||
let msg ← conn.recv? 1024
|
||||
conn.send (msg.getD ByteArray.empty)
|
||||
conn.shutdown -- server closes write side first
|
||||
: Async Unit).toIO
|
||||
|
||||
let cliTask ← (do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
client.send "ping".toUTF8
|
||||
-- Receive the echo
|
||||
let resp ← client.recv? 1024
|
||||
let got := String.fromUTF8! (resp.getD ByteArray.empty)
|
||||
unless got == "ping" do
|
||||
throw <| IO.userError s!"echo mismatch: '{got}'"
|
||||
-- After server shutdown, recv? must return none, not hang
|
||||
let closed ← client.recv? 1024
|
||||
unless closed.isNone do
|
||||
throw <| IO.userError "expected none after server shutdown"
|
||||
client.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 9: acceptSelector delivers a fully-handshaked connection
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testAcceptSelector (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
let srvTask ← (do
|
||||
let conn ← Selectable.one #[
|
||||
.case (selector := server.acceptSelector) (cont := fun c => return c)
|
||||
]
|
||||
let msg ← conn.recv? 1024
|
||||
conn.send (msg.getD ByteArray.empty)
|
||||
conn.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
let cliTask ← (do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
client.send "via selector".toUTF8
|
||||
let resp ← client.recv? 1024
|
||||
let got := String.fromUTF8! (resp.getD ByteArray.empty)
|
||||
unless got == "via selector" do
|
||||
throw <| IO.userError s!"selector round-trip mismatch: '{got}'"
|
||||
client.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 10: sendAll — multiple buffers are fully delivered and echoed
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testSendAll (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
-- Three chunks whose concatenation we can verify.
|
||||
let chunks := #["alpha".toUTF8, "beta".toUTF8, "gamma".toUTF8]
|
||||
let expected := "alphabetagamma"
|
||||
let total := expected.length
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
let srvTask ← (do
|
||||
let conn ← server.accept
|
||||
let mut buf := ByteArray.empty
|
||||
while buf.size < total do
|
||||
let chunk ← conn.recv? total.toUInt64
|
||||
buf := buf ++ chunk.getD ByteArray.empty
|
||||
conn.send buf
|
||||
conn.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
let cliTask ← (do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
client.sendAll chunks
|
||||
let mut buf := ByteArray.empty
|
||||
while buf.size < total do
|
||||
let chunk ← client.recv? total.toUInt64
|
||||
buf := buf ++ chunk.getD ByteArray.empty
|
||||
assertEqStr (String.fromUTF8! buf) expected
|
||||
client.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 11: recvSelector — server pushes data, client receives via selector
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testRecvSelector (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
let srvTask ← (do
|
||||
let conn ← server.accept
|
||||
conn.send "pushed".toUTF8
|
||||
let ack ← conn.recv? 1024
|
||||
let got := String.fromUTF8! (ack.getD ByteArray.empty)
|
||||
unless got == "ack" do
|
||||
throw <| IO.userError s!"server: expected 'ack', got '{got}'"
|
||||
conn.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
let cliTask ← (do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
-- Block until the server's push arrives, using recvSelector.
|
||||
let result ← Selectable.one #[
|
||||
.case (selector := client.recvSelector 1024) (cont := fun r => return r)
|
||||
]
|
||||
let got := String.fromUTF8! (result.getD ByteArray.empty)
|
||||
unless got == "pushed" do
|
||||
throw <| IO.userError s!"recvSelector mismatch: '{got}'"
|
||||
client.send "ack".toUTF8
|
||||
client.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 12: Sequential reuse — same server socket handles N clients in a row
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testSequentialConnections (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
for i in List.range 3 do
|
||||
let srvTask ← (do
|
||||
let conn ← server.accept
|
||||
let msg ← conn.recv? 1024
|
||||
conn.send (msg.getD ByteArray.empty)
|
||||
conn.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
let cliTask ← (do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
let payload := s!"conn-{i}".toUTF8
|
||||
client.send payload
|
||||
let resp ← client.recv? 1024
|
||||
let got := String.fromUTF8! (resp.getD ByteArray.empty)
|
||||
unless got == s!"conn-{i}" do
|
||||
throw <| IO.userError s!"connection {i} echo mismatch: '{got}'"
|
||||
client.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Test 13: Simultaneous bidirectional — both sides send before either reads,
|
||||
-- verifying no deadlock occurs.
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
def testBidirectional (addr : SocketAddress) (certFile keyFile : String) : IO Unit := do
|
||||
let serverCtx ← Context.Server.mk
|
||||
serverCtx.configure certFile keyFile
|
||||
let clientCtx ← Context.Client.mk
|
||||
Client.configureContext clientCtx "" false
|
||||
|
||||
let server ← Server.mk serverCtx
|
||||
server.bind addr
|
||||
server.listen 128
|
||||
|
||||
let srvTask ← (do
|
||||
let conn ← server.accept
|
||||
conn.send "from-server".toUTF8 -- send without waiting for client first
|
||||
let msg ← conn.recv? 1024
|
||||
let got := String.fromUTF8! (msg.getD ByteArray.empty)
|
||||
unless got == "from-client" do
|
||||
throw <| IO.userError s!"server recv mismatch: '{got}'"
|
||||
conn.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
let cliTask ← (do
|
||||
let client ← Client.mk clientCtx
|
||||
client.setServerName "localhost"
|
||||
client.connect addr
|
||||
client.send "from-client".toUTF8 -- send without waiting for server first
|
||||
let msg ← client.recv? 1024
|
||||
let got := String.fromUTF8! (msg.getD ByteArray.empty)
|
||||
unless got == "from-server" do
|
||||
throw <| IO.userError s!"client recv mismatch: '{got}'"
|
||||
client.shutdown
|
||||
: Async Unit).toIO
|
||||
|
||||
srvTask.block
|
||||
cliTask.block
|
||||
|
||||
-- ---------------------------------------------------------------------------
|
||||
-- Run all tests
|
||||
-- ---------------------------------------------------------------------------
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testContextCreation certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testInProcessHandshake certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testDataTransfer certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testPendingPlaintext certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testTCPSSL (SocketAddressV4.mk (.ofParts 127 0 0 1) 18443) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testTCPSSL (SocketAddressV6.mk (.ofParts 0 0 0 0 0 0 0 1) 18444) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testMultipleRoundTrips (SocketAddressV4.mk (.ofParts 127 0 0 1) 18445) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testLargePayload (SocketAddressV4.mk (.ofParts 127 0 0 1) 18446) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testRecvAfterShutdown (SocketAddressV4.mk (.ofParts 127 0 0 1) 18447) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testAcceptSelector (SocketAddressV4.mk (.ofParts 127 0 0 1) 18448) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testSendAll (SocketAddressV4.mk (.ofParts 127 0 0 1) 18449) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testRecvSelector (SocketAddressV4.mk (.ofParts 127 0 0 1) 18450) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testSequentialConnections (SocketAddressV4.mk (.ofParts 127 0 0 1) 18451) certFile keyFile
|
||||
|
||||
#eval do
|
||||
let (certFile, keyFile) ← setupTestCerts
|
||||
testBidirectional (SocketAddressV4.mk (.ofParts 127 0 0 1) 18452) certFile keyFile
|
||||
@@ -1,6 +1,6 @@
|
||||
module
|
||||
public import Init.Grind.Ring.CommSolver
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
public import Lean.Meta.Sym.Arith.Poly
|
||||
open Lean.Grind.CommRing
|
||||
|
||||
def w : Var := 0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
module
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
import Lean.Meta.Sym.Arith.Poly
|
||||
open Lean.Grind.CommRing
|
||||
|
||||
def w : Expr := .var 0
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
import Lean.Runtime
|
||||
|
||||
-- Non-emscripten build: expect the major version of OpenSSL (3)
|
||||
/-- info: 3 -/
|
||||
#guard_msgs in
|
||||
#eval if !System.Platform.isEmscripten then Lean.openSSLVersion >>> 28 else 3
|
||||
138
tests/elab/sym_arith_classify.lean
Normal file
138
tests/elab/sym_arith_classify.lean
Normal file
@@ -0,0 +1,138 @@
|
||||
import Lean
|
||||
|
||||
/-!
|
||||
# Tests for `Sym.Arith.Classify`, `Sym.Arith.EvalNum`, and `Sym.Arith.Functions`
|
||||
-/
|
||||
|
||||
open Lean Meta Sym Arith
|
||||
|
||||
/-- Extract the value of a definition by name. -/
|
||||
def getDefValue (n : Name) : MetaM Expr := do
|
||||
let some (.defnInfo info) := (← getEnv).find? n
|
||||
| throwError "expected definition: {n}"
|
||||
return info.value
|
||||
|
||||
/-! ## Classification tests -/
|
||||
|
||||
deriving instance Repr for ClassifyResult
|
||||
|
||||
/-- info: Lean.Meta.Sym.Arith.ClassifyResult.commRing 0 -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{repr (← classify? (mkConst ``Int))}"
|
||||
|
||||
/-- info: Lean.Meta.Sym.Arith.ClassifyResult.commSemiring 0 -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{repr (← classify? (mkConst ``Nat))}"
|
||||
|
||||
/-- info: Lean.Meta.Sym.Arith.ClassifyResult.none -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{repr (← classify? (mkConst ``Bool))}"
|
||||
|
||||
-- Classifying the same type twice should return cached result with same id
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
let .commRing id1 ← classify? (mkConst ``Int) | unreachable!
|
||||
let .commRing id2 ← classify? (mkConst ``Int) | unreachable!
|
||||
logInfo m!"{id1 == id2}"
|
||||
|
||||
/--
|
||||
info: Lean.Meta.Sym.Arith.ClassifyResult.commRing 0
|
||||
---
|
||||
info: Lean.Meta.Sym.Arith.ClassifyResult.commSemiring 0
|
||||
---
|
||||
info: Lean.Meta.Sym.Arith.ClassifyResult.commRing 2
|
||||
---
|
||||
info: Lean.Meta.Sym.Arith.ClassifyResult.commRing 1
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
let int ← shareCommon (mkConst ``Int)
|
||||
let nat ← shareCommon (mkConst ``Nat)
|
||||
let rat ← shareCommon (mkConst ``Rat)
|
||||
logInfo m!"{repr (← classify? int)}"
|
||||
logInfo m!"{repr (← classify? nat)}"
|
||||
logInfo m!"{repr (← classify? rat)}"
|
||||
let inst ← Sym.synthInstance (mkApp (mkConst ``Grind.Semiring [0]) nat)
|
||||
let ofSemiring ← shareCommon (← Sym.canon <| mkApp2 (mkConst ``Grind.Ring.OfSemiring.Q [0]) nat inst)
|
||||
logInfo m!"{repr (← classify? ofSemiring)}"
|
||||
|
||||
/-! ## EvalNum tests -/
|
||||
|
||||
def natZero : Nat := 0
|
||||
def natSucc3 : Nat := Nat.succ (Nat.succ (Nat.succ 0))
|
||||
def natSeven : Nat := 7
|
||||
def natAdd : Nat := 2 + 3
|
||||
def natMul : Nat := 2 * 3
|
||||
def natPow : Nat := 2 ^ 3
|
||||
def natBigPow : Nat := 2 ^ 100
|
||||
def natPow10 : Nat := 2 ^ 10
|
||||
|
||||
/-- info: some (0) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natZero)}"
|
||||
|
||||
/-- info: some (3) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natSucc3)}"
|
||||
|
||||
/-- info: some (7) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natSeven)}"
|
||||
|
||||
/-- info: some (5) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natAdd)}"
|
||||
|
||||
/-- info: some (6) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natMul)}"
|
||||
|
||||
/-- info: some (8) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natPow)}"
|
||||
|
||||
/-! ## Exp threshold tests -/
|
||||
|
||||
-- 2 ^ 100 should fail with default exp threshold (8)
|
||||
/-- info: none -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natBigPow)}"
|
||||
|
||||
-- 2 ^ 10 succeeds with exp threshold raised to 20
|
||||
/-- info: some (1024) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
withExpThreshold 20 do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natPow10)}"
|
||||
|
||||
/-! ## Int EvalNum tests -/
|
||||
|
||||
def intNeg : Int := -5
|
||||
def intAdd : Int := 3 + (-2)
|
||||
def intMul : Int := (-3) * 4
|
||||
|
||||
/-- info: some (-5) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalInt? (← getDefValue ``intNeg)}"
|
||||
|
||||
/-- info: some (1) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalInt? (← getDefValue ``intAdd)}"
|
||||
|
||||
/-- info: some (-12) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalInt? (← getDefValue ``intMul)}"
|
||||
172
tests/elab/sym_arith_reify.lean
Normal file
172
tests/elab/sym_arith_reify.lean
Normal file
@@ -0,0 +1,172 @@
|
||||
import Lean
|
||||
|
||||
/-!
|
||||
# Tests for `Sym.Arith.Reify`
|
||||
-/
|
||||
|
||||
open Lean Meta Sym Arith
|
||||
|
||||
/-- Extract the value of a definition by name. -/
|
||||
def getDefValue (n : Name) : MetaM Expr := do
|
||||
let some (.defnInfo info) := (← getEnv).find? n
|
||||
| throwError "expected definition: {n}"
|
||||
return info.value
|
||||
|
||||
/-!
|
||||
## Setup: a simple monad for testing reification
|
||||
-/
|
||||
|
||||
structure TestState where
|
||||
ring : CommRing
|
||||
vars : Array Expr := {}
|
||||
varMap : PHashMap ExprPtr Var := {}
|
||||
|
||||
abbrev TestM := StateRefT TestState SymM
|
||||
|
||||
instance : MonadCanon TestM where
|
||||
canonExpr e := Sym.canon e
|
||||
synthInstance? e := Sym.synthInstance? e
|
||||
|
||||
instance : MonadCommRing TestM where
|
||||
getCommRing := return (← get).ring
|
||||
modifyCommRing f := modify fun s => { s with ring := f s.ring }
|
||||
|
||||
instance : MonadMkVar TestM where
|
||||
mkVar e := do
|
||||
if let some v := (← get).varMap.find? { expr := e } then
|
||||
return v
|
||||
let v := (← get).vars.size
|
||||
modify fun s => { s with
|
||||
vars := s.vars.push e
|
||||
varMap := s.varMap.insert { expr := e } v
|
||||
}
|
||||
return v
|
||||
|
||||
instance : MonadGetVar TestM where
|
||||
getVar x := return (← get).vars[x]!
|
||||
|
||||
/-- Run a `TestM` on `Int`'s `CommRing`, canonicalizing `e` first. -/
|
||||
def reifyIntExpr (n : Name) (skipVar := true) : TestM (Option RingExpr) := do
|
||||
let e ← canonExpr (← getDefValue n)
|
||||
reifyRing? e (skipVar := skipVar)
|
||||
|
||||
def runTestOnInt (x : TestM α) : SymM α := do
|
||||
let .commRing id ← classify? (mkConst ``Int) | throwError "Int is not a CommRing"
|
||||
let ring := (← getArithState).rings[id]!
|
||||
x |>.run' { ring }
|
||||
|
||||
/-! ## Reify ring tests on Int -/
|
||||
|
||||
deriving instance Repr for Lean.Grind.CommRing.Expr
|
||||
|
||||
def intAdd : Int := 2 + 3
|
||||
def intMulAdd : Int := 2 * 3 + 1
|
||||
def intNeg : Int := -5
|
||||
def intPow : Int := 2 ^ 3
|
||||
def intSub : Int := 7 - 2
|
||||
|
||||
/-- info: some (Lean.Grind.CommRing.Expr.add (Lean.Grind.CommRing.Expr.num 2) (Lean.Grind.CommRing.Expr.num 3)) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intAdd)}"
|
||||
|
||||
/--
|
||||
info: some (Lean.Grind.CommRing.Expr.add
|
||||
(Lean.Grind.CommRing.Expr.mul (Lean.Grind.CommRing.Expr.num 2) (Lean.Grind.CommRing.Expr.num 3))
|
||||
(Lean.Grind.CommRing.Expr.num 1))
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intMulAdd)}"
|
||||
|
||||
/-- info: some (Lean.Grind.CommRing.Expr.neg (Lean.Grind.CommRing.Expr.num 5)) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intNeg)}"
|
||||
|
||||
/-- info: some (Lean.Grind.CommRing.Expr.pow (Lean.Grind.CommRing.Expr.num 2) 3) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intPow)}"
|
||||
|
||||
/--
|
||||
info: some (Lean.Grind.CommRing.Expr.sub (Lean.Grind.CommRing.Expr.num 7) (Lean.Grind.CommRing.Expr.num 2))
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intSub)}"
|
||||
|
||||
-- skipVar test: a non-arithmetic term returns none with skipVar=true
|
||||
/-- info: none -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
let a ← mkFreshExprMVar (mkConst ``Int)
|
||||
logInfo m!"{repr (← reifyRing? a)}"
|
||||
|
||||
-- skipVar=false: a non-arithmetic term becomes a variable
|
||||
/-- info: some (Lean.Grind.CommRing.Expr.var 0) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
let a ← mkFreshExprMVar (mkConst ``Int)
|
||||
logInfo m!"{repr (← reifyRing? a (skipVar := false))}"
|
||||
|
||||
opaque a : Int
|
||||
opaque b : Int
|
||||
opaque c : Int
|
||||
def e := (a + b*2) - (c*a + a*(3*b + c))
|
||||
|
||||
/--
|
||||
info: some (Lean.Grind.CommRing.Expr.sub
|
||||
(Lean.Grind.CommRing.Expr.add
|
||||
(Lean.Grind.CommRing.Expr.var 0)
|
||||
(Lean.Grind.CommRing.Expr.mul (Lean.Grind.CommRing.Expr.var 1) (Lean.Grind.CommRing.Expr.num 2)))
|
||||
(Lean.Grind.CommRing.Expr.add
|
||||
(Lean.Grind.CommRing.Expr.mul (Lean.Grind.CommRing.Expr.var 2) (Lean.Grind.CommRing.Expr.var 0))
|
||||
(Lean.Grind.CommRing.Expr.mul
|
||||
(Lean.Grind.CommRing.Expr.var 0)
|
||||
(Lean.Grind.CommRing.Expr.add
|
||||
(Lean.Grind.CommRing.Expr.mul (Lean.Grind.CommRing.Expr.num 3) (Lean.Grind.CommRing.Expr.var 1))
|
||||
(Lean.Grind.CommRing.Expr.var 2)))))
|
||||
---
|
||||
info: #[a, b, c]
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``e)}"
|
||||
logInfo (← get).vars
|
||||
|
||||
/-! ## Roundtrip tests: reify then denote -/
|
||||
|
||||
/-- Reify an expression, denote it back, and check they're definitionally equal. -/
|
||||
def roundtrip (n : Name) : TestM Unit := do
|
||||
let orig ← canonExpr (← getDefValue n)
|
||||
let some re ← reifyRing? orig (skipVar := false) | throwError "reify failed"
|
||||
let vars := (← get).vars
|
||||
let denoted ← denoteRingExpr vars re
|
||||
let denoted ← canonExpr denoted
|
||||
unless (← isDefEq orig denoted) do
|
||||
logInfo m!"MISMATCH for {n}:\n orig: {orig}\n denoted: {denoted}"
|
||||
return
|
||||
logInfo m!"roundtrip OK: {n}: {denoted}"
|
||||
|
||||
/--
|
||||
info: roundtrip OK: intAdd: 2 + 3
|
||||
---
|
||||
info: roundtrip OK: intMulAdd: 2 * 3 + 1
|
||||
---
|
||||
info: roundtrip OK: intNeg: -5
|
||||
---
|
||||
info: roundtrip OK: intPow: 2 ^ 3
|
||||
---
|
||||
info: roundtrip OK: intSub: 7 - 2
|
||||
---
|
||||
info: roundtrip OK: e: a + b * 2 - (c * a + a * (3 * b + c))
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
roundtrip ``intAdd
|
||||
roundtrip ``intMulAdd
|
||||
roundtrip ``intNeg
|
||||
roundtrip ``intPow
|
||||
roundtrip ``intSub
|
||||
roundtrip ``e
|
||||
Reference in New Issue
Block a user