Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5373d3f
Start implementing fuzzer
Colton1skees Aug 23, 2025
4cc6fe0
print widths of variables and constants in ast formatter
Colton1skees Aug 23, 2025
69252d6
Implement generation of random semi-linear expressions with mixed widths
Colton1skees Aug 23, 2025
a8fbc8b
Bug fix: Correctly compute random variable conjunction mask
Colton1skees Aug 23, 2025
fbb7dec
Add zext support to msimba
Colton1skees Aug 23, 2025
3d8e6f6
Constant fold pow operator
Colton1skees Aug 23, 2025
e133eea
cleanup
Colton1skees Aug 23, 2025
bde38b2
Add truncation operator
Colton1skees Aug 23, 2025
1108a7f
Add FFI wrapper for allocating trunc nodes
Colton1skees Aug 23, 2025
9e98775
update msimba to support truncation
Colton1skees Aug 23, 2025
d5f04e5
format
Colton1skees Aug 23, 2025
8354203
Add z3 support for trunc
Colton1skees Aug 23, 2025
209d1c9
Start implementing register allocating jit
Colton1skees Aug 26, 2025
105f18b
Collect uses
Colton1skees Aug 26, 2025
8ab2bc9
Implement more parts of JIT
Colton1skees Aug 27, 2025
ca50214
more jitting'
Colton1skees Aug 28, 2025
53c85ee
handle more edge cases
Colton1skees Aug 28, 2025
2dcc72e
actually fetch the DAG node value if > 1 use
Colton1skees Aug 28, 2025
8a15e3d
cleanup
Colton1skees Aug 28, 2025
76de50f
add todo notes for tomorrow
Colton1skees Aug 28, 2025
0a89091
Improve register allocation
Colton1skees Aug 29, 2025
e8e6407
Minimize register swapping; Correctly free up registers
Colton1skees Aug 29, 2025
eb180bc
changes
Colton1skees Aug 29, 2025
cfc9cf4
cleanup
Colton1skees Aug 29, 2025
4cd97f6
Emit prologue and epilogue; Respect x64 ABI
Colton1skees Aug 29, 2025
bf95bf6
Improve performance of linear and semi-linear mba evaluation via stor…
Colton1skees Aug 29, 2025
1884d38
Allocate rwx page & write compiled code; Implement throwaway code for…
Colton1skees Aug 29, 2025
38abc5f
Shift results down when evaluating signature vectors; Perf benchmarking
Colton1skees Aug 29, 2025
1f2ee64
more benchmarking
Colton1skees Aug 29, 2025
662bd5b
TODOs
Colton1skees Aug 29, 2025
586f749
Improve jit performance; Optionally disable packing one bit vars into…
Colton1skees Aug 30, 2025
1717030
more benchmarking
Colton1skees Aug 30, 2025
774a296
Add jit support for POW operator
Colton1skees Aug 30, 2025
342bdfa
Add zext support to jit
Colton1skees Aug 30, 2025
dd0beaf
Reduce result modulo 2**w
Colton1skees Aug 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions EqSat/src/isle/mba.isle
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
(Constant (c u64) (width u8) )
(Symbol (id u32) (width u8) )
(Zext (a index) (to u8) )
(Trunc (a index) (to u8) )
))

(type ConstantWithExpectedValue extern
Expand Down Expand Up @@ -60,6 +61,8 @@
(extern constructor Symbol symbol)
(decl Zext (index u8) SimpleAst)
(extern constructor Zext zext)
(decl Trunc (index u8) SimpleAst)
(extern constructor Trunc trunc)

;; Constant folding utilities
(decl FoldAdd (index index) SimpleAst)
Expand Down
293 changes: 147 additions & 146 deletions EqSat/src/mba.rs

Large diffs are not rendered by default.

221 changes: 208 additions & 13 deletions EqSat/src/simple_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
collections::{hash_map::Entry, HashMap, HashSet},
f32::consts::PI,
ffi::{CStr, CString},
u64,
u64, vec,
};

use ahash::AHashMap;
Expand Down Expand Up @@ -199,6 +199,23 @@ impl Arena {
return self.insert_ast_node(SimpleAst::Zext { a, to: width }, data);
}

