Skip to main content

ftml_solver/impls/
equality.rs

1use std::borrow::Cow;
2
3use crate::{
4    CheckRef, impls::solving::TermExtSolvable, rules::implicits::ImplicitExtApp,
5    split::SplitStrategy, trace::CheckingTask,
6};
7use ftml_ontology::terms::{
8    ApplicationTerm, Argument, BindingTerm, BoundArgument, ComponentVar, MaybeSequence, Term,
9    Variable, eq::Alpha,
10};
11use ftml_solver_trace::traceref;
12
13fn same_shape(lhs: &Term, rhs: &Term) -> bool {
14    if lhs.is_solvable().is_some() || rhs.is_solvable().is_some() {
15        return true;
16    }
17    matches!(
18        (lhs, rhs),
19        (Term::Symbol { .. }, Term::Symbol { .. })
20            | (Term::Var { .. }, Term::Var { .. })
21            | (Term::Field(_), Term::Field(_))
22            | (Term::Label { .. }, Term::Label { .. })
23            | (Term::Number(_), Term::Number(_))
24            | (Term::Application(_), Term::Application(_))
25            | (Term::Bound(_), Term::Bound(_))
26    )
27    /*
28    if r {
29        tracing::warn!(
30            "Same shape:\n  - {:?}\n  - {:?}",
31            lhs.debug_short(),
32            rhs.debug_short()
33        );
34    } else {
35        tracing::error!(
36            "Not same shape:\n  - {:?}\n  - {:?}",
37            lhs.debug_short(),
38            rhs.debug_short()
39        );
40    }
41    r
42     */
43}
44
45impl<'t, Split: SplitStrategy> CheckRef<'t, '_, Split> {
46    pub fn check_equality(&mut self, lhs: &'t Term, rhs: &'t Term) -> Option<bool> {
47        tracing::debug!(
48            "Checking equality {:?}   ==   {:?}",
49            lhs.debug_short(),
50            rhs.debug_short()
51        );
52        self.wrap_check(CheckingTask::Equality(lhs, rhs), |slf| {
53            slf.check_equality_i(lhs, rhs)
54        })
55    }
56    pub(crate) fn check_equality_i(&mut self, lhs: &'t Term, rhs: &'t Term) -> Option<bool> {
57        if lhs.alpha_equal(rhs) {
58            self.comment("trivial");
59            return Some(true);
60        }
61        if let Some(unk) = lhs.is_solvable() {
62            return self.solve_equality(unk, rhs);
63        }
64        if let Some(unk) = rhs.is_solvable() {
65            return self.solve_equality(unk, lhs);
66        }
67        let lhs = self.subst(lhs.clone());
68        let rhs = self.subst(rhs.clone());
69        self.scoped(|slf| {
70            match slf.simplify_rules_two(
71                slf.top.rules.equality(),
72                &lhs,
73                &rhs,
74                |slf, rl, lhs, rhs| rl.applicable(lhs, rhs),
75                |slf, rl, lhs, rhs| rl.apply(slf, lhs, rhs),
76                |lhs, rhs| {
77                    lhs.alpha_equal(rhs)
78                        || lhs.is_solvable().is_some()
79                        || rhs.is_solvable().is_some()
80                },
81            ) {
82                either::Left(opt) => {
83                    if opt.is_some() {
84                        if opt == Some(false) {
85                            slf.failure("Disproven");
86                        }
87                        return opt;
88                    }
89                }
90                either::Right((lhs, rhs)) => {
91                    if lhs.alpha_equal(&rhs) {
92                        slf.comment("trivial");
93                        return Some(true);
94                    }
95                    if let Some(unk) = lhs.is_solvable() {
96                        return slf.solve_equality(unk, &rhs);
97                    }
98                    if let Some(unk) = rhs.is_solvable() {
99                        return slf.solve_equality(unk, &lhs);
100                    }
101                    return slf.scoped(|slf| slf.congruence(&lhs, &rhs));
102                }
103            }
104
105            slf.congruence(&lhs, &rhs)
106        })
107    }
108
109    fn congruence(&mut self, lhs: &'t Term, rhs: &'t Term) -> Option<bool> {
110        tracing::debug!("Trying congruence");
111        let Some((lhs, rhs)) = self.simplify_until_two(lhs, rhs, |_, lhs, rhs| {
112            lhs.unapply_implicits().is_some()
113                || rhs.unapply_implicits().is_some()
114                || same_shape(lhs, rhs)
115        }) else {
116            return self.congruence_cont(lhs, rhs);
117        };
118        match (lhs, rhs) {
119            (Cow::Borrowed(lhs), Cow::Borrowed(rhs)) => self.congruence_i(lhs, rhs),
120            (lhs, rhs) => self.scoped(|slf| slf.congruence_i(&lhs, &rhs)),
121        }
122    }
123
124    fn congruence_i(&mut self, lhs: &'t Term, rhs: &'t Term) -> Option<bool> {
125        if lhs.unapply_implicits().is_some() || rhs.unapply_implicits().is_some() {
126            let nlhs = self
127                .simplify_implicit(lhs)
128                .map_or(Cow::Borrowed(lhs), Cow::Owned);
129            let nrhs = self
130                .simplify_implicit(rhs)
131                .map_or(Cow::Borrowed(rhs), Cow::Owned);
132            if !lhs.alpha_equal(&nlhs) || !rhs.alpha_equal(&nrhs) {
133                return self.scoped(|slf| slf.congruence(&nlhs, &nrhs));
134            }
135        }
136        match (lhs, rhs) {
137            (Term::Application(l), Term::Application(r))
138                if l.arguments.len() == r.arguments.len() =>
139            {
140                match self.traced(CheckingTask::Strategy("Trying congruence"), |slf| {
141                    slf.congruence_app(l, r)
142                }) {
143                    Ok(r) => Some(r),
144                    Err(l) => {
145                        self.add_msg(l.into());
146                        self.congruence_cont(lhs, rhs)
147                    }
148                }
149            }
150            (Term::Bound(l), Term::Bound(r)) if l.arguments.len() == r.arguments.len() => {
151                match self.traced(CheckingTask::Strategy("Trying congruence"), |slf| {
152                    slf.congruence_bind(l, r)
153                }) {
154                    Ok(r) => Some(r),
155                    Err(l) => {
156                        self.add_msg(l.into());
157                        self.congruence_cont(lhs, rhs)
158                    }
159                }
160            }
161            (Term::Field(a), Term::Field(b)) if a.key == b.key => {
162                self.congruence_cont(&a.record, &b.record)
163            }
164            (Term::Field(a), Term::Field(b)) => Some(false),
165            (Term::Number(a), Term::Number(b)) => Some(a == b),
166            _ => self.congruence_cont(lhs, rhs),
167        }
168    }
169
170    fn congruence_cont(&mut self, lhs: &'t Term, rhs: &'t Term) -> Option<bool> {
171        self.add_msg(traceref!("shapes don't match: ", lhs, " and ", rhs).into());
172        // LAST RESORT
173        let nlhs = self
174            .simplify_full(true, lhs)
175            .map_or(Cow::Borrowed(lhs), Cow::Owned);
176        let nrhs = self
177            .simplify_full(true, rhs)
178            .map_or(Cow::Borrowed(rhs), Cow::Owned);
179        if *lhs != *nlhs || *rhs != *nrhs {
180            self.scoped(|slf| slf.check_equality_i(&nlhs, &nrhs))
181        } else {
182            None
183        }
184        /*
185        // todo: preserve logs on recursive fail
186        if let Some(lhs) = self.simplify_one(true, lhs) {
187            if self.alpha_equal(&lhs, rhs) {
188                self.comment("trivial");
189                return Some(true);
190            }
191            return self.scoped(|slf| slf.check_equality_i(&lhs, rhs));
192        }
193        if let Some(rhs) = self.simplify_one(true, rhs) {
194            if self.alpha_equal(lhs, &rhs) {
195                self.comment("trivial");
196                return Some(true);
197            }
198            return self.scoped(|slf| slf.check_equality_i(lhs, &rhs));
199        }
200        None
201         */
202    }
203
204    // invariant: lhs.arguments.len() == rhs.arguments.len()
205    fn congruence_app(
206        &mut self,
207        lhs: &'t ApplicationTerm,
208        rhs: &'t ApplicationTerm,
209    ) -> Option<bool> {
210        tracing::trace!("Comparing operators");
211        self.comment("Comparing operators");
212        if !self.check_equality(&lhs.head, &rhs.head)? {
213            return None;
214        }
215        for (i, (a, b)) in lhs.arguments.iter().zip(&rhs.arguments).enumerate() {
216            tracing::trace!("Comparing argument {}", i + 1);
217            self.counter("Comparing arguments ", i + 1);
218            if let (Argument::Simple(a), Argument::Simple(b)) = (a, b) {
219                if !self.check_equality(a, b)? {
220                    return None;
221                }
222            } else if let (
223                Argument::Sequence(MaybeSequence::One(a)),
224                Argument::Sequence(MaybeSequence::One(b)),
225            ) = (a, b)
226            {
227                if !self.check_equality(a, b)? {
228                    return None;
229                }
230            } else if let (
231                Argument::Sequence(MaybeSequence::Seq(a)),
232                Argument::Sequence(MaybeSequence::Seq(b)),
233            ) = (a, b)
234            {
235                if a.len() != b.len() {
236                    return None;
237                }
238                for (a, b) in a.iter().zip(b.iter()) {
239                    if !self.check_equality(a, b)? {
240                        return None;
241                    }
242                }
243            } else {
244                self.failure("Arguments don't match");
245                return None;
246            }
247        }
248        Some(true)
249    }
250
251    // invariant: lhs.arguments.len() == rhs.arguments.len()
252    fn congruence_bind(&mut self, lhs: &'t BindingTerm, rhs: &'t BindingTerm) -> Option<bool> {
253        self.comment("Comparing operators");
254        if !self.check_equality(&lhs.head, &rhs.head)? {
255            return None;
256        }
257        let mut substs = Alpha::new();
258        macro_rules! maybe_subst {
259            ($a:expr,$b:expr) => {
260                if substs.is_empty()
261                    || !$a.has_free_such_that(|av| substs.iter().any(|(v, _)| *v == av.name()))
262                {
263                    if !self.check_equality($a, $b)? {
264                        return None;
265                    }
266                } else {
267                    let subst = substs
268                        .iter()
269                        .map(|(n, v)| {
270                            (
271                                n,
272                                Term::Var {
273                                    variable: (*v).clone(),
274                                    presentation: None,
275                                },
276                            )
277                        })
278                        .collect::<smallvec::SmallVec<_, 2>>();
279                    let r = match $a / &*subst {
280                        Cow::Borrowed(_) => self.check_equality($a, $b)?,
281                        Cow::Owned(a) => self.scoped(|slf| slf.check_equality(&a, $b))?,
282                    };
283                    if !r {
284                        return None;
285                    }
286                }
287            };
288        }
289        for (i, (a, b)) in lhs.arguments.iter().zip(&rhs.arguments).enumerate() {
290            self.counter("Comparing arguments ", i + 1);
291            match (a, b) {
292                (BoundArgument::Simple(a), BoundArgument::Simple(b)) => {
293                    maybe_subst!(a, b);
294                }
295                (BoundArgument::Bound(a), BoundArgument::Bound(b))
296                | (
297                    BoundArgument::BoundSeq(MaybeSequence::One(a)),
298                    BoundArgument::BoundSeq(MaybeSequence::One(b)),
299                ) => {
300                    match (a.tp.as_ref(), b.tp.as_ref()) {
301                        (Some(a), Some(b)) => {
302                            maybe_subst!(a, b);
303                        }
304                        (None, None) => (),
305                        _ => return None,
306                    }
307                    match (a.df.as_ref(), b.df.as_ref()) {
308                        (Some(a), Some(b)) => {
309                            maybe_subst!(a, b);
310                        }
311                        (None, None) => (),
312                        _ => return None,
313                    }
314                    self.extend_context(b);
315                    if a.var.name() != b.var.name() {
316                        substs.push((a.var.name(), &b.var));
317                    }
318                }
319                /*
320                (BoundArgument::BoundSeq(MaybeSequence::One(a)), BoundArgument::Bound(b))
321                | (BoundArgument::Bound(a), BoundArgument::BoundSeq(MaybeSequence::One(b)))
322                    if a.var.is_solvable().is_some() || b.var.is_solvable().is_some() =>
323                {
324                    match (a.tp.as_ref(), b.tp.as_ref()) {
325                        (Some(a), Some(b)) => {
326                            maybe_subst!(a, b);
327                        }
328                        (None, None) => (),
329                        _ => return None,
330                    }
331                    match (a.df.as_ref(), b.df.as_ref()) {
332                        (Some(a), Some(b)) => {
333                            maybe_subst!(a, b);
334                        }
335                        (None, None) => (),
336                        _ => return None,
337                    }
338                    self.extend_context(b);
339                    if a.var.name() != b.var.name() {
340                        substs.push((a.var.name(), &b.var));
341                    }
342                }
343                 */
344                _ => {
345                    self.failure(format!("Argument not simple: {a:?}  <-->  {b:?}"));
346                    return None;
347                }
348            }
349        }
350        Some(true)
351    }
352}
353
354/*
355*
356// -----------------------------------------------------------
357
358pub fn alpha_equal_traced(lhs: &Term, rhs: &Term) -> bool {
359    alpha_equal_with_traced(lhs, rhs, &mut Alpha::default())
360}
361
362const CHECK: bool = true;
363
364macro_rules! rep_eq {
365    (false@$lhs:expr,$rhs:expr) => {
366        if CHECK {
367            ::tracing::error!(
368                "Not equal ({},{}): {:?}    and   {:?}",
369                line!(),
370                column!(),
371                $lhs,
372                $rhs
373            );
374            false
375        } else {
376            false
377        }
378    };
379    ($lhs:expr,$rhs:expr) => {
380        if CHECK {
381            $lhs == $rhs || {
382                ::tracing::error!(
383                    "Not equal ({},{}): {:?}    and   {:?}",
384                    line!(),
385                    column!(),
386                    $lhs,
387                    $rhs
388                );
389                false
390            }
391        } else {
392            $lhs == $rhs
393        }
394    };
395}
396
397pub fn alpha_equal_with_traced<'t>(lhs: &'t Term, rhs: &'t Term, alpha: &mut Alpha<'t>) -> bool {
398    if lhs == rhs {
399        return true;
400    }
401    match (lhs, rhs) {
402        (Term::Var { variable: v1, .. }, Term::Var { variable: v2, .. }) => {
403            rep_eq!(v1.name(), v2.name())
404                || alpha.iter().any(|(a, b)| {
405                    (*a == v1.name() && b.name() == v2.name())
406                        || (b.name() == v1.name() && *a == v2.name())
407                })
408        }
409        (Term::Application(a), Term::Application(b))
410            if rep_eq!(a.arguments.len(), b.arguments.len()) =>
411        {
412            alpha_equal_with_traced(&a.head, &b.head, alpha)
413                && a.arguments
414                    .iter()
415                    .zip(b.arguments.iter())
416                    .all(|(a, b)| alpha_arg_traced(a, b, alpha))
417        }
418        (Term::Bound(a), Term::Bound(b)) if rep_eq!(a.arguments.len(), b.arguments.len()) => {
419            let mut pop = 0;
420            if !alpha_equal_with_traced(&a.head, &b.head, alpha)
421                || a.arguments.iter().zip(b.arguments.iter()).any(|(a, b)| {
422                    alpha_barg_traced(a, b, alpha)
423                        .inspect(|i| pop += i)
424                        .is_none()
425                })
426            {
427                return false;
428            }
429            for _ in 0..pop {
430                alpha.pop();
431            }
432            true
433        }
434        (Term::Field(a), Term::Field(b)) => {
435            alpha_equal_with_traced(&a.record, &b.record, alpha) && a.key == b.key
436        }
437        (
438            Term::Label {
439                name: na,
440                df: da,
441                tp: ta,
442            },
443            Term::Label {
444                name: nb,
445                df: db,
446                tp: tb,
447            },
448        ) if *na == *nb => {
449            match (da, db) {
450                (Some(a), Some(b)) => {
451                    if !alpha_equal_with_traced(a, b, alpha) {
452                        return false;
453                    }
454                }
455                (None, None) => (),
456                _ => return false,
457            }
458            match (ta, tb) {
459                (Some(a), Some(b)) => alpha_equal_with_traced(a, b, alpha),
460                (None, None) => true,
461                _ => false,
462            }
463        }
464        (Term::Number(a), Term::Number(b)) => rep_eq!(a, b),
465        _ => rep_eq!(false@lhs,rhs),
466    }
467}
468
469fn alpha_arg_traced<'t>(lhs: &'t Argument, rhs: &'t Argument, alpha: &mut Alpha<'t>) -> bool {
470    match (lhs, rhs) {
471        (Argument::Simple(lhs), Argument::Simple(rhs))
472        | (
473            Argument::Sequence(MaybeSequence::One(lhs)),
474            Argument::Sequence(MaybeSequence::One(rhs)),
475        ) => alpha_equal_with_traced(lhs, rhs, alpha),
476        (
477            Argument::Sequence(MaybeSequence::Seq(lhs)),
478            Argument::Sequence(MaybeSequence::Seq(rhs)),
479        ) if lhs.len() == rhs.len() => lhs
480            .iter()
481            .zip(rhs.iter())
482            .all(|(lhs, rhs)| alpha_equal_with_traced(lhs, rhs, alpha)),
483        _ => rep_eq!(false@lhs,rhs),
484    }
485}
486fn alpha_barg_traced<'t>(
487    lhs: &'t BoundArgument,
488    rhs: &'t BoundArgument,
489    alpha: &mut Alpha<'t>,
490) -> Option<usize> {
491    macro_rules! ret {
492        ($e:expr) => {
493            if $e { Some(0) } else { None }
494        };
495    }
496    match (lhs, rhs) {
497        (BoundArgument::Simple(lhs), BoundArgument::Simple(rhs))
498        | (
499            BoundArgument::Sequence(MaybeSequence::One(lhs)),
500            BoundArgument::Sequence(MaybeSequence::One(rhs)),
501        ) => ret!(alpha_equal_with_traced(lhs, rhs, alpha)),
502        (
503            BoundArgument::Sequence(MaybeSequence::Seq(lhs)),
504            BoundArgument::Sequence(MaybeSequence::Seq(rhs)),
505        ) if rep_eq!(lhs.len(), rhs.len()) => ret!(
506            lhs.iter()
507                .zip(rhs.iter())
508                .all(|(lhs, rhs)| alpha_equal_with_traced(lhs, rhs, alpha))
509        ),
510        (BoundArgument::Bound(lhs), BoundArgument::Bound(rhs))
511        | (
512            BoundArgument::BoundSeq(MaybeSequence::One(lhs)),
513            BoundArgument::BoundSeq(MaybeSequence::One(rhs)),
514        ) => {
515            if alpha_cv_traced(lhs, rhs, alpha) {
516                Some(1)
517            } else {
518                None
519            }
520        }
521        (
522            BoundArgument::BoundSeq(MaybeSequence::Seq(lhs)),
523            BoundArgument::BoundSeq(MaybeSequence::Seq(rhs)),
524        ) if rep_eq!(lhs.len(), rhs.len()) => {
525            if lhs
526                .iter()
527                .zip(rhs.iter())
528                .all(|(a, b)| alpha_cv_traced(a, b, alpha))
529            {
530                Some(lhs.len())
531            } else {
532                None
533            }
534        }
535        _ => None,
536    }
537}
538fn alpha_cv_traced<'t>(
539    lhs: &'t ComponentVar,
540    rhs: &'t ComponentVar,
541    alpha: &mut Alpha<'t>,
542) -> bool {
543    match (lhs.tp.as_ref(), rhs.tp.as_ref()) {
544        (Some(lhs), Some(rhs)) => {
545            if !alpha_equal_with_traced(lhs, rhs, alpha) {
546                return false;
547            }
548        }
549        (None, None) => (),
550        _ => return rep_eq!(false@lhs,rhs),
551    }
552    match (lhs.df.as_ref(), rhs.df.as_ref()) {
553        (Some(lhs), Some(rhs)) => {
554            if !alpha_equal_with_traced(lhs, rhs, alpha) {
555                return false;
556            }
557        }
558        (None, None) => (),
559        _ => return rep_eq!(false@lhs,rhs),
560    }
561    alpha.push((lhs.var.name(), &rhs.var));
562    true
563}
564
565// -----------------------------------------------------------
566*/