1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
//! Constraint solving
//!
//! The final phase iterates over the constraints, refining the variance
//! for each inferred until a fixed point is reached. This will be the
//! optimal solution to the constraints. The final variance for each
//! inferred is then written into the `variance_map` in the tcx.

use rustc_data_structures::fx::FxHashMap;
use rustc_hir::def_id::DefId;
use rustc_middle::ty;

use super::constraints::*;
use super::terms::VarianceTerm::*;
use super::terms::*;
use super::xform::*;

struct SolveContext<'a, 'tcx> {
    terms_cx: TermsContext<'a, 'tcx>,
    constraints: Vec<Constraint<'a>>,

    // Maps from an InferredIndex to the inferred value for that variable.
    solutions: Vec<ty::Variance>,
}

pub fn solve_constraints<'tcx>(
    constraints_cx: ConstraintContext<'_, 'tcx>,
) -> ty::CrateVariancesMap<'tcx> {
    let ConstraintContext { terms_cx, constraints, .. } = constraints_cx;

    let mut solutions = vec![ty::Bivariant; terms_cx.inferred_terms.len()];
    for &(id, ref variances) in &terms_cx.lang_items {
        let InferredIndex(start) = terms_cx.inferred_starts[&id];
        for (i, &variance) in variances.iter().enumerate() {
            solutions[start + i] = variance;
        }
    }

    let mut solutions_cx = SolveContext { terms_cx, constraints, solutions };
    solutions_cx.solve();
    let variances = solutions_cx.create_map();

    ty::CrateVariancesMap { variances }
}

impl<'a, 'tcx> SolveContext<'a, 'tcx> {
    fn solve(&mut self) {
        // Propagate constraints until a fixed point is reached.  Note
        // that the maximum number of iterations is 2C where C is the
        // number of constraints (each variable can change values at most
        // twice). Since number of constraints is linear in size of the
        // input, so is the inference process.
        let mut changed = true;
        while changed {
            changed = false;

            for constraint in &self.constraints {
                let Constraint { inferred, variance: term } = *constraint;
                let InferredIndex(inferred) = inferred;
                let variance = self.evaluate(term);
                let old_value = self.solutions[inferred];
                let new_value = glb(variance, old_value);
                if old_value != new_value {
                    debug!(
                        "updating inferred {} \
                            from {:?} to {:?} due to {:?}",
                        inferred, old_value, new_value, term
                    );

                    self.solutions[inferred] = new_value;
                    changed = true;
                }
            }
        }
    }

    fn enforce_const_invariance(&self, generics: &ty::Generics, variances: &mut [ty::Variance]) {
        let tcx = self.terms_cx.tcx;

        // Make all const parameters invariant.
        for param in generics.params.iter() {
            if let ty::GenericParamDefKind::Const { .. } = param.kind {
                variances[param.index as usize] = ty::Invariant;
            }
        }

        // Make all the const parameters in the parent invariant (recursively).
        if let Some(def_id) = generics.parent {
            self.enforce_const_invariance(tcx.generics_of(def_id), variances);
        }
    }

    fn create_map(&self) -> FxHashMap<DefId, &'tcx [ty::Variance]> {
        let tcx = self.terms_cx.tcx;

        let solutions = &self.solutions;
        self.terms_cx
            .inferred_starts
            .iter()
            .map(|(&def_id, &InferredIndex(start))| {
                let generics = tcx.generics_of(def_id);
                let count = generics.count();

                let variances = tcx.arena.alloc_slice(&solutions[start..(start + count)]);

                // Const parameters are always invariant.
                self.enforce_const_invariance(generics, variances);

                // Functions are permitted to have unused generic parameters: make those invariant.
                if let ty::FnDef(..) = tcx.type_of(def_id).kind() {
                    for variance in variances.iter_mut() {
                        if *variance == ty::Bivariant {
                            *variance = ty::Invariant;
                        }
                    }
                }

                (def_id.to_def_id(), &*variances)
            })
            .collect()
    }

    fn evaluate(&self, term: VarianceTermPtr<'a>) -> ty::Variance {
        match *term {
            ConstantTerm(v) => v,

            TransformTerm(t1, t2) => {
                let v1 = self.evaluate(t1);
                let v2 = self.evaluate(t2);
                v1.xform(v2)
            }

            InferredTerm(InferredIndex(index)) => self.solutions[index],
        }
    }
}