1use std::hint::unreachable_unchecked;
2
3use ftml_solver_trace::traceref;
4use rayon::iter::{IntoParallelIterator, ParallelIterator};
5use smallvec::SmallVec;
6
7use crate::{
8 CheckRef,
9 rules::{
10 CheckerRule,
11 extractors::{RuleExtractor, SymbolRuleExtractor},
12 },
13 trace::{CheckingTask, RefCheckLog},
14};
15
16pub trait Cancellation: Default + Send + Sync {
17 fn is_cancelled(&self) -> bool;
18 fn cancel(&self);
19}
20#[derive(Default)]
21pub struct CancelToken<'a, C: Cancellation> {
22 cancelled: C,
23 parent: Option<&'a Self>,
24}
25impl Cancellation for std::sync::atomic::AtomicBool {
26 #[inline]
27 fn is_cancelled(&self) -> bool {
28 self.load(std::sync::atomic::Ordering::Acquire)
29 }
30 #[inline]
31 fn cancel(&self) {
32 self.store(true, std::sync::atomic::Ordering::Release);
33 }
34}
35impl<C: Cancellation> CancelToken<'_, C> {
36 pub fn is_cancelled(&self) -> bool {
37 self.cancelled.is_cancelled()
38 || self.parent.as_ref().is_some_and(|tk| {
39 tk.is_cancelled() && {
40 self.cancelled.cancel();
41 true
42 }
43 })
44 }
45 #[inline]
46 pub fn cancel(&self) {
47 self.cancelled.cancel();
48 }
49 pub fn derive(&self) -> CancelToken<'_, C> {
50 CancelToken {
51 cancelled: C::default(),
52 parent: Some(self),
53 }
54 }
55}
56impl Cancellation for () {
57 #[inline]
58 fn is_cancelled(&self) -> bool {
59 false
60 }
61 #[inline(always)]
62 fn cancel(&self) {}
63}
64
65pub trait SplitStrategy:
66 Send + Sync + Sized + 'static + Copy + Clone + Default + std::fmt::Debug + PartialEq + Eq
67{
68 type CancelToken: Cancellation;
69 const SYMBOL_EXTRACTORS: &[SymbolRuleExtractor<Self>] =
70 { super::rules::extractors::all_symbol_extractors() };
71
72 const RULE_EXTRACTORS: &[RuleExtractor<Self>] =
73 { super::rules::extractors::all_rule_extractors() };
74
75 fn strategies<'t, A, B, R>(
76 solver: &mut CheckRef<'t, '_, Self>,
77 strategy_a: &'static str,
78 oper_a: A,
79 strategy_b: &'static str,
80 oper_b: B,
81 ) -> Option<R>
82 where
83 A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
84 B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
85 R: Send + std::fmt::Debug + Clone;
86
87 #[allow(clippy::result_large_err)]
89 fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
90 slf: &mut CheckRef<'t, '_, Self>,
91 msg: bool,
92 rules: smallvec::SmallVec<&'t Rl, 2>, then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
94 ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>>;
95
96 fn split<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
97 slf: &mut CheckRef<'t, '_, Self>,
98 msg: bool,
99 rules: smallvec::SmallVec<&'t Rl, 2>, then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
101 ) -> Option<R> {
102 match Self::split_i(slf, msg, rules, then) {
103 Ok(r) => Some(r),
104 Err(ls) => {
105 for e in ls {
106 slf.add_msg(e.into());
107 }
108 None
109 }
110 }
111 }
112
113 fn strategies_st<'t, A, B, R>(
116 solver: &mut CheckRef<'t, '_, Self>,
117 strategy_a: &'static str,
118 oper_a: A,
119 strategy_b: &'static str,
120 oper_b: B,
121 ) -> Option<R>
122 where
123 A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
124 B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
125 R: Send + std::fmt::Debug + Clone,
126 {
127 let l1 = match solver
128 .branch_traced(CheckingTask::Strategy(strategy_a), |mut c| oper_a(&mut c))
129 {
130 Ok(r) => return Some(r),
131 Err(l) => l,
132 };
133 match solver.traced(CheckingTask::Strategy(strategy_b), oper_b) {
134 Ok(r) => Some(r),
135 Err(l) => {
136 solver.add_msg(l1.into());
137 solver.add_msg(l.into());
138 None
139 }
140 }
141 }
142
143 #[allow(clippy::result_large_err)]
145 fn split_i_st<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
146 slf: &mut CheckRef<'t, '_, Self>,
147 msg: bool,
148 rules: smallvec::SmallVec<&'t Rl, 2>, then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
150 ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
151 if rules.is_empty() {
153 return Err(if msg {
154 smallvec::smallvec![traceref!(FAIL "No rule applicable")]
155 } else {
156 smallvec::SmallVec::default()
157 });
158 }
159 let mut failures = SmallVec::<_, 2>::new();
160 for rule in rules {
161 match slf.branch_traced(CheckingTask::Rule(rule.as_dyn()), |slf| {
162 tracing::debug!("Applying rule {rule:?}");
163 then(slf, rule)
164 }) {
165 Ok(r) => {
166 return Ok(r);
167 }
168 Err(l) => failures.push(l),
169 }
170 }
171 Err(failures)
172 }
173
174 fn strategies_mt<'t, A, B, R>(
175 solver: &mut CheckRef<'t, '_, Self>,
176 strategy_a: &'static str,
177 oper_a: A,
178 strategy_b: &'static str,
179 oper_b: B,
180 ) -> Option<R>
181 where
182 A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
183 B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
184 R: Send + std::fmt::Debug + Clone,
185 {
186 solver.cancellable(|solver| {
188 let mut top1 = solver.copied();
189 let mut top2 = solver.copied();
190 let (a, b) = rayon::join(
191 || {
192 let mut solver = top1.get_ref();
193 solver
194 .traced(CheckingTask::Strategy(strategy_a), oper_a)
195 .inspect(|_| solver.cancel.cancel())
196 },
197 || {
198 let mut solver = top2.get_ref();
199 solver
200 .traced(CheckingTask::Strategy(strategy_b), oper_b)
201 .inspect(|_| solver.cancel.cancel())
202 },
203 );
204 match (a, b) {
205 (Ok(r), b) => {
206 drop(b);
207 top1.close(solver);
209 Some(r)
210 }
211 (a, Ok(r)) => {
212 drop(a);
213 top2.close(solver);
215 Some(r)
216 }
217 (Err(i1), Err(i2)) => {
218 solver.add_msg(i1.into());
219 solver.add_msg(i2.into());
220 None
221 }
222 }
223 })
224 }
225
226 #[allow(clippy::result_large_err)]
228 fn split_i_mt<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
229 slf: &mut CheckRef<'t, '_, Self>,
230 msg: bool,
231 rules: smallvec::SmallVec<&'t Rl, 2>,
232 then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
233 ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
234 let then = &then;
235 macro_rules! then {
236 ($rl:ident !) => {{
237 let mut top = slf.copied();
238 let mut slf = top.get_ref();
239 slf.branch_traced(CheckingTask::Rule($rl.as_dyn()), move |slf| then(slf, $rl))
240 .inspect(|_| slf.cancel.cancel())
241 }};
242 ($rl:expr) => {{
243 slf.branch_traced(CheckingTask::Rule($rl.as_dyn()), move |slf| then(slf, $rl))
244 .inspect(|_| slf.cancel.cancel())
245 }};
246 }
247
248 match rules.len() {
249 0 => {
250 return Err(if msg {
251 smallvec::smallvec![traceref!(FAIL "No rule applicable")]
252 } else {
253 smallvec::SmallVec::default()
254 });
255 }
256 1 => {
257 let rule = unsafe { rules.first().unwrap_unchecked() };
259 return slf
260 .branch_traced(CheckingTask::Rule(rule.as_dyn()), |slf| then(slf, rule))
261 .map_err(|l| smallvec::smallvec![l]);
262 }
263 2 => {
264 let [rule_a, rule_b] = &*rules else {
266 unsafe { unreachable_unchecked() }
267 };
268 return match rayon::join(|| then!(rule_a!), || then!(rule_b!)) {
269 (Ok(t), _) | (_, Ok(t)) => Ok(t),
270 (Err(i1), Err(i2)) => Err(smallvec::smallvec_inline![i1, i2]),
271 };
272 }
273 _ => (),
274 }
275 let result = parking_lot::Mutex::new(None);
276 let failures = parking_lot::Mutex::new(SmallVec::<_, 2>::new());
277 rules.into_vec().into_par_iter().for_each(|rule| {
278 if slf.cancel.is_cancelled() {
279 return;
280 }
281 match then!(rule!) {
282 Ok(r) => *result.lock() = Some(r),
283 Err(l) => failures.lock().push(l),
284 }
285 });
286 result
287 .into_inner()
288 .map_or_else(|| Err(failures.into_inner()), Ok)
289 }
290}
291
292#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
293pub struct SingleThreadedSplit;
294impl SplitStrategy for SingleThreadedSplit {
295 type CancelToken = ();
296
297 #[inline]
298 fn strategies<'t, A, B, R>(
299 solver: &mut CheckRef<'t, '_, Self>,
300 strategy_a: &'static str,
301 oper_a: A,
302 strategy_b: &'static str,
303 oper_b: B,
304 ) -> Option<R>
305 where
306 A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
307 B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
308 R: Send + std::fmt::Debug + Clone,
309 {
310 Self::strategies_st(solver, strategy_a, oper_a, strategy_b, oper_b)
311 }
312
313 #[inline]
314 fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
315 slf: &mut CheckRef<'t, '_, Self>,
316 msg: bool,
317 rules: smallvec::SmallVec<&'t Rl, 2>, then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
319 ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
320 Self::split_i_st(slf, msg, rules, then)
321 }
322}
323
324#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
325pub struct RayonStrategiesDepth<const DEPTH: usize>;
326impl<const DEPTH: usize> SplitStrategy for RayonStrategiesDepth<DEPTH> {
327 type CancelToken = std::sync::atomic::AtomicBool;
328 #[inline]
329 fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
330 slf: &mut CheckRef<'t, '_, Self>,
331 msg: bool,
332 rules: smallvec::SmallVec<&'t Rl, 2>, then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
334 ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
335 Self::split_i_st(slf, msg, rules, then)
336 }
337
338 #[inline]
339 fn strategies<'t, A, B, R>(
340 solver: &mut CheckRef<'t, '_, Self>,
341 strategy_a: &'static str,
342 oper_a: A,
343 strategy_b: &'static str,
344 oper_b: B,
345 ) -> Option<R>
346 where
347 A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
348 B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
349 R: Send + std::fmt::Debug + Clone,
350 {
351 if solver.depth() <= DEPTH {
352 Self::strategies_mt(solver, strategy_a, oper_a, strategy_b, oper_b)
353 } else {
354 Self::strategies_st(solver, strategy_a, oper_a, strategy_b, oper_b)
355 }
356 }
357}
358
359#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
360pub struct RayonStrategiesOnly;
361impl SplitStrategy for RayonStrategiesOnly {
362 type CancelToken = std::sync::atomic::AtomicBool;
363
364 #[inline]
365 fn strategies<'t, A, B, R>(
366 solver: &mut CheckRef<'t, '_, Self>,
367 strategy_a: &'static str,
368 oper_a: A,
369 strategy_b: &'static str,
370 oper_b: B,
371 ) -> Option<R>
372 where
373 A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
374 B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
375 R: Send + std::fmt::Debug + Clone,
376 {
377 Self::strategies_mt(solver, strategy_a, oper_a, strategy_b, oper_b)
378 }
379
380 #[inline]
381 fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
382 slf: &mut CheckRef<'t, '_, Self>,
383 msg: bool,
384 rules: smallvec::SmallVec<&'t Rl, 2>, then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
386 ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
387 Self::split_i_st(slf, msg, rules, then)
388 }
389}
390
391#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
392pub struct RayonSplit;
393impl SplitStrategy for RayonSplit {
394 type CancelToken = std::sync::atomic::AtomicBool;
395
396 #[inline]
397 fn strategies<'t, A, B, R>(
398 solver: &mut CheckRef<'t, '_, Self>,
399 strategy_a: &'static str,
400 oper_a: A,
401 strategy_b: &'static str,
402 oper_b: B,
403 ) -> Option<R>
404 where
405 A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
406 B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
407 R: Send + std::fmt::Debug + Clone,
408 {
409 Self::strategies_mt(solver, strategy_a, oper_a, strategy_b, oper_b)
410 }
411
412 #[inline]
413 fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
414 slf: &mut CheckRef<'t, '_, Self>,
415 msg: bool,
416 rules: smallvec::SmallVec<&'t Rl, 2>,
417 then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
418 ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
419 Self::split_i_mt(slf, msg, rules, then)
420 }
421}