pub fn trunc(&mut self, a: AstIdx, width: u8) -> AstIdx {
let cost = (1 as u32).saturating_add(self.get_data(a).cost);
let has_poly = self.get_has_poly(a);

let mask = get_modulo_mask(width);
let mask_node = self.constant(mask, width);
let class = self.compute_bitwise_class(a, mask_node);
let data = AstData {
width: width,
cost: cost,
has_poly: has_poly,
class: class,
};

return self.insert_ast_node(SimpleAst::Zext { a, to: width }, data);
}

pub fn constant(&mut self, c: u64, width: u8) -> AstIdx {
let data = AstData {
width: width,
Expand Down Expand Up @@ -428,6 +445,7 @@ pub enum SimpleAst {
Symbol { id: u32, width: u8 },
// Special operators
Zext { a: AstIdx, to: u8 },
Trunc { a: AstIdx, to: u8 },
}

pub struct Context {
Expand Down Expand Up @@ -535,6 +553,12 @@ impl mba::Context for Context {
self.arena.get_node(zext).clone()
}

fn trunc(&mut self, arg0: AstIdx, width: u8) -> SimpleAst {
let zext = self.arena.zext(arg0, width);

self.arena.get_node(zext).clone()
}

fn any(&mut self, arg0: AstIdx) -> SimpleAst {
return self.arena.get_node(arg0).clone();
}
Expand Down Expand Up @@ -624,6 +648,7 @@ impl AstPrinter {
SimpleAst::Constant { c, width } => "",
SimpleAst::Symbol { id, width } => "",
SimpleAst::Zext { a, to } => "zx",
SimpleAst::Trunc { a, to } => "tr",
};

// Don't put parens for constants or symbols
Expand All @@ -644,17 +669,25 @@ impl AstPrinter {
}
SimpleAst::Zext { a, to } => {
self.print_node(ctx, ctx.arena.get_node(*a));
self.output.push_str(&format!(" {} ", operator));
self.output.push_str(&(*to).to_string());
self.output.push_str(&format!(" {} i{}", operator, to));
}
SimpleAst::Trunc { a, to } => {
self.print_node(ctx, ctx.arena.get_node(*a));
self.output.push_str(&format!(" {} i{}", operator, to));
}
SimpleAst::Neg { a } => {
self.output.push('~');
self.print_node(ctx, ctx.arena.get_node(*a));
}
SimpleAst::Constant { c, width } => self.output.push_str(&(*c as i64).to_string()),
SimpleAst::Symbol { id, width } => self
.output
.push_str(&ctx.arena.get_symbol_name(*id).clone()),
SimpleAst::Constant { c, width } => {
self.output
.push_str(&format!("{}:i{}", (*c as i64).to_string(), width))
}
SimpleAst::Symbol { id, width } => self.output.push_str(&format!(
"{}:i{}",
ctx.arena.get_symbol_name(*id).clone(),
width
)),
}

if operator != "" {
Expand All @@ -681,6 +714,7 @@ pub fn eval_ast(ctx: &Context, idx: AstIdx, value_mapping: &HashMap<AstIdx, u64>
SimpleAst::Constant { c, width } => *c,
SimpleAst::Symbol { id, width } => *value_mapping.get(&idx).unwrap(),
SimpleAst::Zext { a, to } => get_modulo_mask(ctx.arena.get_width(*a)) & e(a),
SimpleAst::Trunc { a, to } => get_modulo_mask(*to) & e(a),
}
}

Expand Down Expand Up @@ -715,6 +749,10 @@ pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx {
let op1 = recursive_simplify(ctx, a);
ast = SimpleAst::Zext { a: op1, to }
}
SimpleAst::Trunc { a, to } => {
let op1 = recursive_simplify(ctx, a);
ast = SimpleAst::Zext { a: op1, to }
}
SimpleAst::Constant { c, width } => return idx,
SimpleAst::Symbol { id, width } => return idx,
}
Expand Down Expand Up @@ -805,7 +843,7 @@ fn collect_var_indices_internal(
| SimpleAst::And { a, b }
| SimpleAst::Or { a, b }
| SimpleAst::Xor { a, b } => vbin(*a, *b),
SimpleAst::Neg { a } | SimpleAst::Zext { a, .. } => {
SimpleAst::Neg { a } | SimpleAst::Zext { a, .. } | SimpleAst::Trunc { a, .. } => {
collect_var_indices_internal(ctx, *a, visited, out_vars)
}
SimpleAst::Constant { c, width } => return,
Expand Down Expand Up @@ -962,6 +1000,14 @@ pub extern "C" fn ContextZext(ctx: *mut Context, a: AstIdx, width: u8) -> AstIdx
}
}

#[no_mangle]
pub extern "C" fn ContextTrunc(ctx: *mut Context, a: AstIdx, width: u8) -> AstIdx {
unsafe {
let id = (*ctx).arena.trunc(a, width);
return id;
}
}

#[no_mangle]
pub extern "C" fn ContextConstant(ctx: *mut Context, c: u64, width: u8) -> AstIdx {
unsafe {
Expand Down Expand Up @@ -1010,6 +1056,7 @@ pub fn get_opcode(ctx: &Context, id: AstIdx) -> u8 {
SimpleAst::Constant { c, width } => 8,
SimpleAst::Symbol { id, width } => 9,
SimpleAst::Zext { a, to } => 10,
SimpleAst::Trunc { a, to } => 11,
};
}

Expand Down Expand Up @@ -1067,6 +1114,7 @@ pub fn get_op0(ctx: &Context, id: AstIdx) -> AstIdx {
SimpleAst::Xor { a, b } => *a,
SimpleAst::Neg { a } => *a,
SimpleAst::Zext { a, to } => *a,
SimpleAst::Trunc { a, to } => *a,
_ => unreachable!("Type has no first operand!"),
};
}
Expand Down Expand Up @@ -1105,6 +1153,30 @@ pub extern "C" fn ContextGetConstantValue(ctx: *mut Context, id: AstIdx) -> u64
panic!("ast is not a constant!");
}

