1use super::Symbols;
6use crate::{
7 drcore::{self, get_proc_address, log},
8 ffi,
9 utils::{Boolean, ReadError},
10};
11use enum_iterator::all;
12use spin::Mutex;
13use std::{collections::HashSet, os::raw::c_void};
14use std::{ffi::CStr, sync::LazyLock};
15use std::{os::raw::c_char, ptr};
16
17static WRAPPED: LazyLock<Mutex<Option<HashSet<usize>>>> = LazyLock::new(|| Mutex::new(None));
21
22fn push_wrapped(addr: usize) {
23 let mut wrapped_fns = WRAPPED.lock();
24 match wrapped_fns.as_mut() {
25 Some(wrapped_set) => {
26 wrapped_set.insert(addr);
27 }
28 None => {
29 let mut wrapped_set = HashSet::new();
30 wrapped_set.insert(addr);
31 *wrapped_fns = Some(wrapped_set);
32 }
33 }
34}
35
36fn is_wrapped(addr: usize) -> bool {
37 match WRAPPED.lock().as_ref() {
38 Some(wrapped_set) => wrapped_set.contains(&addr),
39 None => false,
40 }
41}
42
43#[unsafe(no_mangle)]
45pub extern "C" fn wrap_compare_symbols(mod_base_addr: ffi::app_pc, mod_name_ptr: *const c_char) {
46 unsafe {
47 let mod_name = CStr::from_ptr(mod_name_ptr)
48 .to_str()
49 .unwrap_or("<UTF8-parse-err>");
50 all::<Symbols>().for_each(|symbol| {
51 if let Some(addr) = get_proc_address(mod_base_addr, symbol.to_str()) {
52 if !is_wrapped(addr as usize) {
53 match ffi::drwrap_wrap(addr as *mut u8, Some(symbol.wrapper_fn()), None)
54 .as_bool()
55 {
56 true => {
57 log(&format!(
58 "[module load] wrapped {} in module {mod_name} @ 0x{:?}",
59 symbol.to_str(),
60 addr
61 ));
62 push_wrapped(addr as usize);
63 }
64 false => log(&format!(
65 "[module load] failed to wrap {} in module {mod_name} @ 0x{:?}",
66 symbol.to_str(),
67 addr
68 )),
69 }
70 }
71 };
72 });
73 }
74}
75
76pub unsafe fn compare_ptrs(
81 fn_name: &str,
82 afl_cmp_area: *mut u8,
83 base_idx: usize,
84 a: *const u8,
85 b: *const u8,
86 n: usize,
87) -> Result<(), ReadError> {
88 let aa = drcore::safe_read(a as *mut c_void, n)?;
89 let bb = drcore::safe_read(b as *mut c_void, n)?;
90 log(&format!(
91 "[{fn_name}] comparing {n} bytes:\na: {:?}\nb: {:?}",
92 aa, bb
93 ));
94 unsafe { Ok(write_coverage(afl_cmp_area, base_idx, &aa, &bb, n)) }
95}
96
97pub unsafe fn base_idx_of_app_pc(app_pc: *const u8) -> usize {
100 unsafe {
101 ((app_pc as usize).wrapping_mul(0x9E37_79B1) as usize)
103 & (super::WINAFL_CMP_MAP_SIZE as usize - super::CMP_MAX_LEN)
104 }
105}
106
107pub unsafe fn write_coverage(afl_cmp_area: *mut u8, base_idx: usize, a: &[u8], b: &[u8], n: usize) {
112 unsafe {
113 unsafe fn incr_at_idx(afl_cmp_area: *mut u8, idx: usize) {
114 unsafe {
115 let p = afl_cmp_area.add(idx);
116 let val = ptr::read_volatile(p);
118 if val < u8::MAX {
119 ptr::write_volatile(p, val.wrapping_add(1));
120 }
121 }
122 }
123
124 let truncated_cmp_len = super::CMP_MAX_LEN.min(n);
125
126 for i in 0..truncated_cmp_len {
127 let idx = base_idx + i;
128 if a[i] == b[i] {
129 incr_at_idx(afl_cmp_area, idx);
130 } else {
131 break;
132 }
133 }
134 }
135}
136
137pub unsafe fn get_afl_cmp_area() -> Result<*mut u8, String> {
138 unsafe {
139 let drcontext = ffi::dr_get_current_drcontext();
140 let thread_data = ffi::drmgr_get_tls_field(drcontext, super::winafl_tls_field);
141 if thread_data.is_null() {
142 return Err(
143 "pointer to winafl coverage data (in thread-local storage) is null".to_string(),
144 );
145 }
146
147 let thread_data = thread_data as *mut *mut c_void;
149 let afl_area = *thread_data.add(1) as *mut u8;
150 if afl_area.is_null() {
151 return Err("pointer to winafl bitmap (in thread-local storage) is null".to_string());
152 }
153
154 Ok(afl_area.add(super::WINAFL_COV_MAP_SIZE as usize))
157 }
158}