Enzyme main
Loading...
Searching...
No Matches
ActivityAnnotations.cpp
Go to the documentation of this file.
3#include "DataFlowLattice.h"
4#include "Dialect/Dialect.h"
5#include "Dialect/Ops.h"
7
8#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
9#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
11#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12#include "mlir/Analysis/DataFlowFramework.h"
13#include "mlir/Interfaces/FunctionInterfaces.h"
14#include "llvm/Support/raw_ostream.h"
15
16// TODO: Remove dependency on dialects in favour of differential dependency
17// interface
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19
20using namespace mlir;
21
22static bool isPossiblyActive(Type type) {
23 return isa<FloatType, ComplexType>(type);
24}
25
26void enzyme::ForwardOriginsLattice::print(raw_ostream &os) const {
27 os << serialize(getAnchor().getContext());
28}
29
30void enzyme::BackwardOriginsLattice::print(raw_ostream &os) const {
31 os << serialize(getAnchor().getContext());
32}
33
34ChangeResult
36 const auto *otherValueOrigins =
37 static_cast<const ForwardOriginsLattice *>(&other);
38 return elements.join(otherValueOrigins->elements);
39}
40
42 ForwardOriginsLattice *lattice) {
43 auto arg = dyn_cast<BlockArgument>(lattice->getAnchor());
44 if (!arg) {
45 assert(lattice->isUndefined());
46 return;
47 }
48 if (!isPossiblyActive(arg.getType())) {
49 return;
50 }
51
52 auto funcOp = cast<FunctionOpInterface>(arg.getOwner()->getParentOp());
53 auto origin = ArgumentOriginAttr::get(FlatSymbolRefAttr::get(funcOp),
54 arg.getArgNumber());
55 return propagateIfChanged(
56 lattice, lattice->join(ForwardOriginsLattice::single(lattice->getAnchor(),
57 origin)));
58}
59
60void enzyme::ForwardActivityAnnotationAnalysis::markResultsUnknown(
61 ArrayRef<ForwardOriginsLattice *> results) {
62 for (ForwardOriginsLattice *result : results) {
63 propagateIfChanged(result, result->markUnknown());
64 }
65}
66
67/// True iff all results differentially depend on all operands
68// TODO: differential dependency/activity interface
69// TODO: Select cond is not fully active
70static bool isFullyActive(Operation *op) {
71 return isa<LLVM::FMulOp, LLVM::FAddOp, LLVM::FDivOp, LLVM::FSubOp,
72 LLVM::FNegOp, LLVM::FAbsOp, LLVM::SqrtOp, LLVM::SinOp, LLVM::CosOp,
73 LLVM::Exp2Op, LLVM::ExpOp, LLVM::LogOp, LLVM::InsertValueOp,
74 LLVM::ExtractValueOp, LLVM::BitcastOp, LLVM::SelectOp>(op);
75}
76
78 Operation *op, ArrayRef<const ForwardOriginsLattice *> operands,
79 ArrayRef<ForwardOriginsLattice *> results) {
80 if (isFullyActive(op)) {
81 for (ForwardOriginsLattice *result : results) {
82 for (const ForwardOriginsLattice *operand : operands) {
83 join(result, *operand);
84 }
85 }
86 return success();
87 }
88
89 auto activityIface = dyn_cast<enzyme::ActivityOpInterface>(op);
90 if (isPure(op) || (activityIface && activityIface.isInactive()))
91 return success();
92
93 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
94 if (!memory) {
95 markResultsUnknown(results);
96 return success();
97 }
98
99 SmallVector<MemoryEffects::EffectInstance> effects;
100 memory.getEffects(effects);
101 for (const auto &effect : effects) {
102 if (!isa<MemoryEffects::Read>(effect.getEffect()))
103 continue;
104
105 Value value = effect.getValue();
106 if (!value) {
107 markResultsUnknown(results);
108 continue;
109 }
110 processMemoryRead(op, value, results);
111 }
112 return success();
113}
114
115void enzyme::ForwardActivityAnnotationAnalysis::processMemoryRead(
116 Operation *op, Value address, ArrayRef<ForwardOriginsLattice *> results) {
117 ProgramPoint *point = getProgramPointAfter(op);
118 auto *srcClasses = getOrCreateFor<AliasClassLattice>(point, address);
119 auto *originsMap = getOrCreateFor<ForwardOriginsMap>(point, point);
120 if (srcClasses->isUndefined())
121 return;
122 if (srcClasses->isUnknown())
123 return markResultsUnknown(results);
124
125 // Look up the alias class and see what its origins are, then propagate
126 // those origins to the read results.
127 for (DistinctAttr srcClass : srcClasses->getAliasClasses()) {
128 for (ForwardOriginsLattice *result : results) {
129 if (isPossiblyActive(result->getAnchor().getType())) {
130 propagateIfChanged(result,
131 result->merge(originsMap->getOrigins(srcClass)));
132 }
133 }
134 }
135}
136
137void deserializeReturnOrigins(ArrayAttr returnOrigins,
138 SmallVectorImpl<enzyme::ValueOriginSet> &out) {
139 for (auto &&[resultIdx, argOrigins] : llvm::enumerate(returnOrigins)) {
141 if (auto strAttr = dyn_cast<StringAttr>(argOrigins)) {
142 if (strAttr.getValue() == "<unknown>") {
143 (void)origins.markUnknown();
144 } else {
145 // Leave origins undefined
146 }
147 } else {
148 for (enzyme::ArgumentOriginAttr originAttr :
149 cast<ArrayAttr>(argOrigins)
150 .getAsRange<enzyme::ArgumentOriginAttr>()) {
151 (void)origins.insert({originAttr});
152 }
153 }
154
155 out.push_back(origins);
156 }
157}
158
160 CallOpInterface call, ArrayRef<const ForwardOriginsLattice *> operands,
161 ArrayRef<ForwardOriginsLattice *> results) {
162 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
163 auto markAllResultsUnknown = [&]() {
164 for (ForwardOriginsLattice *result : results) {
165 propagateIfChanged(result, result->markUnknown());
166 }
167 };
168 if (!symbol)
169 return markAllResultsUnknown();
170
171 if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
172 call, symbol.getLeafReference())) {
173 if (auto returnOriginsAttr = callee->getAttrOfType<ArrayAttr>(
174 EnzymeDialect::getSparseActivityAnnotationAttrName())) {
175 SmallVector<ValueOriginSet> returnOrigins;
176 deserializeReturnOrigins(returnOriginsAttr, returnOrigins);
177 return processCallToSummarizedFunc(call, returnOrigins, operands,
178 results);
179 }
180 }
181
182 // In the absence of a summary attribute, assume all results differentially
183 // depend on all operands
184 for (ForwardOriginsLattice *result : results)
185 for (const ForwardOriginsLattice *operand : operands)
186 join(result, *operand);
187}
188
189/// Visit everything transitively pointed-to by any pointer in start.
191 const enzyme::PointsToSets &pointsToSets,
192 function_ref<void(DistinctAttr)> visit) {
194 AliasClassSet current(start);
195 while (!current.isUndefined()) {
196 AliasClassSet next;
197
198 assert(!current.isUnknown() && "Unhandled traversal of unknown");
199 for (DistinctAttr currentClass : current.getElements()) {
200 visit(currentClass);
201 (void)next.join(pointsToSets.getPointsTo(currentClass));
202 }
203 std::swap(current, next);
204 }
205}
206
207void enzyme::ForwardActivityAnnotationAnalysis::processCallToSummarizedFunc(
208 CallOpInterface call, ArrayRef<ValueOriginSet> summary,
209 ArrayRef<const ForwardOriginsLattice *> operands,
210 ArrayRef<ForwardOriginsLattice *> results) {
211 for (const auto &[result, returnOrigin] : llvm::zip(results, summary)) {
212 // Convert the origins relative to the callee to relative to the caller
213 ValueOriginSet callerOrigins;
214 if (returnOrigin.isUndefined())
215 continue;
216
217 if (returnOrigin.isUnknown()) {
218 (void)callerOrigins.markUnknown();
219 } else {
220 ProgramPoint *point = getProgramPointAfter(call);
221 auto *denseOrigins = getOrCreateFor<ForwardOriginsMap>(point, point);
222 auto *pointsTo = getOrCreateFor<PointsToSets>(point, point);
223 (void)returnOrigin.foreachElement(
224 [&](OriginAttr calleeOrigin, ValueOriginSet::State state) {
225 assert(state == ValueOriginSet::State::Defined &&
226 "undefined and unknown must have been handled above");
227 auto calleeArgOrigin = cast<ArgumentOriginAttr>(calleeOrigin);
228 // If the caller is a pointer, need to join what it points to
229 const ForwardOriginsLattice *operandOrigins =
230 operands[calleeArgOrigin.getArgNumber()];
231 auto *callerAliasClass = getOrCreateFor<AliasClassLattice>(
232 getProgramPointAfter(call), operandOrigins->getAnchor());
233 traversePointsToSets(callerAliasClass->getAliasClassesObject(),
234 *pointsTo, [&](DistinctAttr aliasClass) {
235 (void)callerOrigins.join(
236 denseOrigins->getOrigins(aliasClass));
237 });
238 return callerOrigins.join(operandOrigins->getOriginsObject());
239 });
240 }
241 propagateIfChanged(result, result->merge(callerOrigins));
242 }
243}
244
246 BackwardOriginsLattice *lattice) {
247 propagateIfChanged(lattice, lattice->markUnknown());
248}
249
250void enzyme::BackwardActivityAnnotationAnalysis::markOperandsUnknown(
251 ArrayRef<BackwardOriginsLattice *> operands) {
252 for (BackwardOriginsLattice *operand : operands) {
253 propagateIfChanged(operand, operand->markUnknown());
254 }
255}
256
258 Operation *op, ArrayRef<BackwardOriginsLattice *> operands,
259 ArrayRef<const BackwardOriginsLattice *> results) {
260 if (isFullyActive(op)) {
261 for (BackwardOriginsLattice *operand : operands)
262 for (const BackwardOriginsLattice *result : results)
263 meet(operand, *result);
264 }
265
266 auto activityIface = dyn_cast<enzyme::ActivityOpInterface>(op);
267 if (isPure(op) || (activityIface && activityIface.isInactive()))
268 return success();
269
270 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
271 if (!memory) {
272 markOperandsUnknown(operands);
273 return success();
274 }
275
276 SmallVector<MemoryEffects::EffectInstance> effects;
277 memory.getEffects(effects);
278 for (const auto &effect : effects) {
279 if (!isa<MemoryEffects::Read>(effect.getEffect()))
280 continue;
281
282 Value value = effect.getValue();
283 if (!value) {
284 markOperandsUnknown(operands);
285 continue;
286 }
287
288 auto *srcClasses =
289 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
290 auto *originsMap =
291 getOrCreate<BackwardOriginsMap>(getProgramPointBefore(op));
292
293 ChangeResult changed = ChangeResult::NoChange;
294 for (const BackwardOriginsLattice *result : results)
295 changed |= originsMap->insert(srcClasses->getAliasClassesObject(),
296 result->getOriginsObject());
297 propagateIfChanged(originsMap, changed);
298 }
299 return success();
300}
301
303 CallOpInterface call, ArrayRef<BackwardOriginsLattice *> operands,
304 ArrayRef<const BackwardOriginsLattice *> results) {
305 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
306 if (!symbol)
307 return markOperandsUnknown(operands);
308
309 if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
310 call, symbol.getLeafReference())) {
311 if (auto returnOriginsAttr = callee->getAttrOfType<ArrayAttr>(
312 EnzymeDialect::getSparseActivityAnnotationAttrName())) {
313 SmallVector<ValueOriginSet> returnOrigins;
314 deserializeReturnOrigins(returnOriginsAttr, returnOrigins);
315 return processCallToSummarizedFunc(call, returnOrigins, operands,
316 results);
317 }
318 }
319
320 // In the absence of a summary attribute, assume all results differentially
321 // depend on all operands
322 for (BackwardOriginsLattice *operand : operands)
323 for (const BackwardOriginsLattice *result : results)
324 meet(operand, *result);
325}
326
327void enzyme::BackwardActivityAnnotationAnalysis::processCallToSummarizedFunc(
328 CallOpInterface call, ArrayRef<ValueOriginSet> summary,
329 ArrayRef<BackwardOriginsLattice *> operands,
330 ArrayRef<const BackwardOriginsLattice *> results) {
331 // collect the result origins, propagate them to the operands.
332 for (const auto &[result, calleeOrigins] : llvm::zip(results, summary)) {
333 ValueOriginSet resultOrigins = result->getOriginsObject();
334 if (calleeOrigins.isUndefined())
335 continue;
336 if (calleeOrigins.isUnknown())
337 (void)resultOrigins.markUnknown();
338 else {
339 (void)calleeOrigins.foreachElement(
340 [&](OriginAttr calleeOrigin, ValueOriginSet::State state) {
341 auto calleeArgOrigin = cast<ArgumentOriginAttr>(calleeOrigin);
342 BackwardOriginsLattice *operand =
343 operands[calleeArgOrigin.getArgNumber()];
344 propagateIfChanged(operand, operand->merge(resultOrigins));
345 return ChangeResult::NoChange;
346 });
347 }
348 }
349}
350
351template <typename KeyT, typename ElementT>
353 const DenseMap<KeyT, enzyme::SetLattice<ElementT>> map, raw_ostream &os) {
354 if (map.empty()) {
355 os << "<empty>\n";
356 return;
357 }
358 for (const auto &[aliasClass, origins] : map) {
359 os << " " << aliasClass << " originates from {";
360 if (origins.isUnknown()) {
361 os << "<unknown>";
362 } else if (origins.isUndefined()) {
363 os << "<undefined>";
364 } else {
365 llvm::interleaveComma(origins.getElements(), os);
366 }
367 os << "}\n";
368 }
369}
370
371void enzyme::ForwardOriginsMap::print(raw_ostream &os) const {
372 printMapOfSetsLattice(this->map, os);
373}
374
375void enzyme::BackwardOriginsMap::print(raw_ostream &os) const {
376 printMapOfSetsLattice(this->map, os);
377}
378
380 ForwardOriginsMap *lattice) {
381 auto point = dyn_cast<ProgramPoint *>(lattice->getAnchor());
382 auto *block = point->getBlock();
383 if (!block)
384 return;
385
386 auto funcOp = cast<FunctionOpInterface>(block->getParentOp());
387 ChangeResult changed = ChangeResult::NoChange;
388 for (BlockArgument arg : funcOp.getArguments()) {
389 auto *argClass = getOrCreateFor<AliasClassLattice>(point, arg);
390 auto origin = ArgumentOriginAttr::get(FlatSymbolRefAttr::get(funcOp),
391 arg.getArgNumber());
392 changed |= lattice->insert(argClass->getAliasClassesObject(),
393 ValueOriginSet(origin));
394 }
395 propagateIfChanged(lattice, changed);
396}
397
398std::optional<Value> getStored(Operation *op);
399std::optional<Value> getCopySource(Operation *op);
400
402 Operation *op, const ForwardOriginsMap &before, ForwardOriginsMap *after) {
403 join(after, before);
404
405 auto activityIface = dyn_cast<enzyme::ActivityOpInterface>(op);
406 if (activityIface && activityIface.isInactive())
407 return success();
408
409 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
410 if (!memory) {
411 propagateIfChanged(after, after->markAllOriginsUnknown());
412 return success();
413 }
414
415 SmallVector<MemoryEffects::EffectInstance> effects;
416 memory.getEffects(effects);
417 for (const auto &effect : effects) {
418 Value value = effect.getValue();
419 if (!value) {
420 propagateIfChanged(after, after->markAllOriginsUnknown());
421 return success();
422 }
423
424 if (isa<MemoryEffects::Read>(effect.getEffect())) {
425 // TODO: Really need that memory interface
426 if (op->getNumResults() != 1)
427 continue;
428 Value readDest = op->getResult(0);
429
430 auto *destClasses =
431 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), readDest);
432 if (destClasses->isUndefined())
433 // Not a pointer, so the sparse analysis will handle this.
434 continue;
435
436 auto *srcClasses =
437 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
438 if (srcClasses->isUnknown()) {
439 propagateIfChanged(after,
440 after->insert(destClasses->getAliasClassesObject(),
442 continue;
443 }
444
445 ChangeResult changed = ChangeResult::NoChange;
446 for (DistinctAttr srcClass : srcClasses->getAliasClasses()) {
447 changed |= after->insert(destClasses->getAliasClassesObject(),
448 before.getOrigins(srcClass));
449 }
450 propagateIfChanged(after, changed);
451 } else if (isa<MemoryEffects::Write>(effect.getEffect())) {
452 if (std::optional<Value> stored = getStored(op)) {
453 if (!isPossiblyActive(stored->getType())) {
454 continue;
455 }
456 auto *origins = getOrCreateFor<ForwardOriginsLattice>(
457 getProgramPointAfter(op), *stored);
458 auto *dest =
459 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
460 propagateIfChanged(after, after->insert(dest->getAliasClassesObject(),
461 origins->getOriginsObject()));
462 } else if (std::optional<Value> copySource = getCopySource(op)) {
463 processCopy(op, *copySource, value, before, after);
464 } else {
465 propagateIfChanged(after, after->markAllOriginsUnknown());
466 }
467 }
468 }
469 return success();
470}
471
472void enzyme::DenseActivityAnnotationAnalysis::processCopy(
473 Operation *op, Value copySource, Value copyDest,
474 const ForwardOriginsMap &before, ForwardOriginsMap *after) {
475 auto *src =
476 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), copySource);
477 ValueOriginSet srcOrigins;
478 if (src->isUndefined())
479 return;
480 if (src->isUnknown())
481 (void)srcOrigins.markUnknown();
482
483 for (DistinctAttr srcClass : src->getAliasClasses())
484 (void)srcOrigins.join(before.getOrigins(srcClass));
485
486 auto *dest =
487 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), copyDest);
488 propagateIfChanged(after,
489 after->insert(dest->getAliasClassesObject(), srcOrigins));
490}
491
492// TODO: rename from pointsto
494 ArrayAttr summaryAttr,
495 DenseMap<DistinctAttr, enzyme::ValueOriginSet> &summaryMap) {
496 // TODO: investigate better encodings for the value origin summary
497 for (auto pair : summaryAttr.getAsRange<ArrayAttr>()) {
498 assert(pair.size() == 2 &&
499 "Expected summary to be in [[key, value]] format");
500 auto pointer = cast<DistinctAttr>(pair[0]);
501 auto pointsToSet = enzyme::ValueOriginSet::getUndefined();
502 if (auto strAttr = dyn_cast<StringAttr>(pair[1])) {
503 if (strAttr.getValue() == "unknown") {
504 (void)pointsToSet.markUnknown();
505 } else {
506 assert(strAttr.getValue() == "undefined" &&
507 "unrecognized points-to destination");
508 }
509 } else {
510 auto pointsTo = cast<ArrayAttr>(pair[1]).getAsRange<enzyme::OriginAttr>();
511 (void)pointsToSet.insert(
512 DenseSet<enzyme::OriginAttr>(pointsTo.begin(), pointsTo.end()));
513 }
514
515 summaryMap.insert({pointer, pointsToSet});
516 }
517}
518
520 CallOpInterface call, dataflow::CallControlFlowAction action,
521 const ForwardOriginsMap &before, ForwardOriginsMap *after) {
522 join(after, before);
523 if (action == dataflow::CallControlFlowAction::ExternalCallee) {
524 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
525 if (!symbol)
526 return propagateIfChanged(after, after->markAllOriginsUnknown());
527
528 if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
529 call, symbol.getLeafReference())) {
530 if (auto summaryAttr = callee->getAttrOfType<ArrayAttr>(
531 EnzymeDialect::getDenseActivityAnnotationAttrName())) {
532 DenseMap<DistinctAttr, ValueOriginSet> summary;
533 deserializePointsTo(summaryAttr, summary);
534 return processCallToSummarizedFunc(call, summary, before, after);
535 }
536 }
537 }
538}
539
540void enzyme::DenseActivityAnnotationAnalysis::processCallToSummarizedFunc(
541 CallOpInterface call, const DenseMap<DistinctAttr, ValueOriginSet> &summary,
542 const ForwardOriginsMap &before, ForwardOriginsMap *after) {
543 ChangeResult changed = ChangeResult::NoChange;
544 ProgramPoint *point = getProgramPointAfter(call);
545 // Unify the value origin summary with the actual lattices of function
546 // arguments
547 // Collect the origins of the function arguments, then collect the alias
548 // classes of the destinations
549 auto *p2sets = getOrCreateFor<PointsToSets>(point, point);
550 SmallVector<ValueOriginSet> argumentOrigins;
551 SmallVector<AliasClassSet> argumentClasses;
552 for (auto &&[i, argOperand] : llvm::enumerate(call.getArgOperands())) {
553 // Value origin might be sparse, might be dense
554 ValueOriginSet argOrigins;
555 auto *argClasses = getOrCreateFor<AliasClassLattice>(point, argOperand);
556 if (argClasses->isUndefined()) {
557 // Not a pointer, use the sparse lattice state
558 auto *sparseOrigins =
559 getOrCreateFor<ForwardOriginsLattice>(point, argOperand);
560 (void)argOrigins.join(sparseOrigins->getOriginsObject());
561 } else {
562 // Unify all the origins
563 // Since we're not keeping track of argument depth, we need to union the
564 // arg origins with everything it points to.
565 traversePointsToSets(argClasses->getAliasClassesObject(), *p2sets,
566 [&](DistinctAttr aliasClass) {
567 (void)argOrigins.join(
568 before.getOrigins(aliasClass));
569 });
570 }
571 argumentClasses.push_back(argClasses->getAliasClassesObject());
572 argumentOrigins.push_back(argOrigins);
573 }
574
575 for (const auto &[destClass, sourceOrigins] : summary) {
576 ValueOriginSet callerOrigins;
577 for (Attribute sourceOrigin : sourceOrigins.getElements()) {
578 unsigned argNumber =
579 cast<ArgumentOriginAttr>(sourceOrigin).getArgNumber();
580 (void)callerOrigins.join(argumentOrigins[argNumber]);
581 }
582
583 AliasClassSet callerDestClasses;
584 if (auto pseudoClass = dyn_cast_if_present<PseudoAliasClassAttr>(
585 destClass.getReferencedAttr())) {
586 // Traverse the points-to sets.
587 AliasClassSet current = argumentClasses[pseudoClass.getArgNumber()];
588 unsigned depth = pseudoClass.getDepth();
589 while (depth > 0) {
590 AliasClassSet next;
591 if (current.isUndefined()) {
592 // Activity annotations requires converged pointer info. If the alias
593 // class is undefined, this could be because (1) the points-to info
594 // hasn't _yet_ been computed (in which case we bail out here
595 // expecting to be called again with more complete info), or (2) if
596 // the points-to info has converged, this signifies reading from
597 // uninitialized memory. This is UB, so we assume it never happens.
598 return;
599 }
600 for (DistinctAttr currentClass : current.getElements())
601 (void)next.join(p2sets->getPointsTo(currentClass));
602 std::swap(current, next);
603 depth--;
604 }
605
606 (void)callerDestClasses.join(current);
607 } else {
608 (void)callerDestClasses.insert({destClass});
609 }
610 changed |= after->insert(callerDestClasses, callerOrigins);
611 }
612 propagateIfChanged(after, changed);
613}
614
616 visitCallControlFlowTransfer(CallOpInterface call,
617 dataflow::CallControlFlowAction action,
618 const BackwardOriginsMap &after,
619 BackwardOriginsMap *before) {
620 meet(before, after);
621 if (action == dataflow::CallControlFlowAction::ExternalCallee) {
622 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
623 if (!symbol)
624 return propagateIfChanged(before, before->markAllOriginsUnknown());
625
626 if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
627 call, symbol.getLeafReference())) {
628 if (auto summaryAttr = callee->getAttrOfType<ArrayAttr>(
629 EnzymeDialect::getDenseActivityAnnotationAttrName())) {
630 DenseMap<DistinctAttr, ValueOriginSet> summary;
631 deserializePointsTo(summaryAttr, summary);
632 return processCallToSummarizedFunc(call, summary, after, before);
633 }
634 }
635 }
636}
637
639 BackwardOriginsMap *lattice) {
640 auto point = dyn_cast<ProgramPoint *>(lattice->getAnchor());
641 auto *block = point->getBlock();
642 if (!block)
643 return;
644
645 auto funcOp = cast<FunctionOpInterface>(block->getParentOp());
646 ChangeResult changed = ChangeResult::NoChange;
647 for (BlockArgument arg : funcOp.getArguments()) {
648 auto *pointsToSets = getOrCreateFor<PointsToSets>(
649 point, getProgramPointAfter(block->getTerminator()));
650 auto *argClass = getOrCreateFor<AliasClassLattice>(point, arg);
651 auto origin = ArgumentOriginAttr::get(FlatSymbolRefAttr::get(funcOp),
652 arg.getArgNumber());
653
654 // Everything that a pointer argument may point to originates from that
655 // pointer argument.
656 traversePointsToSets(argClass->getAliasClassesObject(), *pointsToSets,
657 [&](DistinctAttr currentClass) {
658 changed |=
659 lattice->insert(AliasClassSet(currentClass),
660 ValueOriginSet(origin));
661 });
662 }
663 propagateIfChanged(lattice, changed);
664}
665
667 Operation *op, const BackwardOriginsMap &after,
668 BackwardOriginsMap *before) {
669 meet(before, after);
670
671 auto activityIface = dyn_cast<enzyme::ActivityOpInterface>(op);
672 if (activityIface && activityIface.isInactive())
673 return success();
674
675 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
676 if (!memory) {
677 propagateIfChanged(before, before->markAllOriginsUnknown());
678 return success();
679 }
680
681 SmallVector<MemoryEffects::EffectInstance> effects;
682 memory.getEffects(effects);
683 for (const auto &effect : effects) {
684 if (!isa<MemoryEffects::Write>(effect.getEffect()))
685 continue;
686
687 Value value = effect.getValue();
688 if (!value) {
689 propagateIfChanged(before, before->markAllOriginsUnknown());
690 return success();
691 }
692
693 if (std::optional<Value> stored = getStored(op)) {
694 ProgramPoint *point = getProgramPointBefore(op);
695 auto *addressClasses = getOrCreateFor<AliasClassLattice>(point, value);
696 auto *storedClasses = getOrCreateFor<AliasClassLattice>(point, *stored);
697
698 if (storedClasses->isUndefined()) {
699 // Not a pointer being stored, do a sparse update
700 auto *storedOrigins = getOrCreate<BackwardOriginsLattice>(*stored);
701 propagateIfChanged(
702 storedOrigins,
703 addressClasses->getAliasClassesObject().foreachElement(
704 [&](DistinctAttr alloc, AliasClassSet::State state) {
705 if (state == AliasClassSet::State::Undefined) {
706 return ChangeResult::NoChange;
707 }
708 if (state == AliasClassSet::State::Unknown) {
709 return storedOrigins->markUnknown();
710 }
711 return storedOrigins->merge(after.getOrigins(alloc));
712 }));
713 } else if (storedClasses->isUnknown()) {
714 propagateIfChanged(before, before->markAllOriginsUnknown());
715 } else {
716 // Capturing stores are handled via the points-to relationship in
717 // setToExitState.
718 }
719 } else if (std::optional<Value> copySource = getCopySource(op)) {
720 processCopy(op, *copySource, value, after, before);
721 }
722 }
723 return success();
724}
725
726void enzyme::DenseBackwardActivityAnnotationAnalysis::
727 processCallToSummarizedFunc(
728 CallOpInterface call,
729 const DenseMap<DistinctAttr, ValueOriginSet> &summary,
730 const BackwardOriginsMap &after, BackwardOriginsMap *before) {
731 ChangeResult changed = ChangeResult::NoChange;
732 ProgramPoint *pointBefore = getProgramPointBefore(call);
733 ProgramPoint *pointAfter = getProgramPointAfter(call);
734 // Unify the value origin summary with the actual lattices of function
735 // arguments
736 auto *p2sets = getOrCreateFor<PointsToSets>(pointBefore, pointAfter);
737 SmallVector<AliasClassSet> argumentClasses;
738 for (Value argOperand : call.getArgOperands()) {
739 auto *argClasses =
740 getOrCreateFor<AliasClassLattice>(pointBefore, argOperand);
741 argumentClasses.push_back(argClasses->getAliasClassesObject());
742 }
743
744 for (const auto &[destClass, sourceOrigins] : summary) {
745 // Get the destination origins
746 ValueOriginSet destOrigins;
747 if (auto pseudoClass = dyn_cast_if_present<PseudoAliasClassAttr>(
748 destClass.getReferencedAttr())) {
749 traversePointsToSets(argumentClasses[pseudoClass.getArgNumber()], *p2sets,
750 [&](DistinctAttr aliasClass) {
751 (void)destOrigins.join(
752 after.getOrigins(aliasClass));
753 });
754 }
755
756 if (destOrigins.isUndefined())
757 continue;
758
759 // Get the source alias classes
760 AliasClassSet callerSourceClasses;
761 for (Attribute sourceOrigin : sourceOrigins.getElements()) {
762 unsigned argNumber =
763 cast<ArgumentOriginAttr>(sourceOrigin).getArgNumber();
764
765 if (argumentClasses[argNumber].isUndefined()) {
766 // Not a pointer, do a sparse update
767 auto *backwardLattice = getOrCreate<BackwardOriginsLattice>(
768 call.getArgOperands()[argNumber]);
769 if (destOrigins.isUnknown()) {
770 propagateIfChanged(backwardLattice, backwardLattice->markUnknown());
771 continue;
772 }
773 propagateIfChanged(backwardLattice,
774 backwardLattice->insert(destOrigins.getElements()));
775 } else {
776 traversePointsToSets(argumentClasses[argNumber], *p2sets,
777 [&](DistinctAttr aliasClass) {
778 (void)callerSourceClasses.insert({aliasClass});
779 });
780 }
781 }
782 changed |= before->insert(callerSourceClasses, destOrigins);
783 }
784 propagateIfChanged(before, changed);
785}
786
787void enzyme::DenseBackwardActivityAnnotationAnalysis::processCopy(
788 Operation *op, Value copySource, Value copyDest,
789 const BackwardOriginsMap &after, BackwardOriginsMap *before) {
790 ProgramPoint *point = getProgramPointBefore(op);
791 auto *dest = getOrCreateFor<AliasClassLattice>(point, copyDest);
792 ValueOriginSet destOrigins;
793 if (dest->isUndefined())
794 return;
795 if (dest->isUnknown())
796 (void)destOrigins.markUnknown();
797
798 for (DistinctAttr destClass : dest->getAliasClasses())
799 (void)destOrigins.join(after.getOrigins(destClass));
800
801 auto *src = getOrCreateFor<AliasClassLattice>(point, copySource);
802 propagateIfChanged(before,
803 before->insert(src->getAliasClassesObject(), destOrigins));
804}
805
806namespace {
807
808// TODO: the alias summary attribute is sufficent to get the correct behaviour
809// here, but it would be nice if these were not hardcoded.
810void annotateHardcoded(FunctionOpInterface func) {
811 if (func.getName() == "lgamma" || func.getName() == "tanh") {
812 MLIRContext *ctx = func.getContext();
813 SmallVector<Attribute> arr = {StringAttr::get(ctx, "<undefined>")};
814 func->setAttr(enzyme::EnzymeDialect::getAliasSummaryAttrName(),
815 ArrayAttr::get(ctx, arr));
816 }
817}
818
819/// Starting from callee, compute a reverse (bottom-up) topological sorting of
820/// all functions transitively called from callee.
821void reverseToposortCallgraph(CallableOpInterface callee,
822 SymbolTableCollection *symbolTable,
823 SmallVectorImpl<CallableOpInterface> &sorted) {
824 DenseSet<CallableOpInterface> permanent;
825 DenseSet<CallableOpInterface> temporary;
826 std::function<void(CallableOpInterface)> visit =
827 [&](CallableOpInterface node) {
828 if (permanent.contains(node))
829 return;
830 if (temporary.contains(node))
831 assert(false && "unimplemented cycle in call graph");
832
833 temporary.insert(node);
834 node.walk([&](CallOpInterface call) {
835 auto neighbour = cast<CallableOpInterface>(
836 call.resolveCallableInTable(symbolTable));
837 visit(neighbour);
838 });
839
840 temporary.erase(node);
841 permanent.insert(node);
842 sorted.push_back(node);
843 };
844
845 visit(callee);
846}
847
848void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func,
849 DataFlowSolver &solver) {
850 using namespace mlir::enzyme;
851
852 for (Operation &op : func.getCallableRegion()->getOps()) {
853 if (!op.hasTrait<OpTrait::ReturnLike>())
854 continue;
855
856 for (OpOperand &returnOperand : op.getOpOperands()) {
857 auto *lattice =
858 solver.getOrCreateState<BackwardOriginsLattice>(returnOperand.get());
859 auto origin = ReturnOriginAttr::get(FlatSymbolRefAttr::get(func),
860 returnOperand.getOperandNumber());
861 (void)lattice->insert({origin});
862 }
863 }
864}
865
866using OriginsPair =
867 std::pair<enzyme::ForwardOriginsLattice, enzyme::BackwardOriginsLattice>;
868
869/// Once having reached a top-level entry point, go top-down and convert the
870/// relative sources/sinks into concrete active/constant results.
871///
872/// This would ideally be done after lowering to LLVM and during differentiation
873/// because it loses context sensitivity, but this is faster to prototype with.
874void topDownActivityAnalysis(
875 FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities,
876 ArrayRef<enzyme::Activity> retActivities,
877 DenseMap<BlockArgument, OriginsPair> &blockArgOrigins) {
878 using namespace mlir::enzyme;
879 MLIRContext *ctx = callee.getContext();
880 callee->setAttr("enzyme.visited", UnitAttr::get(ctx));
881 auto trueAttr = BoolAttr::get(ctx, true);
882 auto falseAttr = BoolAttr::get(ctx, false);
883
884 auto isOriginActive = [&](OriginAttr origin) {
885 if (auto argOriginAttr = dyn_cast<ArgumentOriginAttr>(origin)) {
886 return llvm::is_contained({Activity::enzyme_dup,
887 Activity::enzyme_dupnoneed,
888 Activity::enzyme_active},
889 argActivities[argOriginAttr.getArgNumber()]);
890 }
891 auto retOriginAttr = cast<ReturnOriginAttr>(origin);
892 return llvm::is_contained({Activity::enzyme_dup, Activity::enzyme_dupnoneed,
893 Activity::enzyme_active},
894 retActivities[retOriginAttr.getReturnNumber()]);
895 };
896 callee.getFunctionBody().walk([&](Operation *op) {
897 if (op->getNumResults() == 0) {
898 // Operations that don't return values are definitionally "constant"
899 op->setAttr("enzyme.icv", trueAttr);
900 } else {
901 // Value activity
902 if (op->hasAttr("enzyme.constantval")) {
903 op->setAttr("enzyme.icv", trueAttr);
904 } else if (op->hasAttr("enzyme.activeval")) {
905 op->setAttr("enzyme.icv", falseAttr);
906 } else {
907 auto valueSource = op->getAttrOfType<ArrayAttr>("enzyme.valsrc");
908 auto valueSink = op->getAttrOfType<ArrayAttr>("enzyme.valsink");
909 if (!(valueSource && valueSink)) {
910 llvm::errs() << "[activity] missing attributes for op: " << *op
911 << "\n";
912 }
913 assert(valueSource && valueSink && "missing attributes for op");
914 bool activeSource =
915 llvm::any_of(valueSource.getAsRange<OriginAttr>(), isOriginActive);
916 bool activeSink =
917 llvm::any_of(valueSink.getAsRange<OriginAttr>(), isOriginActive);
918 bool activeVal = activeSource && activeSink;
919 op->setAttr("enzyme.icv", BoolAttr::get(ctx, !activeVal));
920 }
921 }
922 op->removeAttr("enzyme.constantval");
923 op->removeAttr("enzyme.activeval");
924 op->removeAttr("enzyme.valsrc");
925 op->removeAttr("enzyme.valsink");
926
927 // Instruction activity
928 if (op->hasAttr("enzyme.constantop")) {
929 op->setAttr("enzyme.ici", trueAttr);
930 } else if (op->hasAttr("enzyme.activeop")) {
931 op->setAttr("enzyme.ici", falseAttr);
932 } else {
933 bool activeSource = llvm::any_of(
934 op->getAttrOfType<ArrayAttr>("enzyme.opsrc").getAsRange<OriginAttr>(),
935 isOriginActive);
936 bool activeSink =
937 llvm::any_of(op->getAttrOfType<ArrayAttr>("enzyme.opsink")
938 .getAsRange<OriginAttr>(),
939 isOriginActive);
940 bool activeOp = activeSource && activeSink;
941 op->setAttr("enzyme.ici", BoolAttr::get(ctx, !activeOp));
942 }
943
944 op->removeAttr("enzyme.constantop");
945 op->removeAttr("enzyme.activeop");
946 op->removeAttr("enzyme.opsrc");
947 op->removeAttr("enzyme.opsink");
948
949 if (auto callOp = dyn_cast<CallOpInterface>(op)) {
950 auto funcOp = cast<FunctionOpInterface>(callOp.resolveCallable());
951 if (!funcOp->hasAttr("enzyme.visited")) {
952 SmallVector<Activity> callArgActivities, callResActivities;
953 for (Value operand : callOp.getArgOperands()) {
954 if (auto *definingOp = operand.getDefiningOp()) {
955 bool icv =
956 definingOp->getAttrOfType<BoolAttr>("enzyme.icv").getValue();
957 callArgActivities.push_back(icv ? Activity::enzyme_const
958 : Activity::enzyme_active);
959 } else {
960 BlockArgument blockArg = cast<BlockArgument>(operand);
961 const OriginsPair &originsPair = blockArgOrigins.at(blockArg);
962 const ForwardOriginsLattice &sources = originsPair.first;
963 const BackwardOriginsLattice &sinks = originsPair.second;
964 bool argActive = false;
965 if (sources.isUnknown() || sinks.isUnknown()) {
966 argActive = true;
967 } else if (sources.isUndefined() || sinks.isUndefined()) {
968 argActive = false;
969 } else {
970 argActive = llvm::any_of(sources.getOrigins(), isOriginActive) &&
971 llvm::any_of(sinks.getOrigins(), isOriginActive);
972 }
973 callArgActivities.push_back(argActive ? Activity::enzyme_active
974 : Activity::enzyme_const);
975 }
976 }
977 if (op->getNumResults() != 0) {
978 bool icv = op->getAttrOfType<BoolAttr>("enzyme.icv").getValue();
979 callResActivities.push_back(icv ? Activity::enzyme_const
980 : Activity::enzyme_active);
981 }
982
983 topDownActivityAnalysis(funcOp, callArgActivities, callResActivities,
984 blockArgOrigins);
985 }
986 }
987 });
988}
989} // namespace
990
992 FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities,
993 const ActivityPrinterConfig &activityConfig) {
994 SymbolTableCollection symbolTable;
995 SmallVector<CallableOpInterface> sorted;
996 reverseToposortCallgraph(callee, &symbolTable, sorted);
997 raw_ostream &os = llvm::outs();
998
999 // TODO: is there any way of serializing information in a block argument?
1000 DenseMap<BlockArgument, OriginsPair> blockArgOrigins;
1001
1002 StringRef pointerSummaryName = EnzymeDialect::getPointerSummaryAttrName();
1003 for (CallableOpInterface node : sorted) {
1004 annotateHardcoded(cast<FunctionOpInterface>(node.getOperation()));
1005
1006 if (!node.getCallableRegion() || node->hasAttr(pointerSummaryName))
1007 continue;
1008 auto funcOp = cast<FunctionOpInterface>(node.getOperation());
1009 if (activityConfig.verbose) {
1010 os << "[ata] processing function @" << funcOp.getName() << "\n";
1011 }
1012 DataFlowConfig dataFlowConfig;
1013 dataFlowConfig.setInterprocedural(false);
1014 DataFlowSolver solver(dataFlowConfig);
1015 SymbolTableCollection symbolTable;
1016
1017 solver.load<dataflow::SparseConstantPropagation>();
1018 solver.load<dataflow::DeadCodeAnalysis>();
1019 solver.load<enzyme::AliasAnalysis>(callee.getContext(),
1020 /*relative=*/true);
1021 solver.load<enzyme::PointsToPointerAnalysis>();
1024 solver.load<enzyme::BackwardActivityAnnotationAnalysis>(symbolTable);
1025 solver.load<enzyme::DenseBackwardActivityAnnotationAnalysis>(symbolTable);
1026
1027 initializeSparseBackwardActivityAnnotations(funcOp, solver);
1028
1029 if (failed(solver.initializeAndRun(node))) {
1030 assert(false && "dataflow solver failed");
1031 }
1032
1033 // Create the overall summary by joining sets at all return sites.
1034 enzyme::PointsToSets p2sets(nullptr);
1035 enzyme::ForwardOriginsMap forwardOriginsMap(nullptr);
1036 size_t numResults = node.getResultTypes().size();
1037 SmallVector<enzyme::ForwardOriginsLattice> returnOperandOrigins(
1038 numResults, ForwardOriginsLattice(nullptr));
1039 SmallVector<enzyme::AliasClassLattice> returnAliasClasses(
1040 numResults, AliasClassLattice(nullptr));
1041
1042 for (Operation &op : node.getCallableRegion()->getOps()) {
1043 if (op.hasTrait<OpTrait::ReturnLike>()) {
1044 ProgramPoint *point = solver.getProgramPointAfter(&op);
1045 (void)p2sets.join(*solver.lookupState<enzyme::PointsToSets>(point));
1046 auto *returnOrigins =
1047 solver.lookupState<enzyme::ForwardOriginsMap>(point);
1048 if (returnOrigins)
1049 (void)forwardOriginsMap.join(*returnOrigins);
1050
1051 for (OpOperand &operand : op.getOpOperands()) {
1052 (void)returnAliasClasses[operand.getOperandNumber()].join(
1053 *solver.lookupState<enzyme::AliasClassLattice>(operand.get()));
1054 (void)returnOperandOrigins[operand.getOperandNumber()].join(
1055 *solver.lookupState<enzyme::ForwardOriginsLattice>(
1056 operand.get()));
1057 }
1058 }
1059 }
1060
1061 // Sparse alias annotations
1062 SmallVector<Attribute> aliasAttributes(returnAliasClasses.size());
1063 llvm::transform(returnAliasClasses, aliasAttributes.begin(),
1064 [&](enzyme::AliasClassLattice lattice) {
1065 return lattice.serialize(node.getContext());
1066 });
1067 node->setAttr(EnzymeDialect::getAliasSummaryAttrName(),
1068 ArrayAttr::get(node.getContext(), aliasAttributes));
1069
1070 // Points-to-pointer annotations
1071 node->setAttr(pointerSummaryName, p2sets.serialize(node.getContext()));
1072 if (activityConfig.verbose) {
1073 os << "[ata] p2p summary:\n";
1074 if (node->getAttrOfType<ArrayAttr>(pointerSummaryName).size() == 0) {
1075 os << " <empty>\n";
1076 }
1077 for (ArrayAttr pair : node->getAttrOfType<ArrayAttr>(pointerSummaryName)
1078 .getAsRange<ArrayAttr>()) {
1079 os << " " << pair[0] << " -> " << pair[1] << "\n";
1080 }
1081 }
1082
1083 node->setAttr(EnzymeDialect::getDenseActivityAnnotationAttrName(),
1084 forwardOriginsMap.serialize(node.getContext()));
1085 if (activityConfig.verbose) {
1086 os << "[ata] forward value origins:\n";
1087 for (ArrayAttr pair :
1088 node->getAttrOfType<ArrayAttr>(
1089 EnzymeDialect::getDenseActivityAnnotationAttrName())
1090 .getAsRange<ArrayAttr>()) {
1091 os << " " << pair[0] << " originates from " << pair[1] << "\n";
1092 }
1093 }
1094
1095 auto *backwardOriginsMap =
1096 solver.getOrCreateState<enzyme::BackwardOriginsMap>(
1097 solver.getProgramPointBefore(
1098 &node.getCallableRegion()->front().front()));
1099 Attribute backwardOrigins =
1100 backwardOriginsMap->serialize(node.getContext());
1101 if (activityConfig.verbose) {
1102 os << "[ata] backward value origins:\n";
1103 for (ArrayAttr pair :
1104 cast<ArrayAttr>(backwardOrigins).getAsRange<ArrayAttr>()) {
1105 os << " " << pair[0] << " goes to " << pair[1] << "\n";
1106 }
1107 }
1108
1109 // Serialize return origins
1110 MLIRContext *ctx = node.getContext();
1111 SmallVector<Attribute> serializedReturnOperandOrigins(
1112 returnOperandOrigins.size());
1113 llvm::transform(returnOperandOrigins,
1114 serializedReturnOperandOrigins.begin(),
1115 [ctx](enzyme::ForwardOriginsLattice lattice) -> Attribute {
1116 return lattice.serialize(ctx);
1117 });
1118 node->setAttr(
1119 EnzymeDialect::getSparseActivityAnnotationAttrName(),
1120 ArrayAttr::get(node.getContext(), serializedReturnOperandOrigins));
1121 if (activityConfig.verbose) {
1122 os << "[ata] return origins: "
1123 << node->getAttr(EnzymeDialect::getSparseActivityAnnotationAttrName())
1124 << "\n";
1125 }
1126
1127 auto joinActiveDataState =
1128 [&](Value value,
1129 std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
1130 auto *sources = solver.getOrCreateState<ForwardOriginsLattice>(value);
1131 auto *sinks = solver.getOrCreateState<BackwardOriginsLattice>(value);
1132 (void)out.first.join(*sources);
1133 (void)out.second.meet(*sinks);
1134 };
1135
1136 auto joinActivePointerState =
1137 [&](const AliasClassSet &aliasClasses,
1138 std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
1140 aliasClasses, p2sets, [&](DistinctAttr aliasClass) {
1141 (void)out.first.merge(forwardOriginsMap.getOrigins(aliasClass));
1142 (void)out.second.merge(
1143 backwardOriginsMap->getOrigins(aliasClass));
1144 });
1145 };
1146
1147 auto joinActiveValueState =
1148 [&](Value value,
1149 std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
1150 if (isa<LLVM::LLVMPointerType, MemRefType>(value.getType())) {
1151 auto *aliasClasses =
1152 solver.getOrCreateState<AliasClassLattice>(value);
1153 joinActivePointerState(aliasClasses->getAliasClassesObject(), out);
1154 } else {
1155 joinActiveDataState(value, out);
1156 }
1157 };
1158
1159 auto annotateActivity = [&](Operation *op) {
1160 assert(op->getNumResults() < 2 && op->getNumRegions() == 0 &&
1161 "annotation only supports the LLVM dialect");
1162 auto unitAttr = UnitAttr::get(ctx);
1163 // Check activity of values
1164 for (OpResult result : op->getResults()) {
1165 std::pair<ForwardOriginsLattice, BackwardOriginsLattice>
1166 activityAttributes({result, ValueOriginSet()},
1167 {result, ValueOriginSet()});
1168 joinActiveValueState(result, activityAttributes);
1169 const auto &sources = activityAttributes.first;
1170 const auto &sinks = activityAttributes.second;
1171 // Possible states: if either source or sink is undefined or empty, the
1172 // value is always constant.
1173 if (sources.isUnknown() || sinks.isUnknown()) {
1174 // Always active
1175 op->setAttr("enzyme.activeval", unitAttr);
1176 } else if (sources.isUndefined() || sinks.isUndefined()) {
1177 // Always constant
1178 op->setAttr("enzyme.constantval", unitAttr);
1179 } else {
1180 // Conditionally active depending on the activity of sources and sinks
1181 op->setAttr("enzyme.valsrc", sources.serialize(ctx));
1182 op->setAttr("enzyme.valsink", sinks.serialize(ctx));
1183 }
1184 }
1185 // Check activity of operation
1186 StringRef opSourceAttrName = "enzyme.opsrc";
1187 StringRef opSinkAttrName = "enzyme.opsink";
1188 std::pair<ForwardOriginsLattice, BackwardOriginsLattice> opAttributes(
1189 {nullptr, ValueOriginSet()}, {nullptr, ValueOriginSet()});
1190 if (isPure(op)) {
1191 // A pure operation can only propagate data via its results
1192 for (OpResult result : op->getResults()) {
1193 joinActiveDataState(result, opAttributes);
1194 }
1195 } else {
1196 // We need a special case because stores of active pointers don't fit
1197 // the definition but are active instructions
1198 if (auto storeOp = dyn_cast<LLVM::StoreOp>(op)) {
1199 auto *storedClass =
1200 solver.getOrCreateState<AliasClassLattice>(storeOp.getValue());
1201 joinActivePointerState(storedClass->getAliasClassesObject(),
1202 opAttributes);
1203 } else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
1204 // TODO: tricky, requires some thought
1205 auto callable = cast<CallableOpInterface>(callOp.resolveCallable());
1206 if (callable->hasAttr(
1207 EnzymeDialect::getDenseActivityAnnotationAttrName())) {
1208 for (Value operand : callOp.getArgOperands())
1209 joinActiveValueState(operand, opAttributes);
1210 }
1211 // We need to
1212 // determine if the body of the function contains active instructions
1213 }
1214
1215 // Default: the op is active iff any of its operands or results are
1216 // active data.
1217 for (Value operand : op->getOperands())
1218 joinActiveDataState(operand, opAttributes);
1219 for (OpResult result : op->getResults())
1220 joinActiveDataState(result, opAttributes);
1221 }
1222
1223 const auto &opSources = opAttributes.first;
1224 const auto &opSinks = opAttributes.second;
1225 if (opSources.isUnknown() || opSinks.isUnknown()) {
1226 op->setAttr("enzyme.activeop", unitAttr);
1227 } else if (opSources.isUndefined() || opSinks.isUndefined()) {
1228 op->setAttr("enzyme.constantop", unitAttr);
1229 } else {
1230 op->setAttr(opSourceAttrName, opAttributes.first.serialize(ctx));
1231 op->setAttr(opSinkAttrName, opAttributes.second.serialize(ctx));
1232 }
1233 };
1234
1235 // We lose the solver state when going top down and I don't know a better
1236 // way to serialize block argument information.
1237 node.getCallableRegion()->walk([&](Block *block) {
1238 for (BlockArgument blockArg : block->getArguments()) {
1239 OriginsPair blockArgAttributes({blockArg, ValueOriginSet()},
1240 {blockArg, ValueOriginSet()});
1241 joinActiveValueState(blockArg, blockArgAttributes);
1242 blockArgOrigins.try_emplace(blockArg, blockArgAttributes);
1243 }
1244 });
1245
1246 node.getCallableRegion()->walk([&](Operation *op) {
1247 if (activityConfig.annotate)
1248 annotateActivity(op);
1249 if (activityConfig.verbose) {
1250 if (op->hasAttr("tag")) {
1251 for (OpResult result : op->getResults()) {
1252 std::pair<ForwardOriginsLattice, BackwardOriginsLattice>
1253 activityAttributes({result, ValueOriginSet()},
1254 {result, ValueOriginSet()});
1255 joinActiveValueState(result, activityAttributes);
1256 os << op->getAttr("tag") << "(#" << result.getResultNumber()
1257 << ")\n"
1258 << " sources: " << activityAttributes.first.serialize(ctx)
1259 << "\n"
1260 << " sinks: " << activityAttributes.second.serialize(ctx)
1261 << "\n";
1262 }
1263 }
1264 }
1265 });
1266 }
1267
1268 if (!argActivities.empty() && activityConfig.annotate) {
1269 SmallVector<enzyme::Activity> resActivities;
1270 for (Type resultType : callee.getResultTypes()) {
1271 resActivities.push_back(isa<FloatType, ComplexType>(resultType)
1272 ? Activity::enzyme_active
1273 : Activity::enzyme_const);
1274 }
1275
1276 topDownActivityAnalysis(callee, argActivities, resActivities,
1277 blockArgOrigins);
1278 }
1279}
static void deserializePointsTo(ArrayAttr summaryAttr, DenseMap< DistinctAttr, enzyme::ValueOriginSet > &summaryMap)
std::optional< Value > getStored(Operation *op)
static bool isPossiblyActive(Type type)
static void traversePointsToSets(const enzyme::AliasClassSet &start, const enzyme::PointsToSets &pointsToSets, function_ref< void(DistinctAttr)> visit)
Visit everything transitively pointed-to by any pointer in start.
void deserializeReturnOrigins(ArrayAttr returnOrigins, SmallVectorImpl< enzyme::ValueOriginSet > &out)
static bool isFullyActive(Operation *op)
True iff all results differentially depend on all operands.
std::optional< Value > getCopySource(Operation *op)
void printMapOfSetsLattice(const DenseMap< KeyT, enzyme::SetLattice< ElementT > > map, raw_ostream &os)
bool annotate
Annotate the IR with activity information for every operation.
bool verbose
Output extra information for debugging.
This analysis implements interprocedural alias analysis.
void visitExternalCall(CallOpInterface call, ArrayRef< BackwardOriginsLattice * > operands, ArrayRef< const BackwardOriginsLattice * > results) override
void setToExitState(BackwardOriginsLattice *lattice) override
LogicalResult visitOperation(Operation *op, ArrayRef< BackwardOriginsLattice * > operands, ArrayRef< const BackwardOriginsLattice * > results) override
void print(raw_ostream &os) const override
ChangeResult meet(const AbstractSparseLattice &other) override
const DenseSet< OriginAttr > & getOrigins() const
void print(raw_ostream &os) const override
const ValueOriginSet & getOrigins(DistinctAttr id) const
LogicalResult visitOperation(Operation *op, const ForwardOriginsMap &before, ForwardOriginsMap *after) override
void visitCallControlFlowTransfer(CallOpInterface call, dataflow::CallControlFlowAction action, const ForwardOriginsMap &before, ForwardOriginsMap *after) override
void setToEntryState(ForwardOriginsMap *lattice) override
void setToExitState(BackwardOriginsMap *lattice) override
LogicalResult visitOperation(Operation *op, const BackwardOriginsMap &after, BackwardOriginsMap *before) override
void visitCallControlFlowTransfer(CallOpInterface call, dataflow::CallControlFlowAction action, const BackwardOriginsMap &after, BackwardOriginsMap *before) override
void setToEntryState(ForwardOriginsLattice *lattice) override
void visitExternalCall(CallOpInterface call, ArrayRef< const ForwardOriginsLattice * > operands, ArrayRef< ForwardOriginsLattice * > results) override
LogicalResult visitOperation(Operation *op, ArrayRef< const ForwardOriginsLattice * > operands, ArrayRef< ForwardOriginsLattice * > results) override
ChangeResult join(const AbstractSparseLattice &other) override
const DenseSet< OriginAttr > & getOrigins() const
void print(raw_ostream &os) const override
static ForwardOriginsLattice single(Value point, OriginAttr value)
const ValueOriginSet & getOrigins(DistinctAttr id) const
void print(raw_ostream &os) const override
ChangeResult join(const AbstractDenseLattice &other)
Attribute serialize(MLIRContext *ctx) const
ChangeResult insert(const SetLattice< KeyT > &keysToUpdate, const SetLattice< ElementT > &values)
Map all keys to all values.
const AliasClassSet & getPointsTo(DistinctAttr id) const
ChangeResult insert(const DenseSet< ValueT > &newElements)
ChangeResult join(const SetLattice< ValueT > &other)
static const SetLattice< OriginAttr > & getUndefined()
static const SetLattice< OriginAttr > & getUnknown()
Attribute serialize(MLIRContext *ctx) const
SetLattice< OriginAttr > ValueOriginSet
SetLattice< DistinctAttr > AliasClassSet
A set of alias class identifiers to be treated as a single union.
void runActivityAnnotations(FunctionOpInterface callee, ArrayRef< enzyme::Activity > argActivities={}, const ActivityPrinterConfig &config=ActivityPrinterConfig())