Enzyme main
Loading...
Searching...
No Matches
ActivityAnnotations.h
Go to the documentation of this file.
1#ifndef ENZYME_MLIR_ANALYSIS_ACTIVITYANNOTATIONS_H
2#define ENZYME_MLIR_ANALYSIS_ACTIVITYANNOTATIONS_H
3
5#include "DataFlowLattice.h"
6#include "Dialect/Ops.h"
7
8#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
9#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
10#include "mlir/Analysis/DataFlowFramework.h"
11
12namespace mlir {
13class FunctionOpInterface;
14
15namespace enzyme {
16
18
19//===----------------------------------------------------------------------===//
20// ForwardOriginsLattice
21//===----------------------------------------------------------------------===//
22
23// TODO: specialize this to only arguments
24class ForwardOriginsLattice : public SparseSetLattice<OriginAttr> {
25public:
27
28 static ForwardOriginsLattice single(Value point, OriginAttr value) {
30 }
31
32 void print(raw_ostream &os) const override;
33
34 ChangeResult join(const AbstractSparseLattice &other) override;
35
36 const DenseSet<OriginAttr> &getOrigins() const {
37 return elements.getElements();
38 }
39
41};
42
43class BackwardOriginsLattice : public SparseSetLattice<OriginAttr> {
44public:
46
47 static BackwardOriginsLattice single(Value point, OriginAttr value) {
49 }
50
51 void print(raw_ostream &os) const override;
52
53 ChangeResult meet(const AbstractSparseLattice &other) override {
54 // MLIR framework again misusing terminology
55 const auto *otherValueOrigins =
56 static_cast<const BackwardOriginsLattice *>(&other);
57 return elements.join(otherValueOrigins->elements);
58 }
59
60 const DenseSet<OriginAttr> &getOrigins() const {
61 return elements.getElements();
62 }
63
65};
66
68 : public dataflow::SparseForwardDataFlowAnalysis<ForwardOriginsLattice> {
69public:
70 ForwardActivityAnnotationAnalysis(DataFlowSolver &solver)
72 assert(!solver.getConfig().isInterprocedural());
73 }
74
75 void setToEntryState(ForwardOriginsLattice *lattice) override;
76
77 LogicalResult
78 visitOperation(Operation *op,
79 ArrayRef<const ForwardOriginsLattice *> operands,
80 ArrayRef<ForwardOriginsLattice *> results) override;
81
82 void visitExternalCall(CallOpInterface call,
83 ArrayRef<const ForwardOriginsLattice *> operands,
84 ArrayRef<ForwardOriginsLattice *> results) override;
85
86private:
87 void processMemoryRead(Operation *op, Value address,
88 ArrayRef<ForwardOriginsLattice *> results);
89
90 void markResultsUnknown(ArrayRef<ForwardOriginsLattice *> results);
91
92 void
93 processCallToSummarizedFunc(CallOpInterface call,
94 ArrayRef<ValueOriginSet> summary,
95 ArrayRef<const ForwardOriginsLattice *> operands,
96 ArrayRef<ForwardOriginsLattice *> results);
97};
98
100 : public dataflow::SparseBackwardDataFlowAnalysis<BackwardOriginsLattice> {
101public:
103 SymbolTableCollection &symbolTable)
104 : SparseBackwardDataFlowAnalysis(solver, symbolTable) {
105 assert(!solver.getConfig().isInterprocedural());
106 }
107
108 void visitBranchOperand(OpOperand &operand) override {}
109
110 void visitCallOperand(OpOperand &operand) override {}
111
112 void
113 visitNonControlFlowArguments(RegionSuccessor &successor,
114 ArrayRef<BlockArgument> arguments) override {}
115
116 void setToExitState(BackwardOriginsLattice *lattice) override;
117
118 LogicalResult
119 visitOperation(Operation *op, ArrayRef<BackwardOriginsLattice *> operands,
120 ArrayRef<const BackwardOriginsLattice *> results) override;
121
122 void
123 visitExternalCall(CallOpInterface call,
124 ArrayRef<BackwardOriginsLattice *> operands,
125 ArrayRef<const BackwardOriginsLattice *> results) override;
126
127private:
128 void
129 processCallToSummarizedFunc(CallOpInterface call,
130 ArrayRef<ValueOriginSet> summary,
131 ArrayRef<BackwardOriginsLattice *> operands,
132 ArrayRef<const BackwardOriginsLattice *> results);
133
134 void markOperandsUnknown(ArrayRef<BackwardOriginsLattice *> operands);
135};
136
137//===----------------------------------------------------------------------===//
138// ForwardOriginsMap
139//===----------------------------------------------------------------------===//
140
141class ForwardOriginsMap : public MapOfSetsLattice<DistinctAttr, OriginAttr> {
142public:
143 using MapOfSetsLattice::MapOfSetsLattice;
144
145 void print(raw_ostream &os) const override;
146
147 ChangeResult markAllOriginsUnknown() { return markAllUnknown(); }
148
149 const ValueOriginSet &getOrigins(DistinctAttr id) const { return lookup(id); }
150};
151
152class BackwardOriginsMap : public MapOfSetsLattice<DistinctAttr, OriginAttr> {
153public:
154 using MapOfSetsLattice::MapOfSetsLattice;
155
156 void print(raw_ostream &os) const override;
157
158 ChangeResult markAllOriginsUnknown() { return markAllUnknown(); }
159
160 const ValueOriginSet &getOrigins(DistinctAttr id) const { return lookup(id); }
161
162 ChangeResult meet(const AbstractDenseLattice &other) override {
163 return join(other);
164 }
165};
166
167//===----------------------------------------------------------------------===//
168// DenseActivityAnnotationAnalysis
169//===----------------------------------------------------------------------===//
170
172 : public dataflow::DenseForwardDataFlowAnalysis<ForwardOriginsMap> {
173public:
174 using DenseForwardDataFlowAnalysis::DenseForwardDataFlowAnalysis;
175
176 void setToEntryState(ForwardOriginsMap *lattice) override;
177
178 LogicalResult visitOperation(Operation *op, const ForwardOriginsMap &before,
179 ForwardOriginsMap *after) override;
180
181 void visitCallControlFlowTransfer(CallOpInterface call,
182 dataflow::CallControlFlowAction action,
183 const ForwardOriginsMap &before,
184 ForwardOriginsMap *after) override;
185
186private:
187 void processCallToSummarizedFunc(
188 CallOpInterface call,
189 const DenseMap<DistinctAttr, ValueOriginSet> &summary,
190 const ForwardOriginsMap &before, ForwardOriginsMap *after);
191
192 void processCopy(Operation *op, Value copySource, Value copyDest,
193 const ForwardOriginsMap &before, ForwardOriginsMap *after);
194
195 OriginalClasses originalClasses;
196};
197
199 : public dataflow::DenseBackwardDataFlowAnalysis<BackwardOriginsMap> {
200public:
201 using DenseBackwardDataFlowAnalysis::DenseBackwardDataFlowAnalysis;
202
203 LogicalResult visitOperation(Operation *op, const BackwardOriginsMap &after,
204 BackwardOriginsMap *before) override;
205
206 void visitCallControlFlowTransfer(CallOpInterface call,
207 dataflow::CallControlFlowAction action,
208 const BackwardOriginsMap &after,
209 BackwardOriginsMap *before) override;
210
211 void setToExitState(BackwardOriginsMap *lattice) override;
212
213private:
214 void processCallToSummarizedFunc(
215 CallOpInterface call,
216 const DenseMap<DistinctAttr, ValueOriginSet> &summary,
217 const BackwardOriginsMap &after, BackwardOriginsMap *before);
218
219 void processCopy(Operation *op, Value copySource, Value copyDest,
220 const BackwardOriginsMap &after, BackwardOriginsMap *before);
221};
222
224public:
226
227 /// Whether to use the data-flow based algorithm or the classic activity
228 /// analysis.
229 bool dataflow = true;
230 /// Use function summaries
231 bool relative = true;
232 /// Output extra information for debugging
233 bool verbose = false;
234 /// Annotate the IR with activity information for every operation. Currently
235 /// only supports the LLVM dialect.
236 bool annotate = false;
237 /// Infer the starting argument state from an __enzyme_autodiff call.
238 bool inferFromAutodiff = false;
239};
240
242 FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities = {},
243 const ActivityPrinterConfig &config = ActivityPrinterConfig());
244
245} // namespace enzyme
246} // namespace mlir
247
248#endif // ENZYME_MLIR_ANALYSIS_ACTIVITYANNOTATIONS_H
bool annotate
Annotate the IR with activity information for every operation.
bool inferFromAutodiff
Infer the starting argument state from an __enzyme_autodiff call.
bool verbose
Output extra information for debugging.
bool dataflow
Whether to use the data-flow based algorithm or the classic activity analysis.
bool relative
Use function summaries.
void visitBranchOperand(OpOperand &operand) override
BackwardActivityAnnotationAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable)
void visitExternalCall(CallOpInterface call, ArrayRef< BackwardOriginsLattice * > operands, ArrayRef< const BackwardOriginsLattice * > results) override
void visitCallOperand(OpOperand &operand) override
void setToExitState(BackwardOriginsLattice *lattice) override
LogicalResult visitOperation(Operation *op, ArrayRef< BackwardOriginsLattice * > operands, ArrayRef< const BackwardOriginsLattice * > results) override
void visitNonControlFlowArguments(RegionSuccessor &successor, ArrayRef< BlockArgument > arguments) override
void print(raw_ostream &os) const override
ChangeResult meet(const AbstractSparseLattice &other) override
static BackwardOriginsLattice single(Value point, OriginAttr value)
const DenseSet< OriginAttr > & getOrigins() const
const SetLattice< OriginAttr > & getOriginsObject() const
ChangeResult meet(const AbstractDenseLattice &other) override
void print(raw_ostream &os) const override
const ValueOriginSet & getOrigins(DistinctAttr id) const
LogicalResult visitOperation(Operation *op, const ForwardOriginsMap &before, ForwardOriginsMap *after) override
void visitCallControlFlowTransfer(CallOpInterface call, dataflow::CallControlFlowAction action, const ForwardOriginsMap &before, ForwardOriginsMap *after) override
void setToEntryState(ForwardOriginsMap *lattice) override
void setToExitState(BackwardOriginsMap *lattice) override
LogicalResult visitOperation(Operation *op, const BackwardOriginsMap &after, BackwardOriginsMap *before) override
void visitCallControlFlowTransfer(CallOpInterface call, dataflow::CallControlFlowAction action, const BackwardOriginsMap &after, BackwardOriginsMap *before) override
void setToEntryState(ForwardOriginsLattice *lattice) override
void visitExternalCall(CallOpInterface call, ArrayRef< const ForwardOriginsLattice * > operands, ArrayRef< ForwardOriginsLattice * > results) override
LogicalResult visitOperation(Operation *op, ArrayRef< const ForwardOriginsLattice * > operands, ArrayRef< ForwardOriginsLattice * > results) override
ChangeResult join(const AbstractSparseLattice &other) override
const SetLattice< OriginAttr > & getOriginsObject() const
const DenseSet< OriginAttr > & getOrigins() const
void print(raw_ostream &os) const override
static ForwardOriginsLattice single(Value point, OriginAttr value)
const ValueOriginSet & getOrigins(DistinctAttr id) const
void print(raw_ostream &os) const override
ChangeResult join(const AbstractDenseLattice &other)
const SetLattice< OriginAttr > & lookup(DistinctAttr key) const
Alias classes for freshly created, e.g., allocated values.
DenseSet< ValueT > & getElements()
ChangeResult join(const SetLattice< ValueT > &other)
SparseSetLattice(Value value, SetLattice< ValueT > &&elements)
void runActivityAnnotations(FunctionOpInterface callee, ArrayRef< enzyme::Activity > argActivities={}, const ActivityPrinterConfig &config=ActivityPrinterConfig())