Skip to main content

ftml_solver/
split.rs

1use std::hint::unreachable_unchecked;
2
3use ftml_solver_trace::traceref;
4use rayon::iter::{IntoParallelIterator, ParallelIterator};
5use smallvec::SmallVec;
6
7use crate::{
8    CheckRef,
9    rules::{
10        CheckerRule,
11        extractors::{RuleExtractor, SymbolRuleExtractor},
12    },
13    trace::{CheckingTask, RefCheckLog},
14};
15
16pub trait Cancellation: Default + Send + Sync {
17    fn is_cancelled(&self) -> bool;
18    fn cancel(&self);
19}
20#[derive(Default)]
21pub struct CancelToken<'a, C: Cancellation> {
22    cancelled: C,
23    parent: Option<&'a Self>,
24}
25impl Cancellation for std::sync::atomic::AtomicBool {
26    #[inline]
27    fn is_cancelled(&self) -> bool {
28        self.load(std::sync::atomic::Ordering::Acquire)
29    }
30    #[inline]
31    fn cancel(&self) {
32        self.store(true, std::sync::atomic::Ordering::Release);
33    }
34}
35impl<C: Cancellation> CancelToken<'_, C> {
36    pub fn is_cancelled(&self) -> bool {
37        self.cancelled.is_cancelled()
38            || self.parent.as_ref().is_some_and(|tk| {
39                tk.is_cancelled() && {
40                    self.cancelled.cancel();
41                    true
42                }
43            })
44    }
45    #[inline]
46    pub fn cancel(&self) {
47        self.cancelled.cancel();
48    }
49    pub fn derive(&self) -> CancelToken<'_, C> {
50        CancelToken {
51            cancelled: C::default(),
52            parent: Some(self),
53        }
54    }
55}
56impl Cancellation for () {
57    #[inline]
58    fn is_cancelled(&self) -> bool {
59        false
60    }
61    #[inline(always)]
62    fn cancel(&self) {}
63}
64
65pub trait SplitStrategy:
66    Send + Sync + Sized + 'static + Copy + Clone + Default + std::fmt::Debug + PartialEq + Eq
67{
68    type CancelToken: Cancellation;
69    const SYMBOL_EXTRACTORS: &[SymbolRuleExtractor<Self>] =
70        { super::rules::extractors::all_symbol_extractors() };
71
72    const RULE_EXTRACTORS: &[RuleExtractor<Self>] =
73        { super::rules::extractors::all_rule_extractors() };
74
75    fn strategies<'t, A, B, R>(
76        solver: &mut CheckRef<'t, '_, Self>,
77        strategy_a: &'static str,
78        oper_a: A,
79        strategy_b: &'static str,
80        oper_b: B,
81    ) -> Option<R>
82    where
83        A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
84        B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
85        R: Send + std::fmt::Debug + Clone;
86
87    /// ### Errors
88    #[allow(clippy::result_large_err)]
89    fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
90        slf: &mut CheckRef<'t, '_, Self>,
91        msg: bool,
92        rules: smallvec::SmallVec<&'t Rl, 2>, //smallvec::SmallVec<&Rl, 2>,
93        then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
94    ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>>;
95
96    fn split<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
97        slf: &mut CheckRef<'t, '_, Self>,
98        msg: bool,
99        rules: smallvec::SmallVec<&'t Rl, 2>, //smallvec::SmallVec<&Rl, 2>,
100        then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
101    ) -> Option<R> {
102        match Self::split_i(slf, msg, rules, then) {
103            Ok(r) => Some(r),
104            Err(ls) => {
105                for e in ls {
106                    slf.add_msg(e.into());
107                }
108                None
109            }
110        }
111    }
112
113    // -----------------------------
114
115    fn strategies_st<'t, A, B, R>(
116        solver: &mut CheckRef<'t, '_, Self>,
117        strategy_a: &'static str,
118        oper_a: A,
119        strategy_b: &'static str,
120        oper_b: B,
121    ) -> Option<R>
122    where
123        A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
124        B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
125        R: Send + std::fmt::Debug + Clone,
126    {
127        let l1 = match solver
128            .branch_traced(CheckingTask::Strategy(strategy_a), |mut c| oper_a(&mut c))
129        {
130            Ok(r) => return Some(r),
131            Err(l) => l,
132        };
133        match solver.traced(CheckingTask::Strategy(strategy_b), oper_b) {
134            Ok(r) => Some(r),
135            Err(l) => {
136                solver.add_msg(l1.into());
137                solver.add_msg(l.into());
138                None
139            }
140        }
141    }
142
143    /// ### Errors
144    #[allow(clippy::result_large_err)]
145    fn split_i_st<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
146        slf: &mut CheckRef<'t, '_, Self>,
147        msg: bool,
148        rules: smallvec::SmallVec<&'t Rl, 2>, //smallvec::SmallVec<&Rl, 2>,
149        then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
150    ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
151        //let mut rules = rules.peekable();
152        if rules.is_empty() {
153            return Err(if msg {
154                smallvec::smallvec![traceref!(FAIL "No rule applicable")]
155            } else {
156                smallvec::SmallVec::default()
157            });
158        }
159        let mut failures = SmallVec::<_, 2>::new();
160        for rule in rules {
161            match slf.branch_traced(CheckingTask::Rule(rule.as_dyn()), |slf| {
162                tracing::debug!("Applying rule {rule:?}");
163                then(slf, rule)
164            }) {
165                Ok(r) => {
166                    return Ok(r);
167                }
168                Err(l) => failures.push(l),
169            }
170        }
171        Err(failures)
172    }
173
174    fn strategies_mt<'t, A, B, R>(
175        solver: &mut CheckRef<'t, '_, Self>,
176        strategy_a: &'static str,
177        oper_a: A,
178        strategy_b: &'static str,
179        oper_b: B,
180    ) -> Option<R>
181    where
182        A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
183        B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
184        R: Send + std::fmt::Debug + Clone,
185    {
186        //solver.comment("Splitting");
187        solver.cancellable(|solver| {
188            let mut top1 = solver.copied();
189            let mut top2 = solver.copied();
190            let (a, b) = rayon::join(
191                || {
192                    let mut solver = top1.get_ref();
193                    solver
194                        .traced(CheckingTask::Strategy(strategy_a), oper_a)
195                        .inspect(|_| solver.cancel.cancel())
196                },
197                || {
198                    let mut solver = top2.get_ref();
199                    solver
200                        .traced(CheckingTask::Strategy(strategy_b), oper_b)
201                        .inspect(|_| solver.cancel.cancel())
202                },
203            );
204            match (a, b) {
205                (Ok(r), b) => {
206                    drop(b);
207                    //solver.comment("1 Succeeded");
208                    top1.close(solver);
209                    Some(r)
210                }
211                (a, Ok(r)) => {
212                    drop(a);
213                    //solver.comment("2 Succeeded");
214                    top2.close(solver);
215                    Some(r)
216                }
217                (Err(i1), Err(i2)) => {
218                    solver.add_msg(i1.into());
219                    solver.add_msg(i2.into());
220                    None
221                }
222            }
223        })
224    }
225
226    /// ### Errors
227    #[allow(clippy::result_large_err)]
228    fn split_i_mt<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
229        slf: &mut CheckRef<'t, '_, Self>,
230        msg: bool,
231        rules: smallvec::SmallVec<&'t Rl, 2>,
232        then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
233    ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
234        let then = &then;
235        macro_rules! then {
236            ($rl:ident !) => {{
237                let mut top = slf.copied();
238                let mut slf = top.get_ref();
239                slf.branch_traced(CheckingTask::Rule($rl.as_dyn()), move |slf| then(slf, $rl))
240                    .inspect(|_| slf.cancel.cancel())
241            }};
242            ($rl:expr) => {{
243                slf.branch_traced(CheckingTask::Rule($rl.as_dyn()), move |slf| then(slf, $rl))
244                    .inspect(|_| slf.cancel.cancel())
245            }};
246        }
247
248        match rules.len() {
249            0 => {
250                return Err(if msg {
251                    smallvec::smallvec![traceref!(FAIL "No rule applicable")]
252                } else {
253                    smallvec::SmallVec::default()
254                });
255            }
256            1 => {
257                // SAFETY: len == 1
258                let rule = unsafe { rules.first().unwrap_unchecked() };
259                return slf
260                    .branch_traced(CheckingTask::Rule(rule.as_dyn()), |slf| then(slf, rule))
261                    .map_err(|l| smallvec::smallvec![l]);
262            }
263            2 => {
264                // SAFETY: len == 2
265                let [rule_a, rule_b] = &*rules else {
266                    unsafe { unreachable_unchecked() }
267                };
268                return match rayon::join(|| then!(rule_a!), || then!(rule_b!)) {
269                    (Ok(t), _) | (_, Ok(t)) => Ok(t),
270                    (Err(i1), Err(i2)) => Err(smallvec::smallvec_inline![i1, i2]),
271                };
272            }
273            _ => (),
274        }
275        let result = parking_lot::Mutex::new(None);
276        let failures = parking_lot::Mutex::new(SmallVec::<_, 2>::new());
277        rules.into_vec().into_par_iter().for_each(|rule| {
278            if slf.cancel.is_cancelled() {
279                return;
280            }
281            match then!(rule!) {
282                Ok(r) => *result.lock() = Some(r),
283                Err(l) => failures.lock().push(l),
284            }
285        });
286        result
287            .into_inner()
288            .map_or_else(|| Err(failures.into_inner()), Ok)
289    }
290}
291
292#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
293pub struct SingleThreadedSplit;
294impl SplitStrategy for SingleThreadedSplit {
295    type CancelToken = ();
296
297    #[inline]
298    fn strategies<'t, A, B, R>(
299        solver: &mut CheckRef<'t, '_, Self>,
300        strategy_a: &'static str,
301        oper_a: A,
302        strategy_b: &'static str,
303        oper_b: B,
304    ) -> Option<R>
305    where
306        A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
307        B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
308        R: Send + std::fmt::Debug + Clone,
309    {
310        Self::strategies_st(solver, strategy_a, oper_a, strategy_b, oper_b)
311    }
312
313    #[inline]
314    fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
315        slf: &mut CheckRef<'t, '_, Self>,
316        msg: bool,
317        rules: smallvec::SmallVec<&'t Rl, 2>, //smallvec::SmallVec<&Rl, 2>,
318        then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
319    ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
320        Self::split_i_st(slf, msg, rules, then)
321    }
322}
323
324#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
325pub struct RayonStrategiesDepth<const DEPTH: usize>;
326impl<const DEPTH: usize> SplitStrategy for RayonStrategiesDepth<DEPTH> {
327    type CancelToken = std::sync::atomic::AtomicBool;
328    #[inline]
329    fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
330        slf: &mut CheckRef<'t, '_, Self>,
331        msg: bool,
332        rules: smallvec::SmallVec<&'t Rl, 2>, //smallvec::SmallVec<&Rl, 2>,
333        then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
334    ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
335        Self::split_i_st(slf, msg, rules, then)
336    }
337
338    #[inline]
339    fn strategies<'t, A, B, R>(
340        solver: &mut CheckRef<'t, '_, Self>,
341        strategy_a: &'static str,
342        oper_a: A,
343        strategy_b: &'static str,
344        oper_b: B,
345    ) -> Option<R>
346    where
347        A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
348        B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
349        R: Send + std::fmt::Debug + Clone,
350    {
351        if solver.depth() <= DEPTH {
352            Self::strategies_mt(solver, strategy_a, oper_a, strategy_b, oper_b)
353        } else {
354            Self::strategies_st(solver, strategy_a, oper_a, strategy_b, oper_b)
355        }
356    }
357}
358
359#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
360pub struct RayonStrategiesOnly;
361impl SplitStrategy for RayonStrategiesOnly {
362    type CancelToken = std::sync::atomic::AtomicBool;
363
364    #[inline]
365    fn strategies<'t, A, B, R>(
366        solver: &mut CheckRef<'t, '_, Self>,
367        strategy_a: &'static str,
368        oper_a: A,
369        strategy_b: &'static str,
370        oper_b: B,
371    ) -> Option<R>
372    where
373        A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
374        B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
375        R: Send + std::fmt::Debug + Clone,
376    {
377        Self::strategies_mt(solver, strategy_a, oper_a, strategy_b, oper_b)
378    }
379
380    #[inline]
381    fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
382        slf: &mut CheckRef<'t, '_, Self>,
383        msg: bool,
384        rules: smallvec::SmallVec<&'t Rl, 2>, //smallvec::SmallVec<&Rl, 2>,
385        then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
386    ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
387        Self::split_i_st(slf, msg, rules, then)
388    }
389}
390
391#[derive(Copy, Clone, Default, Debug, PartialEq, Eq)]
392pub struct RayonSplit;
393impl SplitStrategy for RayonSplit {
394    type CancelToken = std::sync::atomic::AtomicBool;
395
396    #[inline]
397    fn strategies<'t, A, B, R>(
398        solver: &mut CheckRef<'t, '_, Self>,
399        strategy_a: &'static str,
400        oper_a: A,
401        strategy_b: &'static str,
402        oper_b: B,
403    ) -> Option<R>
404    where
405        A: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
406        B: FnOnce(&mut CheckRef<'t, '_, Self>) -> Option<R> + Send,
407        R: Send + std::fmt::Debug + Clone,
408    {
409        Self::strategies_mt(solver, strategy_a, oper_a, strategy_b, oper_b)
410    }
411
412    #[inline]
413    fn split_i<'t, Rl: CheckerRule + ?Sized, R: Send + std::fmt::Debug + Clone + 'static>(
414        slf: &mut CheckRef<'t, '_, Self>,
415        msg: bool,
416        rules: smallvec::SmallVec<&'t Rl, 2>,
417        then: impl Fn(CheckRef<'t, '_, Self>, &Rl) -> Option<R> + Send + Sync,
418    ) -> Result<R, smallvec::SmallVec<RefCheckLog<'t>, 2>> {
419        Self::split_i_mt(slf, msg, rules, then)
420    }
421}