kernel_api/sync/
once.rs

1#![unstable(feature = "kernel_sync_once", issue = "none")]
2
3use core::cell::UnsafeCell;
4use core::mem::MaybeUninit;
5use core::ops::Deref;
6use core::sync::atomic::{AtomicU8, fence, Ordering};
7
8pub struct Once(AtomicU8);
9
10#[derive(Debug, Copy, Clone, Eq, PartialEq)]
11#[repr(u8)]
12enum State {
13    Uncalled = 0,
14    Running = 1,
15    Called = 2,
16    Poison = 3
17}
18
19impl State {
20    const fn const_into_u8(self) -> u8 {
21        match self {
22            State::Uncalled => 0,
23            State::Running => 1,
24            State::Called => 2,
25            State::Poison => 3
26        }
27    }
28
29    const fn const_from_u8(value: u8) -> Result<Self, ()> {
30        match value {
31            0 => Ok(State::Uncalled),
32            1 => Ok(State::Running),
33            2 => Ok(State::Called),
34            3 => Ok(State::Poison),
35            _ => Err(())
36        }
37    }
38}
39
40#[stable(feature = "kernel_core_api", since = "1.0.0")]
41impl From<State> for u8 {
42    fn from(value: State) -> Self {
43        value.const_into_u8()
44    }
45}
46
47impl TryFrom<u8> for State {
48    type Error = ();
49
50    fn try_from(value: u8) -> Result<Self, Self::Error> {
51        Self::const_from_u8(value)
52    }
53}
54
55impl Once {
56    pub const fn new() -> Self {
57        Self(AtomicU8::new(State::Uncalled.const_into_u8()))
58    }
59
60    pub fn call_once<F: FnOnce()>(&self, f: F) {
61        loop {
62            let current = self.0.compare_exchange_weak(State::Uncalled.into(), State::Running.into(), Ordering::Relaxed, Ordering::Acquire);
63            match current {
64                Ok(_) => break, // Switched from Uncalled to Running, call the function
65                Err(s) if s == State::Poison.into() => panic!("poisoned `Once`"),
66                Err(s) if s == State::Running.into() => {}, // Currently running, spin until state changes
67                Err(s) if s == State::Called.into() => return, // Already called, return immediately
68                Err(s) if s == State::Uncalled.into() => {}, // weak cas fail, try again
69                _ => unreachable!()
70            }
71            core::hint::spin_loop();
72        }
73
74        struct DropGuard<'a>(&'a Once);
75        impl Drop for DropGuard<'_> {
76            fn drop(&mut self) {
77                self.0.0.store(State::Poison.into(), Ordering::Relaxed);
78            }
79        }
80        let drop_guard = DropGuard(self);
81
82        f();
83
84        core::mem::forget(drop_guard);
85
86        self.0.store(State::Called.into(), Ordering::Release);
87    }
88
89    pub fn is_complete(&self) -> bool {
90        let state = self.0.load(Ordering::Relaxed).try_into().unwrap();
91        match state {
92            State::Called => true,
93            _ => false
94        }
95    }
96}
97
98pub struct OnceLock<T> {
99    data: UnsafeCell<MaybeUninit<T>>,
100    once: Once
101}
102
103unsafe impl<T: Send> Send for OnceLock<T> {}
104unsafe impl<T: Send + Sync> Sync for OnceLock<T> {}
105
106impl<T> OnceLock<T> {
107    pub const fn new() -> Self {
108        Self {
109            data: UnsafeCell::new(MaybeUninit::uninit()),
110            once: Once::new()
111        }
112    }
113
114    pub fn get(&self) -> Option<&T> {
115        if !self.once.is_complete() { return None; }
116        fence(Ordering::Acquire);
117
118        unsafe {
119            Some((*self.data.get()).assume_init_ref())
120        }
121    }
122
123    pub fn get_mut(&mut self) -> Option<&mut T> {
124        if !self.once.is_complete() { return None; }
125        fence(Ordering::Acquire);
126
127        unsafe {
128            Some((*self.data.get()).assume_init_mut())
129        }
130    }
131
132    pub fn get_or_init(&self, f: impl FnOnce() -> T) -> &T {
133        self.once.call_once(|| unsafe { (*self.data.get()).write(f()); });
134        unsafe { (*self.data.get()).assume_init_ref() }
135    }
136}
137
138pub struct LazyLock<T, F: FnOnce() -> T = fn() -> T> {
139    once: OnceLock<T>,
140    // FIXME: actually drop this when needed
141    f: MaybeUninit<F>
142}
143
144unsafe impl<T, F: FnOnce() -> T> Send for LazyLock<T, F> {}
145unsafe impl<T, F: FnOnce() -> T> Sync for LazyLock<T, F> {}
146
147impl<T, F: FnOnce() -> T> LazyLock<T, F> {
148    pub const fn new(f: F) -> Self {
149        Self {
150            once: OnceLock::new(),
151            f: MaybeUninit::new(f)
152        }
153    }
154
155    pub fn force(this: &Self) -> &T {
156        this.once.get_or_init(unsafe {
157            core::ptr::read(this.f.as_ptr())
158        })
159    }
160}
161
162impl<T, F: FnOnce() -> T> Deref for LazyLock<T, F> {
163    type Target = T;
164
165    fn deref(&self) -> &Self::Target {
166        Self::force(self)
167    }
168}
169
170pub use bootstrap::BootstrapOnceLock;
171
172mod bootstrap {
173    use core::cell::UnsafeCell;
174    use core::mem::MaybeUninit;
175    use core::sync::atomic::{AtomicU8, fence, Ordering};
176
177    #[derive(Debug, Copy, Clone, Eq, PartialEq)]
178    #[repr(u8)]
179    enum State {
180        Uncalled = 0,
181        Saving = 1,
182        Running = 2,
183        Init = 3,
184        Poison = 4
185    }
186
187    impl State {
188        const fn const_into_u8(self) -> u8 {
189            match self {
190                State::Uncalled => 0,
191                State::Saving => 1,
192                State::Running => 2,
193                State::Init => 3,
194                State::Poison => 4
195            }
196        }
197
198        const fn const_from_u8(value: u8) -> Result<Self, ()> {
199            match value {
200                0 => Ok(State::Uncalled),
201                1 => Ok(State::Saving),
202                2 => Ok(State::Running),
203                3 => Ok(State::Init),
204                4 => Ok(State::Poison),
205                _ => Err(())
206            }
207        }
208    }
209
210    impl From<State> for u8 {
211        fn from(value: State) -> Self {
212            value.const_into_u8()
213        }
214    }
215
216    impl TryFrom<u8> for State {
217        type Error = ();
218
219        fn try_from(value: u8) -> Result<Self, Self::Error> {
220            Self::const_from_u8(value)
221        }
222    }
223
224    pub struct BootstrapOnceLock<T> {
225        data: UnsafeCell<MaybeUninit<T>>,
226        state: AtomicU8
227    }
228
229    unsafe impl<T> Send for BootstrapOnceLock<T> {}
230    unsafe impl<T> Sync for BootstrapOnceLock<T> {}
231
232    impl<T> BootstrapOnceLock<T> {
233        pub const fn new() -> Self {
234            Self {
235                data: UnsafeCell::new(MaybeUninit::uninit()),
236                state: AtomicU8::new(State::Uncalled.const_into_u8())
237            }
238        }
239
240        pub fn get(&self) -> Option<&T> {
241            let state: State = self.state.load(Ordering::Relaxed).try_into().unwrap();
242
243            if state == State::Poison { panic!("poisoned `BootstrapOnceLock`") }
244            if state == State::Uncalled || state == State::Saving { return None; }
245            fence(Ordering::Acquire);
246
247            unsafe {
248                Some((*self.data.get()).assume_init_ref())
249            }
250        }
251
252        /*
253        - Starts as `Uncalled`
254        - Move to `Saving`
255        - Store bootstrap value
256        - Move to `Running` - value is now legal to access
257        - Call function
258        - Move to `Saving` - value is now illegal to access
259        - Store new value
260        - Move to `Init` - value is now illegal to access
261         */
262        pub fn bootstrap(&self, bootstrap_value: T, f: impl FnOnce() -> T) -> &T {
263            loop {
264                let current = self.state.compare_exchange_weak(State::Uncalled.into(), State::Saving.into(), Ordering::Relaxed, Ordering::Acquire);
265                match current {
266                    Ok(_) => break, // Switched from Uncalled to Saving, bootstrap then call the function
267                    Err(s) if s == State::Poison.into() => panic!("poisoned `BootstrapOnceLock`"),
268                    Err(s) if s == State::Running.into() || s == State::Saving.into() => {}, // Currently running, spin until state changes
269                    Err(s) if s == State::Init.into() => {
270                        // Already called, return immediately
271                        return unsafe {
272                            (*self.data.get()).assume_init_ref()
273                        };
274                    },
275                    Err(s) if s == State::Uncalled.into() => {}, // Weak CAS failure so retry
276                    _ => unreachable!()
277                }
278                core::hint::spin_loop();
279            }
280
281            // We now need to bootstrap and init
282            unsafe { (*self.data.get()).write(bootstrap_value); }
283            // Release ordering so bootstrapped value syncs with Acquire ordering in Self::get
284            self.state.store(State::Running.into(), Ordering::Release);
285
286            let true_value = f();
287
288            // Relaxed ordering since no memory stuff to sync with (???)
289            self.state.store(State::Saving.into(), Ordering::Relaxed);
290            let ret = unsafe { (*self.data.get()).write(true_value) };
291            // Release ordering so bootstrapped value syncs with Acquire ordering in Self::get
292            self.state.store(State::Init.into(), Ordering::Release);
293            ret
294        }
295    }
296}