Enzyme main
Loading...
Searching...
No Matches
DataFlowLattice.h
Go to the documentation of this file.
1//===- DataFlowLattice.h - Declaration of common dataflow lattices --------===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @inproceedings{NEURIPS2020_9332c513,
11// author = {Moses, William and Churavy, Valentin},
12// booktitle = {Advances in Neural Information Processing Systems},
13// editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H.
14// Lin}, pages = {12472--12485}, publisher = {Curran Associates, Inc.}, title =
15// {Instead of Rewriting Foreign Code for Machine Learning, Automatically
16// Synthesize Fast Gradients}, url =
17// {https://proceedings.neurips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf},
18// volume = {33},
19// year = {2020}
20// }
21//
22//===----------------------------------------------------------------------===//
23//
24// This file contains the declaration of reusable lattices in dataflow analyses.
25//
26//===----------------------------------------------------------------------===//
27
28#ifndef ENZYME_MLIR_ANALYSIS_DATAFLOW_LATTICE_H
29#define ENZYME_MLIR_ANALYSIS_DATAFLOW_LATTICE_H
30
31#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
32#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
33#include "mlir/Analysis/DataFlowFramework.h"
34
35namespace mlir {
36namespace enzyme {
37
38constexpr llvm::StringLiteral undefinedSetString = "<undefined>";
39constexpr llvm::StringLiteral unknownSetString = "<unknown>";
40
41//===----------------------------------------------------------------------===//
42// SetLattice
43//
44// A data structure representing a set of elements. It may be undefined, meaning
45// the analysis has no information about it, or unknown, meaning the analysis
46// has conservatively assumed it could contain anything.
47//===----------------------------------------------------------------------===//
48
49template <typename ValueT> class SetLattice {
50public:
51 enum class State {
52 Undefined, ///< Has not been analyzed yet (lattice bottom).
53 Defined, ///< Has specific elements.
54 Unknown ///< Analyzed and may contain anything (lattice top).
55 };
56
57 SetLattice() : state(State::Undefined) {}
58
59 SetLattice(ValueT single) : state(State::Defined) { elements.insert(single); }
60
61 // TODO(zinenko): deprecate this and use a visitor instead.
62 DenseSet<ValueT> &getElements() {
63 assert(state == State::Defined);
64 return elements;
65 }
66
67 const DenseSet<ValueT> &getElements() const {
68 return const_cast<SetLattice<ValueT> *>(this)->getElements();
69 }
70
71 bool isUnknown() const { return state == State::Unknown; }
72 bool isUndefined() const { return state == State::Undefined; }
73
74 ChangeResult join(const SetLattice<ValueT> &other) {
75 if (isUnknown())
76 return ChangeResult::NoChange;
77 if (isUndefined() && other.isUndefined())
78 return ChangeResult::NoChange;
79 if (other.isUnknown()) {
80 state = State::Unknown;
81 return ChangeResult::Change;
82 }
83
84 ChangeResult result = updateStateToDefined();
85 return insert(other.elements) | result;
86 }
87
88 ChangeResult insert(const DenseSet<ValueT> &newElements) {
89 if (isUnknown())
90 return ChangeResult::NoChange;
91
92 size_t oldSize = elements.size();
93 elements.insert(newElements.begin(), newElements.end());
94 ChangeResult result = elements.size() == oldSize ? ChangeResult::NoChange
95 : ChangeResult::Change;
96 return updateStateToDefined() | result;
97 }
98
99 ChangeResult markUnknown() {
100 if (isUnknown())
101 return ChangeResult::NoChange;
102
103 state = State::Unknown;
104 elements.clear();
105 return ChangeResult::Change;
106 }
107
108 /// Returns true if this set is in the canonical form, i.e. either the state
109 /// is `State::Defined` or the explicit list of classes is empty, but not
110 /// both.
111 bool isCanonical() const {
112 return state == State::Defined || elements.empty();
113 }
114
115 /// Returns an instance of SetLattice known not to have any elements.
116 /// This is different from "undefined" and "unknown". The instance is *not* a
117 /// classical singleton.
118 static const SetLattice<ValueT> &getEmpty() {
119 static const SetLattice<ValueT> empty(State::Defined);
120 return empty;
121 }
122
123 /// Returns an instance of SetLattice in "undefined" state, i.e. without a set
124 /// of elements. This is different from empty set, which indicates that the
125 /// set is known not to contain any elements. The instance is *not* a
126 /// classical singleton, there are other ways of obtaining it.
127 static const SetLattice<ValueT> &getUndefined() { return undefinedSet; }
128
129 /// Returns an instance of SetLattice for the "unknown" class. The instance
130 /// is *not* a classical singleton, there are other ways of obtaining an
131 /// "unknown" alias set.
132 static const SetLattice<ValueT> &getUnknown() { return unknownSet; }
133
134 bool operator==(const SetLattice<ValueT> &other) const {
135 assert(isCanonical() && other.isCanonical());
136 return state == other.state && llvm::equal(elements, other.elements);
137 }
138
139 LLVM_DUMP_METHOD void print(llvm::raw_ostream &os) const {
140 if (isUnknown()) {
141 os << unknownSetString;
142 } else if (isUndefined()) {
143 os << undefinedSetString;
144 } else {
145 llvm::interleaveComma(elements, os << "{");
146 os << "}";
147 }
148 }
149
150 ChangeResult
151 foreachElement(function_ref<ChangeResult(ValueT, State)> callback) const {
152 if (state != State::Defined)
153 return callback(nullptr, state);
154
155 ChangeResult result = ChangeResult::NoChange;
156 for (ValueT element : elements)
157 result |= callback(element, state);
158 return result;
159 }
160
161private:
162 explicit SetLattice(State state) : state(state) {}
163
164 ChangeResult updateStateToDefined() {
165 assert(state != State::Unknown && "cannot go back from unknown state");
166 ChangeResult result = state == State::Undefined ? ChangeResult::Change
167 : ChangeResult::NoChange;
168 state = State::Defined;
169 return result;
170 }
171
172 const static SetLattice<ValueT> unknownSet;
173 const static SetLattice<ValueT> undefinedSet;
174
175 DenseSet<ValueT> elements;
176 State state;
177};
178
179template <typename ValueT>
180const SetLattice<ValueT> SetLattice<ValueT>::unknownSet =
181 SetLattice<ValueT>(SetLattice<ValueT>::State::Unknown);
182
183template <typename ValueT>
184const SetLattice<ValueT> SetLattice<ValueT>::undefinedSet =
185 SetLattice<ValueT>(SetLattice<ValueT>::State::Undefined);
186
187/// Used when serializing to ensure a consistent order.
188bool sortAttributes(Attribute a, Attribute b);
189bool sortArraysLexicographic(ArrayAttr a, ArrayAttr b);
190
191//===----------------------------------------------------------------------===//
192// SparseSetLattice
193//
194// An abstract lattice for sparse analyses that wraps a set lattice.
195//===----------------------------------------------------------------------===//
196
197template <typename ValueT>
198class SparseSetLattice : public dataflow::AbstractSparseLattice {
199public:
200 using AbstractSparseLattice::AbstractSparseLattice;
202 : dataflow::AbstractSparseLattice(value), elements(std::move(elements)) {}
203
204 Attribute serialize(MLIRContext *ctx) const { return serializeSetNaive(ctx); }
205
206 ChangeResult merge(const SetLattice<ValueT> &other) {
207 return elements.join(other);
208 }
209
210 ChangeResult insert(const DenseSet<ValueT> &newElements) {
211 return elements.insert(newElements);
212 }
213
214 ChangeResult markUnknown() { return elements.markUnknown(); }
215
216 bool isUnknown() const { return elements.isUnknown(); }
217
218 bool isUndefined() const { return elements.isUndefined(); }
219
220 const DenseSet<ValueT> &getElements() const { return elements.getElements(); }
221
222protected:
224
225private:
226 Attribute serializeSetNaive(MLIRContext *ctx) const {
227 if (elements.isUndefined())
228 return StringAttr::get(ctx, undefinedSetString);
229 if (elements.isUnknown())
230 return StringAttr::get(ctx, unknownSetString);
231 SmallVector<Attribute> elementsVec;
232 for (Attribute element : elements.getElements()) {
233 elementsVec.push_back(element);
234 }
235 llvm::sort(elementsVec, sortAttributes);
236 return ArrayAttr::get(ctx, elementsVec);
237 }
238};
239
240//===----------------------------------------------------------------------===//
241// MapOfSetsLattice
242//===----------------------------------------------------------------------===//
243
244template <typename KeyT, typename ElementT>
245class MapOfSetsLattice : public dataflow::AbstractDenseLattice {
246public:
247 using AbstractDenseLattice::AbstractDenseLattice;
248
249 Attribute serialize(MLIRContext *ctx) const {
250 return serializeMapOfSetsNaive(ctx);
251 }
252
253 ChangeResult join(const AbstractDenseLattice &other) {
254 const auto &rhs =
255 static_cast<const MapOfSetsLattice<KeyT, ElementT> &>(other);
256 llvm::SmallDenseSet<DistinctAttr> keys;
257 auto lhsRange = llvm::make_first_range(map);
258 auto rhsRange = llvm::make_first_range(rhs.map);
259 keys.insert(lhsRange.begin(), lhsRange.end());
260 keys.insert(rhsRange.begin(), rhsRange.end());
261
262 ChangeResult result = ChangeResult::NoChange;
263 for (DistinctAttr key : keys) {
264 auto lhsIt = map.find(key);
265 auto rhsIt = rhs.map.find(key);
266 assert(lhsIt != map.end() || rhsIt != rhs.map.end());
267
268 // If present in both, join.
269 if (lhsIt != map.end() && rhsIt != rhs.map.end()) {
270 result |= lhsIt->getSecond().join(rhsIt->getSecond());
271 continue;
272 }
273
274 // Copy from RHS if available only there.
275 if (lhsIt == map.end()) {
276 map.try_emplace(rhsIt->getFirst(), rhsIt->getSecond());
277 result = ChangeResult::Change;
278 }
279
280 // Do nothing if available only in LHS.
281 }
282 return result;
283 }
284
285 /// Map all keys to all values.
286 ChangeResult insert(const SetLattice<KeyT> &keysToUpdate,
287 const SetLattice<ElementT> &values) {
288 if (keysToUpdate.isUnknown())
289 return markAllUnknown();
290
291 if (keysToUpdate.isUndefined())
292 return ChangeResult::NoChange;
293
294 return keysToUpdate.foreachElement(
295 [&](DistinctAttr key, typename SetLattice<KeyT>::State state) {
296 assert(state == SetLattice<KeyT>::State::Defined &&
297 "unknown must have been handled above");
298 return joinPotentiallyMissing(key, values);
299 });
300 }
301
302 ChangeResult markAllUnknown() {
303 ChangeResult result = ChangeResult::NoChange;
304 for (auto &it : map)
305 result |= it.getSecond().join(SetLattice<ElementT>::getUnknown());
306 return result;
307 }
308
309 const SetLattice<ElementT> &lookup(KeyT key) const {
310 auto it = map.find(key);
311 if (it == map.end())
313 return it->getSecond();
314 }
315
316protected:
317 ChangeResult joinPotentiallyMissing(KeyT key,
318 const SetLattice<ElementT> &value) {
319 // Don't store explicitly undefined values in the mapping, keys absent from
320 // the mapping are treated as implicitly undefined.
321 if (value.isUndefined())
322 return ChangeResult::NoChange;
323
324 bool inserted;
325 decltype(map.begin()) iterator;
326 std::tie(iterator, inserted) = map.try_emplace(key, value);
327 if (!inserted)
328 return iterator->second.join(value);
329 return ChangeResult::Change;
330 }
331
332 /// Maps a key to a set of values. When a key is not present in this map, it
333 /// is considered to map to an uninitialized set.
334 DenseMap<KeyT, SetLattice<ElementT>> map;
335
336private:
337 Attribute serializeMapOfSetsNaive(MLIRContext *ctx) const {
338 SmallVector<Attribute> pointsToArray;
339
340 for (const auto &[srcClass, destClasses] : map) {
341 SmallVector<Attribute> pair = {srcClass};
342 SmallVector<Attribute> aliasClasses;
343 if (destClasses.isUnknown()) {
344 aliasClasses.push_back(StringAttr::get(ctx, unknownSetString));
345 } else if (destClasses.isUndefined()) {
346 aliasClasses.push_back(StringAttr::get(ctx, undefinedSetString));
347 } else {
348 for (const Attribute &destClass : destClasses.getElements()) {
349 aliasClasses.push_back(destClass);
350 }
351 llvm::sort(aliasClasses, sortAttributes);
352 }
353 pair.push_back(ArrayAttr::get(ctx, aliasClasses));
354 pointsToArray.push_back(ArrayAttr::get(ctx, pair));
355 }
356 llvm::sort(pointsToArray, [&](Attribute a, Attribute b) {
357 return sortArraysLexicographic(cast<ArrayAttr>(a), cast<ArrayAttr>(b));
358 });
359 return ArrayAttr::get(ctx, pointsToArray);
360 }
361};
362
363} // namespace enzyme
364} // namespace mlir
365
366#endif // ENZYME_MLIR_ANALYSIS_DATAFLOW_LATTICE_H
ChangeResult join(const AbstractDenseLattice &other)
Attribute serialize(MLIRContext *ctx) const
DenseMap< KeyT, SetLattice< ElementT > > map
Maps a key to a set of values.
ChangeResult insert(const SetLattice< KeyT > &keysToUpdate, const SetLattice< ElementT > &values)
Map all keys to all values.
ChangeResult joinPotentiallyMissing(KeyT key, const SetLattice< ElementT > &value)
const SetLattice< ElementT > & lookup(KeyT key) const
DenseSet< ValueT > & getElements()
ChangeResult insert(const DenseSet< ValueT > &newElements)
bool operator==(const SetLattice< ValueT > &other) const
ChangeResult foreachElement(function_ref< ChangeResult(ValueT, State)> callback) const
bool isCanonical() const
Returns true if this set is in the canonical form, i.e.
ChangeResult join(const SetLattice< ValueT > &other)
static const SetLattice< ValueT > & getEmpty()
Returns an instance of SetLattice known not to have any elements.
static const SetLattice< ValueT > & getUndefined()
Returns an instance of SetLattice in "undefined" state, i.e.
const DenseSet< ValueT > & getElements() const
@ Unknown
Analyzed and may contain anything (lattice top).
@ Defined
Has specific elements.
@ Undefined
Has not been analyzed yet (lattice bottom).
static const SetLattice< ValueT > & getUnknown()
Returns an instance of SetLattice for the "unknown" class.
LLVM_DUMP_METHOD void print(llvm::raw_ostream &os) const
Attribute serialize(MLIRContext *ctx) const
SparseSetLattice(Value value, SetLattice< ValueT > &&elements)
const DenseSet< ValueT > & getElements() const
ChangeResult merge(const SetLattice< ValueT > &other)
ChangeResult insert(const DenseSet< ValueT > &newElements)
constexpr llvm::StringLiteral unknownSetString
bool sortArraysLexicographic(ArrayAttr a, ArrayAttr b)
bool sortAttributes(Attribute a, Attribute b)
Used when serializing to ensure a consistent order.
constexpr llvm::StringLiteral undefinedSetString