28#ifndef ENZYME_MLIR_ANALYSIS_DATAFLOW_LATTICE_H
29#define ENZYME_MLIR_ANALYSIS_DATAFLOW_LATTICE_H
31#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
32#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
33#include "mlir/Analysis/DataFlowFramework.h"
76 return ChangeResult::NoChange;
78 return ChangeResult::NoChange;
81 return ChangeResult::Change;
84 ChangeResult result = updateStateToDefined();
85 return insert(other.elements) | result;
88 ChangeResult
insert(
const DenseSet<ValueT> &newElements) {
90 return ChangeResult::NoChange;
92 size_t oldSize = elements.size();
93 elements.insert(newElements.begin(), newElements.end());
94 ChangeResult result = elements.size() == oldSize ? ChangeResult::NoChange
95 : ChangeResult::Change;
96 return updateStateToDefined() | result;
101 return ChangeResult::NoChange;
105 return ChangeResult::Change;
136 return state == other.state && llvm::equal(elements, other.elements);
139 LLVM_DUMP_METHOD
void print(llvm::raw_ostream &os)
const {
145 llvm::interleaveComma(elements, os <<
"{");
153 return callback(
nullptr, state);
155 ChangeResult result = ChangeResult::NoChange;
156 for (ValueT element : elements)
157 result |= callback(element, state);
162 explicit SetLattice(State state) : state(state) {}
164 ChangeResult updateStateToDefined() {
165 assert(state !=
State::Unknown &&
"cannot go back from unknown state");
167 : ChangeResult::NoChange;
172 const static SetLattice<ValueT> unknownSet;
173 const static SetLattice<ValueT> undefinedSet;
175 DenseSet<ValueT> elements;
179template <
typename ValueT>
180const SetLattice<ValueT> SetLattice<ValueT>::unknownSet =
183template <
typename ValueT>
184const SetLattice<ValueT> SetLattice<ValueT>::undefinedSet =
197template <
typename ValueT>
200 using AbstractSparseLattice::AbstractSparseLattice;
204 Attribute
serialize(MLIRContext *ctx)
const {
return serializeSetNaive(ctx); }
210 ChangeResult
insert(
const DenseSet<ValueT> &newElements) {
211 return elements.insert(newElements);
226 Attribute serializeSetNaive(MLIRContext *ctx)
const {
231 SmallVector<Attribute> elementsVec;
232 for (Attribute element :
elements.getElements()) {
233 elementsVec.push_back(element);
236 return ArrayAttr::get(ctx, elementsVec);
244template <
typename KeyT,
typename ElementT>
247 using AbstractDenseLattice::AbstractDenseLattice;
250 return serializeMapOfSetsNaive(ctx);
256 llvm::SmallDenseSet<DistinctAttr> keys;
257 auto lhsRange = llvm::make_first_range(
map);
258 auto rhsRange = llvm::make_first_range(rhs.map);
259 keys.
insert(lhsRange.begin(), lhsRange.end());
260 keys.
insert(rhsRange.begin(), rhsRange.end());
262 ChangeResult result = ChangeResult::NoChange;
263 for (DistinctAttr key : keys) {
264 auto lhsIt =
map.find(key);
265 auto rhsIt = rhs.map.find(key);
266 assert(lhsIt !=
map.end() || rhsIt != rhs.map.end());
269 if (lhsIt !=
map.end() && rhsIt != rhs.map.end()) {
270 result |= lhsIt->getSecond().join(rhsIt->getSecond());
275 if (lhsIt ==
map.end()) {
276 map.try_emplace(rhsIt->getFirst(), rhsIt->getSecond());
277 result = ChangeResult::Change;
292 return ChangeResult::NoChange;
297 "unknown must have been handled above");
303 ChangeResult result = ChangeResult::NoChange;
310 auto it =
map.find(key);
313 return it->getSecond();
322 return ChangeResult::NoChange;
325 decltype(
map.begin()) iterator;
326 std::tie(iterator, inserted) =
map.try_emplace(key, value);
328 return iterator->second.join(value);
329 return ChangeResult::Change;
334 DenseMap<KeyT, SetLattice<ElementT>>
map;
337 Attribute serializeMapOfSetsNaive(MLIRContext *ctx)
const {
338 SmallVector<Attribute> pointsToArray;
340 for (
const auto &[srcClass, destClasses] :
map) {
341 SmallVector<Attribute> pair = {srcClass};
342 SmallVector<Attribute> aliasClasses;
343 if (destClasses.isUnknown()) {
345 }
else if (destClasses.isUndefined()) {
348 for (
const Attribute &destClass : destClasses.getElements()) {
349 aliasClasses.push_back(destClass);
353 pair.push_back(ArrayAttr::get(ctx, aliasClasses));
354 pointsToArray.push_back(ArrayAttr::get(ctx, pair));
356 llvm::sort(pointsToArray, [&](Attribute a, Attribute b) {
359 return ArrayAttr::get(ctx, pointsToArray);
ChangeResult join(const AbstractDenseLattice &other)
Attribute serialize(MLIRContext *ctx) const
DenseMap< KeyT, SetLattice< ElementT > > map
Maps a key to a set of values.
ChangeResult markAllUnknown()
ChangeResult insert(const SetLattice< KeyT > &keysToUpdate, const SetLattice< ElementT > &values)
Map all keys to all values.
ChangeResult joinPotentiallyMissing(KeyT key, const SetLattice< ElementT > &value)
const SetLattice< ElementT > & lookup(KeyT key) const
ChangeResult markUnknown()
DenseSet< ValueT > & getElements()
ChangeResult insert(const DenseSet< ValueT > &newElements)
bool operator==(const SetLattice< ValueT > &other) const
ChangeResult foreachElement(function_ref< ChangeResult(ValueT, State)> callback) const
bool isCanonical() const
Returns true if this set is in the canonical form, i.e.
ChangeResult join(const SetLattice< ValueT > &other)
static const SetLattice< ValueT > & getEmpty()
Returns an instance of SetLattice known not to have any elements.
static const SetLattice< ValueT > & getUndefined()
Returns an instance of SetLattice in "undefined" state, i.e.
SetLattice(ValueT single)
const DenseSet< ValueT > & getElements() const
@ Unknown
Analyzed and may contain anything (lattice top).
@ Defined
Has specific elements.
@ Undefined
Has not been analyzed yet (lattice bottom).
static const SetLattice< ValueT > & getUnknown()
Returns an instance of SetLattice for the "unknown" class.
LLVM_DUMP_METHOD void print(llvm::raw_ostream &os) const
ChangeResult markUnknown()
Attribute serialize(MLIRContext *ctx) const
SparseSetLattice(Value value, SetLattice< ValueT > &&elements)
const DenseSet< ValueT > & getElements() const
SetLattice< ValueT > elements
ChangeResult merge(const SetLattice< ValueT > &other)
ChangeResult insert(const DenseSet< ValueT > &newElements)
constexpr llvm::StringLiteral unknownSetString
bool sortArraysLexicographic(ArrayAttr a, ArrayAttr b)
bool sortAttributes(Attribute a, Attribute b)
Used when serializing to ensure a consistent order.
constexpr llvm::StringLiteral undefinedSetString