8#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
9#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
10#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
11#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
12#include "mlir/Analysis/DataFlowFramework.h"
13#include "mlir/Interfaces/FunctionInterfaces.h"
14#include "llvm/Support/raw_ostream.h"
18#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
23 return isa<FloatType, ComplexType>(type);
27 os <<
serialize(getAnchor().getContext());
31 os << serialize(getAnchor().getContext());
36 const auto *otherValueOrigins =
38 return elements.
join(otherValueOrigins->elements);
43 auto arg = dyn_cast<BlockArgument>(lattice->getAnchor());
52 auto funcOp = cast<FunctionOpInterface>(arg.getOwner()->getParentOp());
53 auto origin = ArgumentOriginAttr::get(FlatSymbolRefAttr::get(funcOp),
55 return propagateIfChanged(
60void enzyme::ForwardActivityAnnotationAnalysis::markResultsUnknown(
61 ArrayRef<ForwardOriginsLattice *> results) {
63 propagateIfChanged(result, result->markUnknown());
71 return isa<LLVM::FMulOp, LLVM::FAddOp, LLVM::FDivOp, LLVM::FSubOp,
72 LLVM::FNegOp, LLVM::FAbsOp, LLVM::SqrtOp, LLVM::SinOp, LLVM::CosOp,
73 LLVM::Exp2Op, LLVM::ExpOp, LLVM::LogOp, LLVM::InsertValueOp,
74 LLVM::ExtractValueOp, LLVM::BitcastOp, LLVM::SelectOp>(op);
78 Operation *op, ArrayRef<const ForwardOriginsLattice *> operands,
79 ArrayRef<ForwardOriginsLattice *> results) {
83 join(result, *operand);
89 auto activityIface = dyn_cast<enzyme::ActivityOpInterface>(op);
90 if (isPure(op) || (activityIface && activityIface.isInactive()))
93 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
95 markResultsUnknown(results);
99 SmallVector<MemoryEffects::EffectInstance> effects;
100 memory.getEffects(effects);
101 for (
const auto &effect : effects) {
102 if (!isa<MemoryEffects::Read>(effect.getEffect()))
105 Value value = effect.getValue();
107 markResultsUnknown(results);
110 processMemoryRead(op, value, results);
115void enzyme::ForwardActivityAnnotationAnalysis::processMemoryRead(
116 Operation *op, Value address, ArrayRef<ForwardOriginsLattice *> results) {
117 ProgramPoint *point = getProgramPointAfter(op);
118 auto *srcClasses = getOrCreateFor<AliasClassLattice>(point, address);
119 auto *originsMap = getOrCreateFor<ForwardOriginsMap>(point, point);
120 if (srcClasses->isUndefined())
122 if (srcClasses->isUnknown())
123 return markResultsUnknown(results);
127 for (DistinctAttr srcClass : srcClasses->getAliasClasses()) {
130 propagateIfChanged(result,
131 result->merge(originsMap->getOrigins(srcClass)));
138 SmallVectorImpl<enzyme::ValueOriginSet> &out) {
139 for (
auto &&[resultIdx, argOrigins] : llvm::enumerate(returnOrigins)) {
141 if (
auto strAttr = dyn_cast<StringAttr>(argOrigins)) {
142 if (strAttr.getValue() ==
"<unknown>") {
148 for (enzyme::ArgumentOriginAttr originAttr :
149 cast<ArrayAttr>(argOrigins)
150 .getAsRange<enzyme::ArgumentOriginAttr>()) {
151 (void)origins.
insert({originAttr});
155 out.push_back(origins);
160 CallOpInterface call, ArrayRef<const ForwardOriginsLattice *> operands,
161 ArrayRef<ForwardOriginsLattice *> results) {
162 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
163 auto markAllResultsUnknown = [&]() {
165 propagateIfChanged(result, result->markUnknown());
169 return markAllResultsUnknown();
171 if (
auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
172 call, symbol.getLeafReference())) {
173 if (
auto returnOriginsAttr = callee->getAttrOfType<ArrayAttr>(
174 EnzymeDialect::getSparseActivityAnnotationAttrName())) {
175 SmallVector<ValueOriginSet> returnOrigins;
177 return processCallToSummarizedFunc(call, returnOrigins, operands,
186 join(result, *operand);
192 function_ref<
void(DistinctAttr)> visit) {
195 while (!current.isUndefined()) {
198 assert(!current.isUnknown() &&
"Unhandled traversal of unknown");
199 for (DistinctAttr currentClass : current.getElements()) {
201 (void)next.join(pointsToSets.
getPointsTo(currentClass));
203 std::swap(current, next);
207void enzyme::ForwardActivityAnnotationAnalysis::processCallToSummarizedFunc(
208 CallOpInterface call, ArrayRef<ValueOriginSet> summary,
209 ArrayRef<const ForwardOriginsLattice *> operands,
210 ArrayRef<ForwardOriginsLattice *> results) {
211 for (
const auto &[result, returnOrigin] : llvm::zip(results, summary)) {
214 if (returnOrigin.isUndefined())
217 if (returnOrigin.isUnknown()) {
218 (void)callerOrigins.markUnknown();
220 ProgramPoint *point = getProgramPointAfter(call);
221 auto *denseOrigins = getOrCreateFor<ForwardOriginsMap>(point, point);
222 auto *pointsTo = getOrCreateFor<PointsToSets>(point, point);
223 (void)returnOrigin.foreachElement(
225 assert(state == ValueOriginSet::State::Defined &&
226 "undefined and unknown must have been handled above");
227 auto calleeArgOrigin = cast<ArgumentOriginAttr>(calleeOrigin);
229 const ForwardOriginsLattice *operandOrigins =
230 operands[calleeArgOrigin.getArgNumber()];
231 auto *callerAliasClass = getOrCreateFor<AliasClassLattice>(
232 getProgramPointAfter(call), operandOrigins->getAnchor());
233 traversePointsToSets(callerAliasClass->getAliasClassesObject(),
234 *pointsTo, [&](DistinctAttr aliasClass) {
235 (void)callerOrigins.join(
236 denseOrigins->getOrigins(aliasClass));
238 return callerOrigins.join(operandOrigins->getOriginsObject());
241 propagateIfChanged(result, result->merge(callerOrigins));
247 propagateIfChanged(lattice, lattice->
markUnknown());
250void enzyme::BackwardActivityAnnotationAnalysis::markOperandsUnknown(
251 ArrayRef<BackwardOriginsLattice *> operands) {
253 propagateIfChanged(operand, operand->markUnknown());
258 Operation *op, ArrayRef<BackwardOriginsLattice *> operands,
259 ArrayRef<const BackwardOriginsLattice *> results) {
263 meet(operand, *result);
266 auto activityIface = dyn_cast<enzyme::ActivityOpInterface>(op);
267 if (isPure(op) || (activityIface && activityIface.isInactive()))
270 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
272 markOperandsUnknown(operands);
276 SmallVector<MemoryEffects::EffectInstance> effects;
277 memory.getEffects(effects);
278 for (
const auto &effect : effects) {
279 if (!isa<MemoryEffects::Read>(effect.getEffect()))
282 Value value = effect.getValue();
284 markOperandsUnknown(operands);
289 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
291 getOrCreate<BackwardOriginsMap>(getProgramPointBefore(op));
293 ChangeResult changed = ChangeResult::NoChange;
295 changed |= originsMap->insert(srcClasses->getAliasClassesObject(),
296 result->getOriginsObject());
297 propagateIfChanged(originsMap, changed);
303 CallOpInterface call, ArrayRef<BackwardOriginsLattice *> operands,
304 ArrayRef<const BackwardOriginsLattice *> results) {
305 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
307 return markOperandsUnknown(operands);
309 if (
auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
310 call, symbol.getLeafReference())) {
311 if (
auto returnOriginsAttr = callee->getAttrOfType<ArrayAttr>(
312 EnzymeDialect::getSparseActivityAnnotationAttrName())) {
313 SmallVector<ValueOriginSet> returnOrigins;
315 return processCallToSummarizedFunc(call, returnOrigins, operands,
324 meet(operand, *result);
327void enzyme::BackwardActivityAnnotationAnalysis::processCallToSummarizedFunc(
328 CallOpInterface call, ArrayRef<ValueOriginSet> summary,
329 ArrayRef<BackwardOriginsLattice *> operands,
330 ArrayRef<const BackwardOriginsLattice *> results) {
332 for (
const auto &[result, calleeOrigins] : llvm::zip(results, summary)) {
334 if (calleeOrigins.isUndefined())
336 if (calleeOrigins.isUnknown())
339 (void)calleeOrigins.foreachElement(
341 auto calleeArgOrigin = cast<ArgumentOriginAttr>(calleeOrigin);
342 BackwardOriginsLattice *operand =
343 operands[calleeArgOrigin.getArgNumber()];
344 propagateIfChanged(operand, operand->merge(resultOrigins));
345 return ChangeResult::NoChange;
351template <
typename KeyT,
typename ElementT>
358 for (
const auto &[aliasClass, origins] : map) {
359 os <<
" " << aliasClass <<
" originates from {";
360 if (origins.isUnknown()) {
362 }
else if (origins.isUndefined()) {
365 llvm::interleaveComma(origins.getElements(), os);
381 auto point = dyn_cast<ProgramPoint *>(lattice->getAnchor());
382 auto *block = point->getBlock();
386 auto funcOp = cast<FunctionOpInterface>(block->getParentOp());
387 ChangeResult changed = ChangeResult::NoChange;
388 for (BlockArgument arg : funcOp.getArguments()) {
389 auto *argClass = getOrCreateFor<AliasClassLattice>(point, arg);
390 auto origin = ArgumentOriginAttr::get(FlatSymbolRefAttr::get(funcOp),
392 changed |= lattice->
insert(argClass->getAliasClassesObject(),
395 propagateIfChanged(lattice, changed);
398std::optional<Value>
getStored(Operation *op);
405 auto activityIface = dyn_cast<enzyme::ActivityOpInterface>(op);
406 if (activityIface && activityIface.isInactive())
409 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
415 SmallVector<MemoryEffects::EffectInstance> effects;
416 memory.getEffects(effects);
417 for (
const auto &effect : effects) {
418 Value value = effect.getValue();
424 if (isa<MemoryEffects::Read>(effect.getEffect())) {
426 if (op->getNumResults() != 1)
428 Value readDest = op->getResult(0);
431 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), readDest);
432 if (destClasses->isUndefined())
437 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
438 if (srcClasses->isUnknown()) {
439 propagateIfChanged(after,
440 after->
insert(destClasses->getAliasClassesObject(),
445 ChangeResult changed = ChangeResult::NoChange;
446 for (DistinctAttr srcClass : srcClasses->getAliasClasses()) {
447 changed |= after->
insert(destClasses->getAliasClassesObject(),
450 propagateIfChanged(after, changed);
451 }
else if (isa<MemoryEffects::Write>(effect.getEffect())) {
452 if (std::optional<Value> stored =
getStored(op)) {
456 auto *origins = getOrCreateFor<ForwardOriginsLattice>(
457 getProgramPointAfter(op), *stored);
459 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
460 propagateIfChanged(after, after->
insert(dest->getAliasClassesObject(),
461 origins->getOriginsObject()));
462 }
else if (std::optional<Value> copySource =
getCopySource(op)) {
463 processCopy(op, *copySource, value, before, after);
472void enzyme::DenseActivityAnnotationAnalysis::processCopy(
473 Operation *op, Value copySource, Value copyDest,
476 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), copySource);
478 if (src->isUndefined())
480 if (src->isUnknown())
483 for (DistinctAttr srcClass : src->getAliasClasses())
487 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), copyDest);
488 propagateIfChanged(after,
489 after->
insert(dest->getAliasClassesObject(), srcOrigins));
494 ArrayAttr summaryAttr,
495 DenseMap<DistinctAttr, enzyme::ValueOriginSet> &summaryMap) {
497 for (
auto pair : summaryAttr.getAsRange<ArrayAttr>()) {
498 assert(pair.size() == 2 &&
499 "Expected summary to be in [[key, value]] format");
500 auto pointer = cast<DistinctAttr>(pair[0]);
502 if (
auto strAttr = dyn_cast<StringAttr>(pair[1])) {
503 if (strAttr.getValue() ==
"unknown") {
504 (void)pointsToSet.markUnknown();
506 assert(strAttr.getValue() ==
"undefined" &&
507 "unrecognized points-to destination");
510 auto pointsTo = cast<ArrayAttr>(pair[1]).getAsRange<enzyme::OriginAttr>();
511 (void)pointsToSet.insert(
512 DenseSet<enzyme::OriginAttr>(pointsTo.begin(), pointsTo.end()));
515 summaryMap.insert({pointer, pointsToSet});
520 CallOpInterface call, dataflow::CallControlFlowAction action,
523 if (action == dataflow::CallControlFlowAction::ExternalCallee) {
524 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
528 if (
auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
529 call, symbol.getLeafReference())) {
530 if (
auto summaryAttr = callee->getAttrOfType<ArrayAttr>(
531 EnzymeDialect::getDenseActivityAnnotationAttrName())) {
532 DenseMap<DistinctAttr, ValueOriginSet> summary;
534 return processCallToSummarizedFunc(call, summary, before, after);
540void enzyme::DenseActivityAnnotationAnalysis::processCallToSummarizedFunc(
541 CallOpInterface call,
const DenseMap<DistinctAttr, ValueOriginSet> &summary,
543 ChangeResult changed = ChangeResult::NoChange;
544 ProgramPoint *point = getProgramPointAfter(call);
549 auto *p2sets = getOrCreateFor<PointsToSets>(point, point);
550 SmallVector<ValueOriginSet> argumentOrigins;
551 SmallVector<AliasClassSet> argumentClasses;
552 for (
auto &&[i, argOperand] : llvm::enumerate(call.getArgOperands())) {
555 auto *argClasses = getOrCreateFor<AliasClassLattice>(point, argOperand);
556 if (argClasses->isUndefined()) {
558 auto *sparseOrigins =
559 getOrCreateFor<ForwardOriginsLattice>(point, argOperand);
560 (void)argOrigins.
join(sparseOrigins->getOriginsObject());
566 [&](DistinctAttr aliasClass) {
567 (void)argOrigins.join(
568 before.getOrigins(aliasClass));
571 argumentClasses.push_back(argClasses->getAliasClassesObject());
572 argumentOrigins.push_back(argOrigins);
575 for (
const auto &[destClass, sourceOrigins] : summary) {
577 for (Attribute sourceOrigin : sourceOrigins.getElements()) {
579 cast<ArgumentOriginAttr>(sourceOrigin).getArgNumber();
580 (void)callerOrigins.join(argumentOrigins[argNumber]);
584 if (
auto pseudoClass = dyn_cast_if_present<PseudoAliasClassAttr>(
585 destClass.getReferencedAttr())) {
587 AliasClassSet current = argumentClasses[pseudoClass.getArgNumber()];
588 unsigned depth = pseudoClass.getDepth();
591 if (current.isUndefined()) {
600 for (DistinctAttr currentClass : current.getElements())
601 (void)next.
join(p2sets->getPointsTo(currentClass));
602 std::swap(current, next);
606 (void)callerDestClasses.
join(current);
608 (void)callerDestClasses.insert({destClass});
610 changed |= after->
insert(callerDestClasses, callerOrigins);
612 propagateIfChanged(after, changed);
617 dataflow::CallControlFlowAction action,
621 if (action == dataflow::CallControlFlowAction::ExternalCallee) {
622 auto symbol = dyn_cast<SymbolRefAttr>(call.getCallableForCallee());
626 if (
auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
627 call, symbol.getLeafReference())) {
628 if (
auto summaryAttr = callee->getAttrOfType<ArrayAttr>(
629 EnzymeDialect::getDenseActivityAnnotationAttrName())) {
630 DenseMap<DistinctAttr, ValueOriginSet> summary;
632 return processCallToSummarizedFunc(call, summary, after, before);
640 auto point = dyn_cast<ProgramPoint *>(lattice->getAnchor());
641 auto *block = point->getBlock();
645 auto funcOp = cast<FunctionOpInterface>(block->getParentOp());
646 ChangeResult changed = ChangeResult::NoChange;
647 for (BlockArgument arg : funcOp.getArguments()) {
648 auto *pointsToSets = getOrCreateFor<PointsToSets>(
649 point, getProgramPointAfter(block->getTerminator()));
650 auto *argClass = getOrCreateFor<AliasClassLattice>(point, arg);
651 auto origin = ArgumentOriginAttr::get(FlatSymbolRefAttr::get(funcOp),
657 [&](DistinctAttr currentClass) {
659 lattice->insert(AliasClassSet(currentClass),
660 ValueOriginSet(origin));
663 propagateIfChanged(lattice, changed);
671 auto activityIface = dyn_cast<enzyme::ActivityOpInterface>(op);
672 if (activityIface && activityIface.isInactive())
675 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
681 SmallVector<MemoryEffects::EffectInstance> effects;
682 memory.getEffects(effects);
683 for (
const auto &effect : effects) {
684 if (!isa<MemoryEffects::Write>(effect.getEffect()))
687 Value value = effect.getValue();
693 if (std::optional<Value> stored =
getStored(op)) {
694 ProgramPoint *point = getProgramPointBefore(op);
695 auto *addressClasses = getOrCreateFor<AliasClassLattice>(point, value);
696 auto *storedClasses = getOrCreateFor<AliasClassLattice>(point, *stored);
698 if (storedClasses->isUndefined()) {
700 auto *storedOrigins = getOrCreate<BackwardOriginsLattice>(*stored);
703 addressClasses->getAliasClassesObject().foreachElement(
705 if (state == AliasClassSet::State::Undefined) {
706 return ChangeResult::NoChange;
708 if (state == AliasClassSet::State::Unknown) {
709 return storedOrigins->markUnknown();
711 return storedOrigins->merge(after.
getOrigins(alloc));
713 }
else if (storedClasses->isUnknown()) {
719 }
else if (std::optional<Value> copySource =
getCopySource(op)) {
720 processCopy(op, *copySource, value, after, before);
726void enzyme::DenseBackwardActivityAnnotationAnalysis::
727 processCallToSummarizedFunc(
728 CallOpInterface call,
729 const DenseMap<DistinctAttr, ValueOriginSet> &summary,
730 const BackwardOriginsMap &after, BackwardOriginsMap *before) {
731 ChangeResult changed = ChangeResult::NoChange;
732 ProgramPoint *pointBefore = getProgramPointBefore(call);
733 ProgramPoint *pointAfter = getProgramPointAfter(call);
736 auto *p2sets = getOrCreateFor<PointsToSets>(pointBefore, pointAfter);
737 SmallVector<AliasClassSet> argumentClasses;
738 for (Value argOperand : call.getArgOperands()) {
740 getOrCreateFor<AliasClassLattice>(pointBefore, argOperand);
741 argumentClasses.push_back(argClasses->getAliasClassesObject());
744 for (
const auto &[destClass, sourceOrigins] : summary) {
747 if (
auto pseudoClass = dyn_cast_if_present<PseudoAliasClassAttr>(
748 destClass.getReferencedAttr())) {
750 [&](DistinctAttr aliasClass) {
751 (void)destOrigins.join(
752 after.getOrigins(aliasClass));
756 if (destOrigins.isUndefined())
761 for (Attribute sourceOrigin : sourceOrigins.getElements()) {
763 cast<ArgumentOriginAttr>(sourceOrigin).getArgNumber();
765 if (argumentClasses[argNumber].isUndefined()) {
767 auto *backwardLattice = getOrCreate<BackwardOriginsLattice>(
768 call.getArgOperands()[argNumber]);
769 if (destOrigins.isUnknown()) {
770 propagateIfChanged(backwardLattice, backwardLattice->markUnknown());
773 propagateIfChanged(backwardLattice,
774 backwardLattice->insert(destOrigins.getElements()));
777 [&](DistinctAttr aliasClass) {
778 (void)callerSourceClasses.insert({aliasClass});
782 changed |= before->
insert(callerSourceClasses, destOrigins);
784 propagateIfChanged(before, changed);
787void enzyme::DenseBackwardActivityAnnotationAnalysis::processCopy(
788 Operation *op, Value copySource, Value copyDest,
789 const BackwardOriginsMap &after, BackwardOriginsMap *before) {
790 ProgramPoint *point = getProgramPointBefore(op);
791 auto *dest = getOrCreateFor<AliasClassLattice>(point, copyDest);
793 if (dest->isUndefined())
795 if (dest->isUnknown())
798 for (DistinctAttr destClass : dest->getAliasClasses())
799 (void)destOrigins.join(after.getOrigins(destClass));
801 auto *src = getOrCreateFor<AliasClassLattice>(point, copySource);
802 propagateIfChanged(before,
803 before->insert(src->getAliasClassesObject(), destOrigins));
810void annotateHardcoded(FunctionOpInterface func) {
811 if (func.getName() ==
"lgamma" || func.getName() ==
"tanh") {
812 MLIRContext *ctx = func.getContext();
813 SmallVector<Attribute> arr = {StringAttr::get(ctx,
"<undefined>")};
814 func->setAttr(enzyme::EnzymeDialect::getAliasSummaryAttrName(),
815 ArrayAttr::get(ctx, arr));
821void reverseToposortCallgraph(CallableOpInterface callee,
822 SymbolTableCollection *symbolTable,
823 SmallVectorImpl<CallableOpInterface> &sorted) {
824 DenseSet<CallableOpInterface> permanent;
825 DenseSet<CallableOpInterface> temporary;
826 std::function<void(CallableOpInterface)> visit =
827 [&](CallableOpInterface node) {
828 if (permanent.contains(node))
830 if (temporary.contains(node))
831 assert(
false &&
"unimplemented cycle in call graph");
833 temporary.insert(node);
834 node.walk([&](CallOpInterface call) {
835 auto neighbour = cast<CallableOpInterface>(
836 call.resolveCallableInTable(symbolTable));
840 temporary.erase(node);
841 permanent.insert(node);
842 sorted.push_back(node);
848void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func,
849 DataFlowSolver &solver) {
852 for (Operation &op : func.getCallableRegion()->getOps()) {
853 if (!op.hasTrait<OpTrait::ReturnLike>())
856 for (OpOperand &returnOperand : op.getOpOperands()) {
859 auto origin = ReturnOriginAttr::get(FlatSymbolRefAttr::get(func),
860 returnOperand.getOperandNumber());
861 (void)lattice->insert({origin});
867 std::pair<enzyme::ForwardOriginsLattice, enzyme::BackwardOriginsLattice>;
874void topDownActivityAnalysis(
875 FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities,
876 ArrayRef<enzyme::Activity> retActivities,
877 DenseMap<BlockArgument, OriginsPair> &blockArgOrigins) {
879 MLIRContext *ctx = callee.getContext();
880 callee->setAttr(
"enzyme.visited", UnitAttr::get(ctx));
881 auto trueAttr = BoolAttr::get(ctx,
true);
882 auto falseAttr = BoolAttr::get(ctx,
false);
884 auto isOriginActive = [&](OriginAttr origin) {
885 if (
auto argOriginAttr = dyn_cast<ArgumentOriginAttr>(origin)) {
886 return llvm::is_contained({Activity::enzyme_dup,
887 Activity::enzyme_dupnoneed,
888 Activity::enzyme_active},
889 argActivities[argOriginAttr.getArgNumber()]);
891 auto retOriginAttr = cast<ReturnOriginAttr>(origin);
892 return llvm::is_contained({Activity::enzyme_dup, Activity::enzyme_dupnoneed,
893 Activity::enzyme_active},
894 retActivities[retOriginAttr.getReturnNumber()]);
896 callee.getFunctionBody().walk([&](Operation *op) {
897 if (op->getNumResults() == 0) {
899 op->setAttr(
"enzyme.icv", trueAttr);
902 if (op->hasAttr(
"enzyme.constantval")) {
903 op->setAttr(
"enzyme.icv", trueAttr);
904 }
else if (op->hasAttr(
"enzyme.activeval")) {
905 op->setAttr(
"enzyme.icv", falseAttr);
907 auto valueSource = op->getAttrOfType<ArrayAttr>(
"enzyme.valsrc");
908 auto valueSink = op->getAttrOfType<ArrayAttr>(
"enzyme.valsink");
909 if (!(valueSource && valueSink)) {
910 llvm::errs() <<
"[activity] missing attributes for op: " << *op
913 assert(valueSource && valueSink &&
"missing attributes for op");
915 llvm::any_of(valueSource.getAsRange<OriginAttr>(), isOriginActive);
917 llvm::any_of(valueSink.getAsRange<OriginAttr>(), isOriginActive);
918 bool activeVal = activeSource && activeSink;
919 op->setAttr(
"enzyme.icv", BoolAttr::get(ctx, !activeVal));
922 op->removeAttr(
"enzyme.constantval");
923 op->removeAttr(
"enzyme.activeval");
924 op->removeAttr(
"enzyme.valsrc");
925 op->removeAttr(
"enzyme.valsink");
928 if (op->hasAttr(
"enzyme.constantop")) {
929 op->setAttr(
"enzyme.ici", trueAttr);
930 }
else if (op->hasAttr(
"enzyme.activeop")) {
931 op->setAttr(
"enzyme.ici", falseAttr);
933 bool activeSource = llvm::any_of(
934 op->getAttrOfType<ArrayAttr>(
"enzyme.opsrc").getAsRange<OriginAttr>(),
937 llvm::any_of(op->getAttrOfType<ArrayAttr>(
"enzyme.opsink")
938 .getAsRange<OriginAttr>(),
940 bool activeOp = activeSource && activeSink;
941 op->setAttr(
"enzyme.ici", BoolAttr::get(ctx, !activeOp));
944 op->removeAttr(
"enzyme.constantop");
945 op->removeAttr(
"enzyme.activeop");
946 op->removeAttr(
"enzyme.opsrc");
947 op->removeAttr(
"enzyme.opsink");
949 if (
auto callOp = dyn_cast<CallOpInterface>(op)) {
950 auto funcOp = cast<FunctionOpInterface>(callOp.resolveCallable());
951 if (!funcOp->hasAttr(
"enzyme.visited")) {
952 SmallVector<Activity> callArgActivities, callResActivities;
953 for (Value operand : callOp.getArgOperands()) {
954 if (
auto *definingOp = operand.getDefiningOp()) {
956 definingOp->getAttrOfType<BoolAttr>(
"enzyme.icv").getValue();
957 callArgActivities.push_back(icv ? Activity::enzyme_const
958 : Activity::enzyme_active);
960 BlockArgument blockArg = cast<BlockArgument>(operand);
961 const OriginsPair &originsPair = blockArgOrigins.at(blockArg);
964 bool argActive =
false;
970 argActive = llvm::any_of(sources.
getOrigins(), isOriginActive) &&
971 llvm::any_of(sinks.
getOrigins(), isOriginActive);
973 callArgActivities.push_back(argActive ? Activity::enzyme_active
974 : Activity::enzyme_const);
977 if (op->getNumResults() != 0) {
978 bool icv = op->getAttrOfType<BoolAttr>(
"enzyme.icv").getValue();
979 callResActivities.push_back(icv ? Activity::enzyme_const
980 : Activity::enzyme_active);
983 topDownActivityAnalysis(funcOp, callArgActivities, callResActivities,
992 FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities,
994 SymbolTableCollection symbolTable;
995 SmallVector<CallableOpInterface> sorted;
996 reverseToposortCallgraph(callee, &symbolTable, sorted);
997 raw_ostream &os = llvm::outs();
1000 DenseMap<BlockArgument, OriginsPair> blockArgOrigins;
1002 StringRef pointerSummaryName = EnzymeDialect::getPointerSummaryAttrName();
1003 for (CallableOpInterface node : sorted) {
1004 annotateHardcoded(cast<FunctionOpInterface>(node.getOperation()));
1006 if (!node.getCallableRegion() || node->hasAttr(pointerSummaryName))
1008 auto funcOp = cast<FunctionOpInterface>(node.getOperation());
1010 os <<
"[ata] processing function @" << funcOp.getName() <<
"\n";
1012 DataFlowConfig dataFlowConfig;
1013 dataFlowConfig.setInterprocedural(
false);
1014 DataFlowSolver solver(dataFlowConfig);
1015 SymbolTableCollection symbolTable;
1017 solver.load<dataflow::SparseConstantPropagation>();
1018 solver.load<dataflow::DeadCodeAnalysis>();
1027 initializeSparseBackwardActivityAnnotations(funcOp, solver);
1029 if (failed(solver.initializeAndRun(node))) {
1030 assert(
false &&
"dataflow solver failed");
1036 size_t numResults = node.getResultTypes().size();
1037 SmallVector<enzyme::ForwardOriginsLattice> returnOperandOrigins(
1039 SmallVector<enzyme::AliasClassLattice> returnAliasClasses(
1042 for (Operation &op : node.getCallableRegion()->getOps()) {
1043 if (op.hasTrait<OpTrait::ReturnLike>()) {
1044 ProgramPoint *point = solver.getProgramPointAfter(&op);
1046 auto *returnOrigins =
1049 (void)forwardOriginsMap.
join(*returnOrigins);
1051 for (OpOperand &operand : op.getOpOperands()) {
1052 (void)returnAliasClasses[operand.getOperandNumber()].join(
1054 (void)returnOperandOrigins[operand.getOperandNumber()].join(
1062 SmallVector<Attribute> aliasAttributes(returnAliasClasses.size());
1063 llvm::transform(returnAliasClasses, aliasAttributes.begin(),
1065 return lattice.serialize(node.getContext());
1067 node->setAttr(EnzymeDialect::getAliasSummaryAttrName(),
1068 ArrayAttr::get(node.getContext(), aliasAttributes));
1071 node->setAttr(pointerSummaryName, p2sets.
serialize(node.getContext()));
1073 os <<
"[ata] p2p summary:\n";
1074 if (node->getAttrOfType<ArrayAttr>(pointerSummaryName).size() == 0) {
1077 for (ArrayAttr pair : node->getAttrOfType<ArrayAttr>(pointerSummaryName)
1078 .getAsRange<ArrayAttr>()) {
1079 os <<
" " << pair[0] <<
" -> " << pair[1] <<
"\n";
1083 node->setAttr(EnzymeDialect::getDenseActivityAnnotationAttrName(),
1084 forwardOriginsMap.
serialize(node.getContext()));
1086 os <<
"[ata] forward value origins:\n";
1087 for (ArrayAttr pair :
1088 node->getAttrOfType<ArrayAttr>(
1089 EnzymeDialect::getDenseActivityAnnotationAttrName())
1090 .getAsRange<ArrayAttr>()) {
1091 os <<
" " << pair[0] <<
" originates from " << pair[1] <<
"\n";
1095 auto *backwardOriginsMap =
1097 solver.getProgramPointBefore(
1098 &node.getCallableRegion()->front().front()));
1099 Attribute backwardOrigins =
1100 backwardOriginsMap->serialize(node.getContext());
1102 os <<
"[ata] backward value origins:\n";
1103 for (ArrayAttr pair :
1104 cast<ArrayAttr>(backwardOrigins).getAsRange<ArrayAttr>()) {
1105 os <<
" " << pair[0] <<
" goes to " << pair[1] <<
"\n";
1110 MLIRContext *ctx = node.getContext();
1111 SmallVector<Attribute> serializedReturnOperandOrigins(
1112 returnOperandOrigins.size());
1113 llvm::transform(returnOperandOrigins,
1114 serializedReturnOperandOrigins.begin(),
1116 return lattice.serialize(ctx);
1119 EnzymeDialect::getSparseActivityAnnotationAttrName(),
1120 ArrayAttr::get(node.getContext(), serializedReturnOperandOrigins));
1122 os <<
"[ata] return origins: "
1123 << node->getAttr(EnzymeDialect::getSparseActivityAnnotationAttrName())
1127 auto joinActiveDataState =
1129 std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
1132 (void)out.first.join(*sources);
1133 (void)out.second.
meet(*sinks);
1136 auto joinActivePointerState =
1138 std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
1140 aliasClasses, p2sets, [&](DistinctAttr aliasClass) {
1141 (void)out.first.merge(forwardOriginsMap.
getOrigins(aliasClass));
1142 (void)out.second.merge(
1143 backwardOriginsMap->getOrigins(aliasClass));
1147 auto joinActiveValueState =
1149 std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
1150 if (isa<LLVM::LLVMPointerType, MemRefType>(value.getType())) {
1151 auto *aliasClasses =
1153 joinActivePointerState(aliasClasses->getAliasClassesObject(), out);
1155 joinActiveDataState(value, out);
1159 auto annotateActivity = [&](Operation *op) {
1160 assert(op->getNumResults() < 2 && op->getNumRegions() == 0 &&
1161 "annotation only supports the LLVM dialect");
1162 auto unitAttr = UnitAttr::get(ctx);
1164 for (OpResult result : op->getResults()) {
1165 std::pair<ForwardOriginsLattice, BackwardOriginsLattice>
1168 joinActiveValueState(result, activityAttributes);
1169 const auto &sources = activityAttributes.first;
1170 const auto &sinks = activityAttributes.second;
1173 if (sources.isUnknown() || sinks.isUnknown()) {
1175 op->setAttr(
"enzyme.activeval", unitAttr);
1176 }
else if (sources.isUndefined() || sinks.isUndefined()) {
1178 op->setAttr(
"enzyme.constantval", unitAttr);
1181 op->setAttr(
"enzyme.valsrc", sources.serialize(ctx));
1182 op->setAttr(
"enzyme.valsink", sinks.serialize(ctx));
1186 StringRef opSourceAttrName =
"enzyme.opsrc";
1187 StringRef opSinkAttrName =
"enzyme.opsink";
1188 std::pair<ForwardOriginsLattice, BackwardOriginsLattice> opAttributes(
1192 for (OpResult result : op->getResults()) {
1193 joinActiveDataState(result, opAttributes);
1198 if (
auto storeOp = dyn_cast<LLVM::StoreOp>(op)) {
1201 joinActivePointerState(storedClass->getAliasClassesObject(),
1203 }
else if (
auto callOp = dyn_cast<CallOpInterface>(op)) {
1205 auto callable = cast<CallableOpInterface>(callOp.resolveCallable());
1206 if (callable->hasAttr(
1207 EnzymeDialect::getDenseActivityAnnotationAttrName())) {
1208 for (Value operand : callOp.getArgOperands())
1209 joinActiveValueState(operand, opAttributes);
1217 for (Value operand : op->getOperands())
1218 joinActiveDataState(operand, opAttributes);
1219 for (OpResult result : op->getResults())
1220 joinActiveDataState(result, opAttributes);
1223 const auto &opSources = opAttributes.first;
1224 const auto &opSinks = opAttributes.second;
1225 if (opSources.isUnknown() || opSinks.isUnknown()) {
1226 op->setAttr(
"enzyme.activeop", unitAttr);
1227 }
else if (opSources.isUndefined() || opSinks.isUndefined()) {
1228 op->setAttr(
"enzyme.constantop", unitAttr);
1230 op->setAttr(opSourceAttrName, opAttributes.first.serialize(ctx));
1231 op->setAttr(opSinkAttrName, opAttributes.second.serialize(ctx));
1237 node.getCallableRegion()->walk([&](Block *block) {
1238 for (BlockArgument blockArg : block->getArguments()) {
1241 joinActiveValueState(blockArg, blockArgAttributes);
1242 blockArgOrigins.try_emplace(blockArg, blockArgAttributes);
1246 node.getCallableRegion()->walk([&](Operation *op) {
1248 annotateActivity(op);
1250 if (op->hasAttr(
"tag")) {
1251 for (OpResult result : op->getResults()) {
1252 std::pair<ForwardOriginsLattice, BackwardOriginsLattice>
1255 joinActiveValueState(result, activityAttributes);
1256 os << op->getAttr(
"tag") <<
"(#" << result.getResultNumber()
1258 <<
" sources: " << activityAttributes.first.serialize(ctx)
1260 <<
" sinks: " << activityAttributes.second.serialize(ctx)
1268 if (!argActivities.empty() && activityConfig.
annotate) {
1269 SmallVector<enzyme::Activity> resActivities;
1270 for (Type resultType : callee.getResultTypes()) {
1271 resActivities.push_back(isa<FloatType, ComplexType>(resultType)
1272 ? Activity::enzyme_active
1273 : Activity::enzyme_const);
1276 topDownActivityAnalysis(callee, argActivities, resActivities,
static void deserializePointsTo(ArrayAttr summaryAttr, DenseMap< DistinctAttr, enzyme::ValueOriginSet > &summaryMap)
std::optional< Value > getStored(Operation *op)
static bool isPossiblyActive(Type type)
static void traversePointsToSets(const enzyme::AliasClassSet &start, const enzyme::PointsToSets &pointsToSets, function_ref< void(DistinctAttr)> visit)
Visit everything transitively pointed-to by any pointer in start.
void deserializeReturnOrigins(ArrayAttr returnOrigins, SmallVectorImpl< enzyme::ValueOriginSet > &out)
static bool isFullyActive(Operation *op)
True iff all results differentially depend on all operands.
std::optional< Value > getCopySource(Operation *op)
void printMapOfSetsLattice(const DenseMap< KeyT, enzyme::SetLattice< ElementT > > map, raw_ostream &os)
bool annotate
Annotate the IR with activity information for every operation.
bool verbose
Output extra information for debugging.
This analysis implements interprocedural alias analysis.
void visitExternalCall(CallOpInterface call, ArrayRef< BackwardOriginsLattice * > operands, ArrayRef< const BackwardOriginsLattice * > results) override
void setToExitState(BackwardOriginsLattice *lattice) override
LogicalResult visitOperation(Operation *op, ArrayRef< BackwardOriginsLattice * > operands, ArrayRef< const BackwardOriginsLattice * > results) override
void print(raw_ostream &os) const override
ChangeResult meet(const AbstractSparseLattice &other) override
const DenseSet< OriginAttr > & getOrigins() const
ChangeResult markAllOriginsUnknown()
void print(raw_ostream &os) const override
const ValueOriginSet & getOrigins(DistinctAttr id) const
LogicalResult visitOperation(Operation *op, const ForwardOriginsMap &before, ForwardOriginsMap *after) override
void visitCallControlFlowTransfer(CallOpInterface call, dataflow::CallControlFlowAction action, const ForwardOriginsMap &before, ForwardOriginsMap *after) override
void setToEntryState(ForwardOriginsMap *lattice) override
void setToExitState(BackwardOriginsMap *lattice) override
LogicalResult visitOperation(Operation *op, const BackwardOriginsMap &after, BackwardOriginsMap *before) override
void visitCallControlFlowTransfer(CallOpInterface call, dataflow::CallControlFlowAction action, const BackwardOriginsMap &after, BackwardOriginsMap *before) override
void setToEntryState(ForwardOriginsLattice *lattice) override
void visitExternalCall(CallOpInterface call, ArrayRef< const ForwardOriginsLattice * > operands, ArrayRef< ForwardOriginsLattice * > results) override
LogicalResult visitOperation(Operation *op, ArrayRef< const ForwardOriginsLattice * > operands, ArrayRef< ForwardOriginsLattice * > results) override
ChangeResult join(const AbstractSparseLattice &other) override
const DenseSet< OriginAttr > & getOrigins() const
void print(raw_ostream &os) const override
static ForwardOriginsLattice single(Value point, OriginAttr value)
const ValueOriginSet & getOrigins(DistinctAttr id) const
ChangeResult markAllOriginsUnknown()
void print(raw_ostream &os) const override
ChangeResult join(const AbstractDenseLattice &other)
Attribute serialize(MLIRContext *ctx) const
ChangeResult insert(const SetLattice< KeyT > &keysToUpdate, const SetLattice< ElementT > &values)
Map all keys to all values.
const AliasClassSet & getPointsTo(DistinctAttr id) const
ChangeResult markUnknown()
ChangeResult insert(const DenseSet< ValueT > &newElements)
ChangeResult join(const SetLattice< ValueT > &other)
static const SetLattice< OriginAttr > & getUndefined()
static const SetLattice< OriginAttr > & getUnknown()
ChangeResult markUnknown()
Attribute serialize(MLIRContext *ctx) const
SetLattice< OriginAttr > ValueOriginSet
SetLattice< DistinctAttr > AliasClassSet
A set of alias class identifiers to be treated as a single union.
void runActivityAnnotations(FunctionOpInterface callee, ArrayRef< enzyme::Activity > argActivities={}, const ActivityPrinterConfig &config=ActivityPrinterConfig())