#[no_mangle]
pub extern "C" fn ContextGetZextToWidth(ctx: *const Context, id: AstIdx) -> u8 {
unsafe {
let ast = (*ctx).arena.get_node(id);
if let SimpleAst::Zext { a, to } = ast {
return (*to);
}
}

panic!("ast is not a constant!");
}

#[no_mangle]
pub extern "C" fn ContextGetTruncToWidth(ctx: *const Context, id: AstIdx) -> u8 {
unsafe {
let ast = (*ctx).arena.get_node(id);
if let SimpleAst::Trunc { a, to } = ast {
return (*to);
}
}

panic!("ast is not a constant!");
}

#[no_mangle]
pub extern "C" fn ContextGetSymbolName(ctx: *mut Context, id: AstIdx) -> *mut c_char {
unsafe {
Expand Down Expand Up @@ -1384,7 +1456,7 @@ unsafe fn jit_rec(
emit(page, offset, &[0x48, 0xc1, 0xE6, var_idx]);

// // varValue = i & varMask
// mov rdi, combIdxRegister
// mov rdi, combIdxRegister (rdx)
emit(page, offset, &[0x48, 0x89, 0xD7]);
// and rdi, rsi
emit(page, offset, &[0x48, 0x21, 0xF7]);
Expand All @@ -1403,14 +1475,31 @@ unsafe fn jit_rec(
emit_u8(page, offset, PUSH_RDI);
}
SimpleAst::Zext { a, to } => {
// Zero extend is a no-op in our JIT, since we always AND with a mask after every operation.
jit_rec(ctx, *a, node_to_var, page, offset);
}
SimpleAst::Trunc { a, to } => {
jit_rec(ctx, *a, node_to_var, page, offset);
emit_u8(page, offset, POP_RSI);

// TODO: AND with mask!
//emit(page, offset, &[0x48, 0xF7, 0xD6]);
emit_u8(page, offset, PUSH_RSI);
// mov rax, constant
emit_u8(page, offset, 0x48);
emit_u8(page, offset, 0xB8);
// Fill in the constant
let trunc_mask = get_modulo_mask(*to);
emit_u64(page, offset, trunc_mask);
// and [rsp+8], rax
emit(page, offset, &[0x48, 0x21, 0x44, 0x24, 0x08]);
}
};

// mov rax, constant
emit_u8(page, offset, 0x48);
emit_u8(page, offset, 0xB8);
// Fill in the constant
let c = get_modulo_mask(ctx.arena.get_width(node));
emit_u64(page, offset, c);
// and [rsp+8], rax
emit(page, offset, &[0x48, 0x21, 0x44, 0x24, 0x08]);
}

unsafe fn jit_constant(c: u64, page: *mut u8, offset: &mut usize) {
Expand All @@ -1437,6 +1526,112 @@ pub extern "C" fn Pow(mut base: u64, mut exp: u64) -> u64 {
return res;
}

#[no_mangle]
pub unsafe extern "C" fn ContextCompile(
ctx_p: *mut Context,
node: AstIdx,
mask: u64,
multi_bit_u: u32,
bit_width: u32,
variables: *const AstIdx,
var_count: u64,
num_combinations: u64,
page: *mut u8,
) {
let multi_bit = multi_bit_u != 0;
let num_bit_iterations: u32 = if multi_bit { bit_width } else { 1 };

let mut ctx: &mut Context = &mut (*ctx_p);

let mut offset: usize = 0;

// Push all clobbered registers
emit_u8(page, &mut offset, PUSH_RBX);
emit_u8(page, &mut offset, PUSH_RSI);
emit_u8(page, &mut offset, PUSH_RDI);

// JIT code
let mut node_to_var: HashMap<AstIdx, u8> = HashMap::with_capacity(var_count as usize);
for i in 0..var_count {
node_to_var.insert(*variables.add(i as usize), i as u8);
}

jit_rec(ctx, node, &node_to_var, page, &mut offset);

// Pop the evaluation result
emit_u8(page, &mut offset, POP_RAX);

// Mask off bits that we don't care about
// mov rsi, mask
emit(page, &mut offset, &[0x48, 0xBE]);
emit_u64(page, &mut offset, mask);

// and rax, rsi
emit(page, &mut offset, &[0x48, 0x21, 0xF0]);

// Shift the value back down to bit index zero,
// varValue = varValue >> (ushort)v
// shr rax, bitIdxRegister
emit(page, &mut offset, &[0x48, 0xD3, 0xE8]);

// Restore the clobbered registers.
emit_u8(page, &mut offset, POP_RDI);
emit_u8(page, &mut offset, POP_RSI);
emit_u8(page, &mut offset, POP_RBX);

emit_u8(page, &mut offset, RET);

// Don't execute
}

#[no_mangle]
pub unsafe extern "C" fn ContextExecute(
multi_bit_u: u32,
bit_width: u32,
var_count: u64,
num_combinations: u64,
page: *mut u8,
output: *mut u64,
one_bit_vars: u32,
) {
let multi_bit = multi_bit_u != 0;
let num_bit_iterations: u32 = if multi_bit { bit_width } else { 1 };

if (one_bit_vars != 0) {
let fptr: unsafe extern "C" fn(u32, u64) -> u64 = std::mem::transmute(page);

let mut arr_idx: usize = 0;
for bit_index in 0..num_bit_iterations {
for i in 0..num_combinations {
let result = fptr(bit_index, i);
*output.add(arr_idx) = result;
arr_idx += 1;
}
}

return;
}

let mut var_values = vec![0u64; var_count as usize];
let vptr = var_values.as_mut_slice();

let mut arr_idx: usize = 0;
let fptr: unsafe extern "C" fn(*mut u64) -> u64 = std::mem::transmute(page);
for bit_index in 0..num_bit_iterations {
for i in 0..num_combinations {
for v_idx in 0..var_count {
vptr[v_idx as usize] = ((i >> v_idx) & 1) << bit_index;
// *var_values.get_mut(v_idx as usize) = (i >> v_idx) & 1;
// (v_idx as usize);
}

let result = fptr(vptr.as_mut_ptr()) >> bit_index;
*output.add(arr_idx) = result;
arr_idx += 1;
}
}
}

#[no_mangle]
pub unsafe extern "C" fn ContextJit(
ctx_p: *mut Context,
Expand Down
Loading