1use crate::sync::Mutex;
27use core::{
28 alloc::{GlobalAlloc, Layout},
29 fmt::{Debug, Display},
30 ptr::{NonNull, null_mut},
31 sync::atomic::{AtomicPtr, Ordering},
32};
33use vera_portal::{
34 MapMemoryError, MemoryLocation, MemoryProtections,
35 sys_client::{map_memory, unmap_memory},
36};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
39pub enum MemoryAllocationError {
40 OutOfMemory,
41 OutOfSystemMemory,
42 DoubleFree,
43 NotAllocated,
44 NotInRegion,
45}
46
47impl Display for MemoryAllocationError {
48 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
49 core::fmt::Debug::fmt(self, f)
50 }
51}
52
53impl core::error::Error for MemoryAllocationError {}
54
55impl From<MapMemoryError> for MemoryAllocationError {
56 fn from(value: MapMemoryError) -> Self {
57 match value {
58 MapMemoryError::InvalidLength(l) => {
59 panic!("Inner alloc error: Invalid allocation length = {l}")
60 }
61 MapMemoryError::MappingMemoryError => {
62 panic!("Inner alloc error: OS mapping memory error!")
63 }
64 MapMemoryError::OutOfMemory => Self::OutOfSystemMemory,
65 }
66 }
67}
68
69pub type Result<T> = core::result::Result<T, MemoryAllocationError>;
70
71#[derive(Debug, PartialEq, Eq)]
72enum BuddyState {
73 Free,
74 Used { layout: Layout },
75}
76
77#[derive(Debug)]
78struct BuddyNode {
79 next: Option<NonNull<BuddyNode>>,
80 prev: Option<NonNull<BuddyNode>>,
81 state: BuddyState,
82 size: usize,
83}
84
85struct BuddyAllocator {
86 head: Option<NonNull<BuddyNode>>,
87 region_start: NonNull<u8>,
88 region_end: NonNull<u8>,
89}
90
91impl BuddyAllocator {
92 const fn new(ptr: NonNull<u8>, len: usize) -> Self {
93 let buddy_allocator = Self {
94 head: None,
95 region_start: ptr,
96 region_end: unsafe { ptr.byte_add(len) },
97 };
98 buddy_allocator
99 }
100
101 fn head(&mut self) -> NonNull<BuddyNode> {
102 let buddy = *self.head.get_or_insert_with(|| {
103 let region = self.region_start..self.region_end;
104 let offset = self.region_start.align_offset(align_of::<BuddyNode>());
105 let new_buddy = unsafe { self.region_start.byte_add(offset) }.cast::<BuddyNode>();
106
107 assert!(region.contains(&new_buddy.cast::<u8>()));
108
109 unsafe {
110 new_buddy.write(BuddyNode {
111 next: None,
112 prev: None,
113 state: BuddyState::Free,
114 size: (self.region_end.addr().get() - new_buddy.addr().get())
115 - size_of::<BuddyNode>(),
116 });
117 }
118
119 new_buddy
120 });
121
122 self.safety_check_buddy(buddy);
123 buddy
124 }
125
126 #[inline]
128 fn safety_check_buddy(&self, buddy: NonNull<BuddyNode>) -> BuddyNode {
129 let region = self.region_start..self.region_end;
130
131 debug_assert!(buddy.is_aligned());
132 assert!(region.contains(&buddy.cast::<u8>()));
133
134 let buddy_read = unsafe { buddy.read() };
135 debug_assert!(
136 buddy_read
137 .next
138 .is_none_or(|next| { region.contains(&next.cast::<u8>()) })
139 );
140 debug_assert!(
141 buddy_read
142 .prev
143 .is_none_or(|prev| { region.contains(&prev.cast::<u8>()) })
144 );
145 debug_assert_ne!(buddy_read.size, 0, "{:#?}", self);
146
147 buddy_read
148 }
149
150 unsafe fn alloc(&mut self, layout: Layout) -> Result<*mut u8> {
151 if layout.size() == 0 {
152 return Ok(self.region_start.as_ptr());
153 }
154
155 let mut cursor = self.head();
156
157 loop {
158 let cursor_read = self.safety_check_buddy(cursor);
159 if matches!(cursor_read.state, BuddyState::Used { .. }) {
160 cursor = cursor_read.next.ok_or(MemoryAllocationError::OutOfMemory)?;
161 continue;
162 }
163
164 let post_header_ptr = unsafe { cursor.byte_add(size_of::<BuddyNode>()) };
165 let post_header_size = cursor_read.size;
166 let end_region_ptr = unsafe { post_header_ptr.byte_add(post_header_size) };
167
168 let type_alignment_cost = post_header_ptr.cast::<u8>().align_offset(layout.align());
169 let type_size = type_alignment_cost + layout.size();
170
171 if post_header_size < type_size {
173 if let Some(next_cursor) = cursor_read.next {
174 cursor = next_cursor;
175 } else {
176 return Err(MemoryAllocationError::OutOfMemory);
177 }
178 continue;
179 }
180
181 let post_allocation_bytes = post_header_size - type_size;
182 let next_header_alignmnet_cost = unsafe {
183 post_header_ptr
184 .cast::<u8>()
185 .byte_add(type_size)
186 .align_offset(align_of::<BuddyNode>())
187 };
188
189 if post_allocation_bytes > next_header_alignmnet_cost + (2 * size_of::<BuddyNode>()) {
191 let mut next_buddy_ptr =
192 unsafe { post_header_ptr.byte_add(type_size + next_header_alignmnet_cost) };
193
194 debug_assert!(next_buddy_ptr.is_aligned());
195 debug_assert!(
196 unsafe { next_buddy_ptr.byte_add(size_of::<BuddyNode>()) } < end_region_ptr
197 );
198
199 let new_post_header_size =
200 next_buddy_ptr.addr().get() - post_header_ptr.addr().get();
201 let next_size = (end_region_ptr.addr().get() - next_buddy_ptr.addr().get())
202 - size_of::<BuddyNode>();
203
204 unsafe {
206 let next_mut = next_buddy_ptr.as_mut();
207
208 next_mut.prev = Some(cursor);
209 next_mut.size = next_size;
210 next_mut.state = BuddyState::Free;
211
212 if let Some(mut next) = cursor_read.next {
213 next_mut.next = Some(next);
214 next.as_mut().prev = Some(next_buddy_ptr);
215 } else {
216 next_mut.next = None;
217 }
218
219 let cursor_mut = cursor.as_mut();
220 cursor_mut.next = Some(next_buddy_ptr);
221 cursor_mut.size = new_post_header_size;
222 };
223 }
224
225 unsafe {
227 let cursor_mut = cursor.as_mut();
228 cursor_mut.state = BuddyState::Used { layout };
229 }
230
231 let ret_ptr: *mut u8 = unsafe { post_header_ptr.byte_add(type_alignment_cost) }
232 .cast()
233 .as_ptr();
234
235 debug_assert_eq!(ret_ptr.addr() as u64 & (layout.align() as u64 - 1), 0);
236 unsafe { ret_ptr.write_bytes(0, layout.size()) };
237
238 return Ok(ret_ptr);
239 }
240 }
241
242 fn combine(&mut self, cursor: NonNull<BuddyNode>) {
243 let mut current = cursor;
245 loop {
246 let current_read = self.safety_check_buddy(current);
247 let Some(prev) = current_read.prev else {
248 break;
249 };
250 let prev_read = self.safety_check_buddy(prev);
251
252 if !matches!(prev_read.state, BuddyState::Free)
253 || !matches!(current_read.state, BuddyState::Free)
254 {
255 break;
256 }
257
258 unsafe {
259 prev.write(BuddyNode {
260 next: current_read.next,
261 prev: prev_read.prev,
262 state: BuddyState::Free,
263 size: current_read.size + prev_read.size + size_of::<BuddyNode>(),
264 });
265
266 if let Some(mut next) = current_read.next {
267 next.as_mut().prev = Some(prev);
268 }
269 }
270
271 current = prev;
272 }
273
274 loop {
276 let current_read = self.safety_check_buddy(current);
277 let Some(next) = current_read.next else {
278 break;
279 };
280 let next_read = self.safety_check_buddy(next);
281
282 if !matches!(next_read.state, BuddyState::Free)
283 || !matches!(current_read.state, BuddyState::Free)
284 {
285 break;
286 }
287
288 unsafe {
289 current.write(BuddyNode {
290 next: next_read.next,
291 prev: current_read.prev,
292 state: BuddyState::Free,
293 size: current_read.size + next_read.size + size_of::<BuddyNode>(),
294 });
295
296 if let Some(mut next_next) = next_read.next {
297 next_next.as_mut().prev = Some(current);
298 }
299 }
300
301 current = next;
302 }
303 }
304
305 unsafe fn dealloc(&mut self, ptr: *mut u8, layout: Layout) -> Result<()> {
306 if self.region_end.as_ptr() < ptr && self.region_start.as_ptr() > ptr {
307 return Err(MemoryAllocationError::NotInRegion);
308 }
309
310 if layout.size() == 0 {
311 assert_eq!(ptr, self.region_start.as_ptr());
312 }
313
314 let mut cursor = self.head();
315
316 loop {
317 let cursor_read = self.safety_check_buddy(cursor);
318 let post_header_size = cursor_read.size;
319 let post_header_ptr = unsafe { cursor.byte_add(size_of::<BuddyNode>()) }.cast::<u8>();
320 let post_header_end =
321 unsafe { post_header_ptr.byte_add(post_header_size) }.cast::<u8>();
322
323 if !(post_header_ptr.as_ptr()..post_header_end.as_ptr()).contains(&ptr) {
324 cursor = cursor_read
325 .next
326 .expect("reached end of region, but didn't find ptr to free!");
327 continue;
328 }
329
330 match cursor_read.state {
332 BuddyState::Free => return Err(MemoryAllocationError::DoubleFree),
333 BuddyState::Used {
334 layout: state_layout,
335 } if state_layout != layout => {
336 return Err(MemoryAllocationError::NotAllocated);
337 }
338 _ => (),
339 }
340
341 unsafe { cursor.as_mut().state = BuddyState::Free };
342 self.combine(cursor);
343
344 break Ok(());
345 }
346 }
347}
348
349impl Debug for BuddyAllocator {
350 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
351 struct Fields {
352 head_ptr: Option<NonNull<BuddyNode>>,
353 }
354
355 impl Debug for Fields {
356 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
357 let mut list = f.debug_list();
358
359 if let Some(mut alloc) = self.head_ptr {
360 list.entry(&unsafe { alloc.read() });
361
362 while let Some(next) = unsafe { alloc.read().next } {
363 list.entry(&unsafe { next.read() });
364 alloc = next;
365 }
366 }
367
368 list.finish()
369 }
370 }
371
372 f.debug_struct(stringify!(BuddyAllocator))
373 .field("head", &self.head)
374 .field(
375 "region_len",
376 &(self.region_end.addr().get() - self.region_start.addr().get()),
377 )
378 .field(
379 "alloc",
380 &Fields {
381 head_ptr: self.head,
382 },
383 )
384 .finish()
385 }
386}
387
388struct MemoryMapRegion {
390 alloc: Mutex<Option<BuddyAllocator>>,
391 next: AtomicPtr<MemoryMapRegion>,
392}
393
394impl MemoryMapRegion {
395 const REGION_SIZE: usize = 1024 * 1024;
396
397 pub const fn new() -> Self {
399 Self {
400 alloc: Mutex::new(None),
401 next: AtomicPtr::new(null_mut()),
402 }
403 }
404
405 fn new_buddy() -> Result<BuddyAllocator> {
406 let memory_region_ptr = map_memory(
407 MemoryLocation::Anywhere,
408 MemoryProtections::ReadWrite,
409 Self::REGION_SIZE,
410 )?;
411 Ok(BuddyAllocator::new(
412 NonNull::new(memory_region_ptr).expect("Mapping memory should never return 0"),
413 Self::REGION_SIZE,
414 ))
415 }
416
417 pub unsafe fn alloc(&self, layout: Layout) -> Result<*mut u8> {
418 let mut alloc_lock = self.alloc.lock();
419
420 if alloc_lock.is_none() {
422 if let Some(previous_alloc) = alloc_lock.replace(Self::new_buddy()?) {
423 unreachable!(
424 "Tried to replace an existing MemoryMapRegion -- {:#?}",
425 previous_alloc
426 );
427 }
428 }
429
430 let alloc_inner = alloc_lock
431 .as_mut()
432 .expect("Just allocated a region, cannot be none!");
433
434 unsafe {
435 match alloc_inner.alloc(layout) {
436 Ok(allocated) => Ok(allocated),
437 Err(MemoryAllocationError::OutOfMemory) => {
438 let mut next_ptr = self.next.load(Ordering::Relaxed);
443 if next_ptr.is_null() {
444 let mut new_buddy = Self::new_buddy()?;
445 let new_region_start = new_buddy.region_start;
446
447 next_ptr = new_buddy
449 .alloc(Layout::new::<MemoryMapRegion>())?
450 .cast::<MemoryMapRegion>();
451
452 *next_ptr = MemoryMapRegion {
453 alloc: Mutex::new(Some(new_buddy)),
454 next: AtomicPtr::new(null_mut()),
455 };
456
457 while let Err(failed_set) = self.next.compare_exchange(
458 null_mut(),
459 next_ptr,
460 Ordering::SeqCst,
461 Ordering::Relaxed,
462 ) {
463 if failed_set.is_null() {
465 continue;
466 }
467
468 unmap_memory(new_region_start.as_ptr());
471 next_ptr = failed_set;
472 }
473 }
474
475 (&mut *next_ptr).alloc(layout)
476 }
477 Err(other_error) => Err(other_error),
478 }
479 }
480 }
481
482 pub unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) -> Result<()> {
483 let mut alloc_lock = self.alloc.lock();
484
485 let Some(ref mut inner) = *alloc_lock else {
486 return Err(MemoryAllocationError::NotAllocated);
487 };
488
489 unsafe {
490 match inner.dealloc(ptr, layout) {
491 Ok(_) => Ok(()),
492 Err(MemoryAllocationError::NotInRegion) => {
493 let next_ptr = self.next.load(Ordering::Relaxed);
494
495 if next_ptr.is_null() {
496 return Err(MemoryAllocationError::NotAllocated);
497 }
498
499 (&*next_ptr).dealloc(ptr, layout)
500 }
501 Err(err) => Err(err),
502 }
503 }
504 }
505}
506
507impl Drop for MemoryMapRegion {
508 fn drop(&mut self) {
509 let next_ptr = self.next.load(Ordering::Relaxed);
510
511 if !next_ptr.is_null() {
512 unsafe {
513 core::ptr::drop_in_place(next_ptr);
514 }
515 }
516
517 let alloc_lock = self.alloc.lock();
518 let Some(ref alloc) = *alloc_lock else {
519 return;
520 };
521
522 let ptr = alloc.region_start.as_ptr();
524 unmap_memory(ptr);
525 }
526}
527
528pub struct QuantumHeap {
530 head_region: MemoryMapRegion,
531}
532
533impl QuantumHeap {
534 pub const fn new() -> Self {
536 Self {
537 head_region: MemoryMapRegion::new(),
538 }
539 }
540
541 #[inline]
543 pub unsafe fn inner_allocate(&self, layout: Layout) -> Result<*mut u8> {
544 unsafe { self.head_region.alloc(layout) }
545 }
546
547 #[inline]
549 pub unsafe fn inner_deallocate(&self, ptr: *mut u8, layout: Layout) -> Result<()> {
550 unsafe { self.head_region.dealloc(ptr, layout) }
551 }
552}
553
554unsafe impl GlobalAlloc for QuantumHeap {
555 unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
556 unsafe { self.inner_allocate(layout).unwrap() }
557 }
558
559 unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
560 unsafe { self.inner_deallocate(ptr, layout).unwrap() }
561 }
562}
563
564#[cfg(test)]
565mod test {
566 use super::*;
567
568 use core::alloc::Layout;
569 extern crate std;
570
571 #[test]
572 fn test_buddy_new() {
573 let len = 10 * 1024;
574 let layout = Layout::from_size_align(len, 1).unwrap();
575 let mem_region = unsafe { std::alloc::alloc_zeroed(layout) };
576
577 let mut ptrs = std::vec::Vec::new();
578 let mut alloc = BuddyAllocator::new(NonNull::new(mem_region).unwrap(), len);
579
580 for i in 0..3 {
581 let ptr = unsafe { alloc.alloc(Layout::new::<u8>()) }.unwrap();
582 unsafe { *ptr = i };
583 assert_eq!(unsafe { *ptr }, i);
584 ptrs.push(ptr);
585 }
586
587 for i in 0..3 {
588 let ptr = ptrs[i as usize];
589 assert_eq!(unsafe { *ptr }, i);
590 unsafe { alloc.dealloc(ptr, Layout::new::<u8>()) }.unwrap();
591 }
592
593 unsafe { std::alloc::dealloc(mem_region, layout) };
594 }
595
596 #[test]
597 fn alloc_random() {
598 let len = 32 * 1024;
599 let layout = Layout::from_size_align(len, 1).unwrap();
600 let mem_region = unsafe { std::alloc::alloc_zeroed(layout) };
601
602 let mut ptrs = std::vec::Vec::new();
603 let mut alloc = BuddyAllocator::new(NonNull::new(mem_region).unwrap(), len);
604
605 for i in 0..100 {
606 let ptr =
607 unsafe { alloc.alloc(Layout::from_size_align((i * 8) % 128 + 8, 8).unwrap()) }
608 .unwrap();
609 unsafe { *ptr = i as u8 };
610 assert_eq!(unsafe { *ptr }, i as u8);
611 ptrs.push(ptr);
612 }
613
614 for i in 0..100 {
615 let ptr = ptrs[i as usize];
616 assert_eq!(unsafe { *ptr }, i as u8);
617 unsafe { alloc.dealloc(ptr, Layout::from_size_align((i * 8) % 128 + 8, 8).unwrap()) }
618 .unwrap();
619 }
620
621 unsafe { std::alloc::dealloc(mem_region, layout) };
622 }
623}