Enzyme main
Loading...
Searching...
No Matches
ActivityAnalysis.h
Go to the documentation of this file.
1#ifndef ENZYME_MLIR_ANALYSIS_ACTIVITYANALYSIS_H
2#define ENZYME_MLIR_ANALYSIS_ACTIVITYANALYSIS_H
3
4#include "../../Utils.h"
5#include "mlir/IR/Block.h"
6
7namespace mlir {
8
9class CallOpInterface;
10class FunctionOpInterface;
11
12namespace enzyme {
13
14// class TypeResults {};
15
16class MTypeResults;
17
18/// Helper class to analyze the differential activity
20 // PreProcessCache &PPC;
21
22 // TODO: MLIR aliasing Information
23 // llvm::AAResults &AA;
24
25 // Blocks not to be analyzed
26 const llvm::SmallPtrSetImpl<Block *> &notForAnalysis;
27
28 // Blocks not to be analyzed
29 llvm::DenseMap<Operation *, bool> &readOnlyCache;
30
31 /// Library Information
32 // llvm::TargetLibraryInfo &TLI;
33
34public:
35 /// Whether the returns of the function being analyzed are active
36 const llvm::ArrayRef<DIFFE_TYPE> ActiveReturns;
37
38private:
39 /// Direction of current analysis
40 const uint8_t directions;
41 /// Analyze up based off of operands
42 static constexpr uint8_t UP = 1;
43 /// Analyze down based off uses
44 static constexpr uint8_t DOWN = 2;
45
46 /// Operations that don't propagate adjoints
47 /// These operations could return an active pointer, but
48 /// do not propagate adjoints themselves
49 llvm::SmallPtrSet<Operation *, 4> ConstantOperations;
50
51 /// Operations that could propagate adjoints
52 llvm::SmallPtrSet<Operation *, 20> ActiveOperations;
53
54 /// Values that do not contain derivative information, either
55 /// directly or as a pointer to
56 llvm::SmallPtrSet<Value, 4> ConstantValues;
57
58 /// Values that may contain derivative information
59 llvm::SmallPtrSet<Value, 2> ActiveValues;
60
61 /// Intermediate pointers which are created by inactive instructions
62 /// but are marked as active values to inductively determine their
63 /// activity.
64 llvm::SmallPtrSet<Value, 1> DeducingPointers;
65
66public:
67 /// Construct the analyzer from the a previous set of constant and active
68 /// values and whether returns are active. The all arguments of the functions
69 /// being analyzed must be in the set of constant and active values, lest an
70 /// error occur during analysis
72 // PreProcessCache &PPC, llvm::AAResults &AA_,
73 const llvm::SmallPtrSetImpl<Block *> &notForAnalysis_,
74 llvm::DenseMap<Operation *, bool> &readOnlyCache_,
75 // llvm::TargetLibraryInfo &TLI_,
76 const llvm::SmallPtrSetImpl<Value> &ConstantValues,
77 const llvm::SmallPtrSetImpl<Value> &ActiveValues,
78 llvm::ArrayRef<DIFFE_TYPE> ActiveReturns)
79 : notForAnalysis(notForAnalysis_), readOnlyCache(readOnlyCache_),
80 ActiveReturns(ActiveReturns), directions(UP | DOWN),
81 ConstantValues(ConstantValues.begin(), ConstantValues.end()),
82 ActiveValues(ActiveValues.begin(), ActiveValues.end()) {}
83
84 /// Return whether this operation is known not to propagate adjoints
85 /// Note that operations could return an active pointer, but
86 /// do not propagate adjoints themselves
87 bool isConstantOperation(MTypeResults const &TR, Operation *op);
88
89 /// Return whether this values is known not to contain derivative
90 /// information, either directly or as a pointer to
91 bool isConstantValue(MTypeResults const &TR, Value val);
92
93 bool isReadOnly(Operation *val);
94
95private:
96 DenseMap<Operation *, llvm::SmallPtrSet<Value, 4>>
97 ReEvaluateValueIfInactiveOp;
98 DenseMap<Value, llvm::SmallPtrSet<Value, 4>> ReEvaluateValueIfInactiveValue;
99 DenseMap<Value, llvm::SmallPtrSet<Operation *, 4>>
100 ReEvaluateOpIfInactiveValue;
101
102 void InsertConstantOperation(MTypeResults const &TR, Operation *op);
103 void InsertConstantValue(MTypeResults const &TR, Value V);
104
105 /// Create a new analyzer starting from an existing Analyzer
106 /// This is used to perform inductive assumptions
107 ActivityAnalyzer(ActivityAnalyzer &Other, uint8_t directions)
108 : notForAnalysis(Other.notForAnalysis),
109 readOnlyCache(Other.readOnlyCache), ActiveReturns(Other.ActiveReturns),
110 directions(directions), ConstantOperations(Other.ConstantOperations),
111 ActiveOperations(Other.ActiveOperations),
112 ConstantValues(Other.ConstantValues), ActiveValues(Other.ActiveValues) {
113 // DeducingPointers(Other.DeducingPointers) {
114 assert(directions != 0);
115 assert((directions & Other.directions) == directions);
116 assert((directions & Other.directions) != 0);
117 }
118
119 /// Import known constants from an existing analyzer
120 void insertConstantsFrom(MTypeResults const &TR,
121 ActivityAnalyzer &Hypothesis) {
122 for (auto I : Hypothesis.ConstantOperations) {
123 InsertConstantOperation(TR, I);
124 }
125 for (auto V : Hypothesis.ConstantValues) {
126 InsertConstantValue(TR, V);
127 }
128 }
129
130 /// Import known data from an existing analyzer
131 void insertAllFrom(MTypeResults const &TR, ActivityAnalyzer &Hypothesis,
132 Value Orig) {
133 insertConstantsFrom(TR, Hypothesis);
134 for (auto I : Hypothesis.ActiveOperations) {
135 bool inserted = ActiveOperations.insert(I).second;
136 if (inserted && directions == 3) {
137 ReEvaluateOpIfInactiveValue[Orig].insert(I);
138 }
139 }
140 for (auto V : Hypothesis.ActiveValues) {
141 bool inserted = ActiveValues.insert(V).second;
142 if (inserted && directions == 3) {
143 ReEvaluateValueIfInactiveValue[Orig].insert(V);
144 }
145 }
146 }
147
148 /// Is the use of value val as an argument of call CI known to be inactive
149 bool isFunctionArgumentConstant(mlir::CallOpInterface CI, Value val);
150
151 /// Is the value guaranteed to be inactive because of how it's produced.
152 /// If active and inactArg is non-null, store any values which may allow this
153 /// to succeed in the future
154 bool isValueInactiveFromOrigin(
155 MTypeResults const &TR, Value val,
156 llvm::SmallPtrSetImpl<mlir::Value> *inactArg = nullptr);
157
158 /// Is the operation guaranteed to be inactive because of how its operands are
159 /// produced.
160 bool isOperationInactiveFromOrigin(
161 MTypeResults const &TR, Operation *op,
162 std::optional<unsigned> resultNo = std::nullopt,
163 llvm::SmallPtrSetImpl<mlir::Value> *inactArg = nullptr);
164
165public:
166 enum class UseActivity {
167 // No Additional use activity info
168 None = 0,
169
170 // Only consider loads of memory
171 OnlyLoads = 1,
172
173 // Only consider active stores into
174 OnlyStores = 2,
175
176 // Only consider active stores and pointer-style loads
178
179 // Only consider any (active or not) stores into
180 AllStores = 4
181 };
182 /// Is the value free of any active uses
183 bool isValueInactiveFromUsers(MTypeResults const &TR, Value val,
184 UseActivity UA,
185 Operation **FoundInst = nullptr);
186
187 /// Is the value potentially actively returned or stored
188 bool isValueActivelyStoredOrReturned(MTypeResults const &TR, Value val,
189 bool outside = false);
190
191private:
192 /// StoredOrReturnedCache acts as an inductive cache of results for
193 /// isValueActivelyStoredOrReturned
194 std::map<std::pair<bool, Value>, bool> StoredOrReturnedCache;
195};
196
197} // namespace enzyme
198
199inline bool operator<(const Value &lhs, const Value &rhs) {
200 return lhs.getAsOpaquePointer() < rhs.getAsOpaquePointer();
201}
202} // namespace mlir
203
204#endif // ENZYME_MLIR_ANALYSIS_ACTIVITYANALYSIS_H
Helper class to analyze the differential activity.
Helper class to analyze the differential activity.
bool isConstantOperation(MTypeResults const &TR, Operation *op)
Return whether this operation is known not to propagate adjoints Note that operations could return an...
const llvm::ArrayRef< DIFFE_TYPE > ActiveReturns
Library Information.
ActivityAnalyzer(const llvm::SmallPtrSetImpl< Block * > &notForAnalysis_, llvm::DenseMap< Operation *, bool > &readOnlyCache_, const llvm::SmallPtrSetImpl< Value > &ConstantValues, const llvm::SmallPtrSetImpl< Value > &ActiveValues, llvm::ArrayRef< DIFFE_TYPE > ActiveReturns)
Construct the analyzer from the a previous set of constant and active values and whether returns are ...
bool isValueInactiveFromUsers(MTypeResults const &TR, Value val, UseActivity UA, Operation **FoundInst=nullptr)
Is the value free of any active uses.
bool isValueActivelyStoredOrReturned(MTypeResults const &TR, Value val, bool outside=false)
Is the value potentially actively returned or stored.
bool isConstantValue(MTypeResults const &TR, Value val)
Return whether this values is known not to contain derivative information, either directly or as a po...
bool operator<(const Value &lhs, const Value &rhs)