Enzyme main
Loading...
Searching...
No Matches
DataFlowActivityAnalysis.cpp
Go to the documentation of this file.
1//===- DataFlowActivityAnalysis.h - Implementation of Activity Analysis ---===//
2//
3// Enzyme Project
4//
5// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
6// Exceptions. See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9// If using this code in an academic setting, please cite the following:
10// @incollection{enzymeNeurips,
11// title = {Instead of Rewriting Foreign Code for Machine Learning,
12// Automatically Synthesize Fast Gradients},
13// author = {Moses, William S. and Churavy, Valentin},
14// booktitle = {Advances in Neural Information Processing Systems 33},
15// year = {2020},
16// note = {To appear in},
17// }
18//
19//===----------------------------------------------------------------------===//
20//
21// This file contains the implementation of Activity Analysis -- an AD-specific
22// analysis that deduces if a given instruction or value can impact the
23// calculation of a derivative. This file formulates activity analysis within
24// a dataflow framework.
25//
26//===----------------------------------------------------------------------===//
29#include "Dialect/Ops.h"
31
32#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
33#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
34#include "mlir/Analysis/DataFlow/DenseAnalysis.h"
35#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
36#include "mlir/Analysis/DataFlowFramework.h"
37#include "mlir/Interfaces/SideEffectInterfaces.h"
38
39// TODO: Don't depend on specific dialects
40#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
41#include "mlir/Dialect/Linalg/IR/Linalg.h"
42#include "mlir/Dialect/MemRef/IR/MemRef.h"
43
44#include "mlir/Analysis/AliasAnalysis/LocalAliasAnalysis.h"
45
47
48using namespace mlir;
49using namespace mlir::dataflow;
51
52/// From LLVM Enzyme's activity analysis, there are four activity states.
53// constant instruction vs constant value, a value/instruction (one and the same
54// in LLVM) can be a constant instruction but active value, active instruction
55// but constant value, or active/constant both.
56
57// The result of activity states are potentially different for multiple
58// enzyme.autodiff calls.
59
61
62using llvm::errs;
64public:
68
72
76
77 bool isActiveVal() const { return value == ActivityKind::ActiveVal; }
78
79 bool isConstant() const { return value == ActivityKind::Constant; }
80
81 bool isUnknown() const { return value == ActivityKind::Unknown; }
82
84 ValueActivity(ActivityKind value) : value(value) {}
85
86 /// Get the known activity state.
87 const ActivityKind &getValue() const { return value; }
88
89 bool operator==(const ValueActivity &rhs) const { return value == rhs.value; }
90
92 const ValueActivity &rhs) {
93 if (lhs.isUnknown() || rhs.isUnknown())
95
96 if (lhs.isConstant() && rhs.isConstant())
99 }
100
102 const ValueActivity &rhs) {
103 return ValueActivity::merge(lhs, rhs);
104 }
105
106 void print(raw_ostream &os) const {
107 switch (value) {
109 os << "ActiveVal";
110 break;
112 os << "Constant";
113 break;
115 os << "Unknown";
116 break;
117 }
118 }
119
120 raw_ostream &operator<<(raw_ostream &os) const {
121 print(os);
122 return os;
123 }
124
125private:
126 /// The activity kind. Optimistically initialized to constant.
128};
129
130raw_ostream &operator<<(raw_ostream &os, const ValueActivity &val) {
131 val.print(os);
132 return os;
133}
134
135class ForwardValueActivity : public Lattice<ValueActivity> {
136public:
137 using Lattice::Lattice;
138};
139
141public:
142 using AbstractSparseLattice::AbstractSparseLattice;
143
144 ChangeResult meet(const AbstractSparseLattice &other) override {
145 const auto *rhs = reinterpret_cast<const BackwardValueActivity *>(&other);
146 return meet(rhs->getValue());
147 }
148
149 void print(raw_ostream &os) const override { value.print(os); }
150
151 ValueActivity getValue() const { return value; }
152
153 ChangeResult meet(ValueActivity other) {
154 auto met = ValueActivity::merge(getValue(), other);
155 if (getValue() == met) {
156 return ChangeResult::NoChange;
157 }
158
159 value = met;
160 return ChangeResult::Change;
161 }
162
163private:
164 ValueActivity value;
165};
166
167raw_ostream &operator<<(raw_ostream &os, const CallControlFlowAction &action) {
168 switch (action) {
169 case CallControlFlowAction::EnterCallee:
170 os << "EnterCallee";
171 break;
172 case CallControlFlowAction::ExitCallee:
173 os << "ExitCallee";
174 break;
175 case CallControlFlowAction::ExternalCallee:
176 os << "ExternalCallee";
177 break;
178 }
179 return os;
180}
181
182/// This needs to keep track of three things:
183/// 1. Could active info store in?
184/// 2. Could active info load out?
185/// TODO: Necessary for run-time activity
186/// 3. Could constant info propagate (store?) in?
187///
188/// Active: (forward) active in && (backward) active out && (??) !const in
189/// ActiveOrConstant: active in && active out && const in
190/// Constant: everything else
192 /// Whether active data has stored into this memory location.
193 bool activeIn = false;
194 /// Whether active data was loaded out of this memory location.
195 bool activeOut = false;
196
197 bool operator==(const MemoryActivityState &other) {
198 return activeIn == other.activeIn && activeOut == other.activeOut;
199 }
200
201 bool operator!=(const MemoryActivityState &other) {
202 return !(*this == other);
203 }
204
205 ChangeResult reset() {
206 if (!activeIn && !activeOut)
207 return ChangeResult::NoChange;
208 activeIn = false;
209 activeOut = false;
210 return ChangeResult::Change;
211 }
212
213 ChangeResult merge(const MemoryActivityState &other) {
214 if (*this == other) {
215 return ChangeResult::NoChange;
216 }
217
218 activeIn |= other.activeIn;
219 activeOut |= other.activeOut;
220 return ChangeResult::Change;
221 }
222};
223
225public:
226 using AbstractDenseLattice::AbstractDenseLattice;
227
228 /// Clear all modifications.
229 ChangeResult reset() {
230 if (activityStates.empty())
231 return otherMemoryActivity.reset();
232 activityStates.clear();
233 return otherMemoryActivity.reset();
234 }
235
236 bool hasActiveData(DistinctAttr aliasClass) const {
237 if (!aliasClass)
238 return otherMemoryActivity.activeIn;
239 auto it = activityStates.find(aliasClass);
240 if (it != activityStates.end())
241 return it->getSecond().activeIn;
242 return otherMemoryActivity.activeIn;
243 }
244
245 bool activeDataFlowsOut(DistinctAttr aliasClass) const {
246 if (!aliasClass)
247 return otherMemoryActivity.activeOut;
248
249 auto it = activityStates.find(aliasClass);
250 if (it != activityStates.end())
251 return it->getSecond().activeOut;
252 return otherMemoryActivity.activeOut;
253 }
254
255 /// Set the internal activity state. Accepts null attribute to indicate "other
256 /// classes".
257 ChangeResult setActiveIn(DistinctAttr aliasClass) {
258 if (!aliasClass)
259 return setActiveIn();
260
261 auto &state = activityStates[aliasClass];
262 ChangeResult result =
263 state.activeIn ? ChangeResult::NoChange : ChangeResult::Change;
264 state.activeIn = true;
265 return result;
266 }
267 ChangeResult setActiveIn() {
268 if (otherMemoryActivity.activeIn && activityStates.empty())
269 return ChangeResult::NoChange;
270 otherMemoryActivity.activeIn = true;
271 activityStates.clear();
272 return ChangeResult::Change;
273 }
274 ChangeResult setActiveOut(DistinctAttr aliasClass) {
275 if (!aliasClass)
276 return setActiveOut();
277
278 auto &state = activityStates[aliasClass];
279 ChangeResult result =
280 state.activeOut ? ChangeResult::NoChange : ChangeResult::Change;
281 state.activeOut = true;
282 return result;
283 }
284 ChangeResult setActiveOut() {
285 if (otherMemoryActivity.activeOut && activityStates.empty())
286 return ChangeResult::NoChange;
287 otherMemoryActivity.activeOut = true;
288 activityStates.clear();
289 return ChangeResult::Change;
290 }
291
292 void print(raw_ostream &os) const override {
293 if (activityStates.empty()) {
294 os << "<memory activity state was empty>"
295 << "\n";
296 }
297 for (const auto &[value, state] : activityStates) {
298 os << value << ": in " << state.activeIn << " out " << state.activeOut
299 << "\n";
300 }
301 os << "other classes: in " << otherMemoryActivity.activeIn << " out "
302 << otherMemoryActivity.activeOut << "\n";
303 }
304
305 raw_ostream &operator<<(raw_ostream &os) const {
306 print(os);
307 return os;
308 }
309
310protected:
311 ChangeResult merge(const AbstractDenseLattice &lattice) {
312 const auto &rhs = static_cast<const MemoryActivity &>(lattice);
313 ChangeResult result = ChangeResult::NoChange;
314 DenseSet<DistinctAttr> known;
315 auto lhsRange = llvm::make_first_range(activityStates);
316 auto rhsRange = llvm::make_first_range(rhs.activityStates);
317 known.insert(lhsRange.begin(), lhsRange.end());
318 known.insert(rhsRange.begin(), rhsRange.end());
319
320 MemoryActivityState updatedOther(otherMemoryActivity);
321 result |= updatedOther.merge(rhs.otherMemoryActivity);
322 DenseMap<DistinctAttr, MemoryActivityState> updated;
323 for (DistinctAttr d : known) {
324 auto lhsIt = activityStates.find(d);
325 auto rhsIt = rhs.activityStates.find(d);
326 bool isKnownInLHS = lhsIt != activityStates.end();
327 bool isKnownInRHS = rhsIt != rhs.activityStates.end();
328 const MemoryActivityState *lhsActivity =
329 isKnownInLHS ? &lhsIt->getSecond() : &otherMemoryActivity;
330 const MemoryActivityState *rhsActivity =
331 isKnownInRHS ? &rhsIt->getSecond() : &rhs.otherMemoryActivity;
332 MemoryActivityState updatedActivity(*lhsActivity);
333 (void)updatedActivity.merge(*rhsActivity);
334 if ((lhsIt != activityStates.end() &&
335 updatedActivity != lhsIt->getSecond()) ||
336 (lhsIt == activityStates.end() &&
337 updatedActivity != otherMemoryActivity)) {
338 result |= ChangeResult::Change;
339 }
340 if (updatedActivity != updatedOther)
341 updated.try_emplace(d, updatedActivity);
342 }
343 std::swap(updated, activityStates);
344 return otherMemoryActivity.merge(rhs.otherMemoryActivity) | result;
345 }
346
347private:
348 DenseMap<DistinctAttr, MemoryActivityState> activityStates;
349 MemoryActivityState otherMemoryActivity;
350};
351
353public:
354 using MemoryActivity::MemoryActivity;
355
356 /// Join the activity states.
357 ChangeResult join(const AbstractDenseLattice &lattice) {
358 return merge(lattice);
359 }
360};
361
363public:
364 using MemoryActivity::MemoryActivity;
365
366 ChangeResult meet(const AbstractDenseLattice &lattice) override {
367 return merge(lattice);
368 }
369};
370
371/// Sparse activity analysis reasons about activity by traversing forward down
372/// the def-use chains starting from active function arguments.
374 : public SparseForwardDataFlowAnalysis<ForwardValueActivity> {
375public:
376 using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
377
378 /// In general, we don't know anything about entry operands.
379 void setToEntryState(ForwardValueActivity *lattice) override {
380 // errs() << "sparse forward setting to entry state\n";
381 propagateIfChanged(lattice, lattice->join(ValueActivity()));
382 }
383
384 LogicalResult
385 visitOperation(Operation *op, ArrayRef<const ForwardValueActivity *> operands,
386 ArrayRef<ForwardValueActivity *> results) override {
387 if (op->hasTrait<OpTrait::ConstantLike>())
388 return success();
389
390 // Bail out if this op affects memory.
391 if (!isPure(op))
392 return success();
393
394 transfer(op, operands, results);
395
396 return success();
397 }
398
399 void visitExternalCall(CallOpInterface call,
400 ArrayRef<const ForwardValueActivity *> operands,
401 ArrayRef<ForwardValueActivity *> results) override {
402 transfer(call, operands, results);
403 }
404
405 void transfer(Operation *op, ArrayRef<const ForwardValueActivity *> operands,
406 ArrayRef<ForwardValueActivity *> results) {
407 // For value-based AA, assume any active argument leads to an active
408 // result.
409 ValueActivity joinedResult;
410 for (const ForwardValueActivity *operand : operands)
411 joinedResult = ValueActivity::merge(joinedResult, operand->getValue());
412
413 // Only mark results as active data if the type can carry perturbations and
414 // has value semantics
415 for (ForwardValueActivity *result : results) {
416 if (joinedResult.isActiveVal())
417 propagateIfChanged(result,
418 result->join(isa<FloatType, ComplexType>(
419 result->getAnchor().getType())
420 ? joinedResult
422 else
423 propagateIfChanged(result, result->join(joinedResult));
424 }
425 }
426};
427
429 : public SparseBackwardDataFlowAnalysis<BackwardValueActivity> {
430public:
431 using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;
432
433 void setToExitState(BackwardValueActivity *lattice) override {
434 // errs() << "backward sparse setting to exit state\n";
435 }
436
437 void visitBranchOperand(OpOperand &operand) override {}
438
439 void visitCallOperand(OpOperand &operand) override {}
440
441 void
442 visitNonControlFlowArguments(RegionSuccessor &successor,
443 ArrayRef<BlockArgument> arguments) override {}
444
445 void transfer(Operation *op, ArrayRef<BackwardValueActivity *> operands,
446 ArrayRef<const BackwardValueActivity *> results) {
447 // Propagate all operands to all results
448 for (auto operand : operands)
449 for (auto result : results)
450 meet(operand, *result);
451 }
452
453 LogicalResult
454 visitOperation(Operation *op, ArrayRef<BackwardValueActivity *> operands,
455 ArrayRef<const BackwardValueActivity *> results) override {
456 // Bail out if the op propagates memory
457 if (!isPure(op)) {
458 return success();
459 }
460
461 transfer(op, operands, results);
462 return success();
463 }
464
465 void
466 visitExternalCall(CallOpInterface call,
467 ArrayRef<BackwardValueActivity *> operands,
468 ArrayRef<const BackwardValueActivity *> results) override {
469 transfer(call, operands, results);
470 }
471};
472
473// When applying a transfer function to a store from memory, we need to know
474// what value is being stored.
475std::optional<Value> getStored(Operation *op) {
476 if (auto storeOp = dyn_cast<LLVM::StoreOp>(op)) {
477 return storeOp.getValue();
478 } else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
479 return storeOp.getValue();
480 }
481 return std::nullopt;
482}
483
484// TODO consider making this an interface ourselves
485std::optional<Value> getCopySource(Operation *op) {
486 if (isa<LLVM::MemcpyOp, LLVM::MemcpyInlineOp, LLVM::MemmoveOp>(op)) {
487 return op->getOperand(1);
488 }
489 return std::nullopt;
490}
491
492/// The dense analyses operate using a pointer's "canonical allocation", the
493/// Value corresponding to its allocation.
494/// The callback may receive null allocation when the class alias set is
495/// unknown.
496/// If the classes are undefined, the callback will not be called at all.
497void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass,
498 function_ref<void(DistinctAttr)> forEachFn) {
499 (void)ptrAliasClass->getAliasClassesObject().foreachElement(
500 [&](DistinctAttr alloc, enzyme::AliasClassSet::State state) {
501 if (state != enzyme::AliasClassSet::State::Undefined)
502 forEachFn(alloc);
503 return ChangeResult::NoChange;
504 });
505}
506
508 : public DenseForwardDataFlowAnalysis<ForwardMemoryActivity> {
509public:
510 DenseForwardActivityAnalysis(DataFlowSolver &solver, Block *entryBlock,
511 ArrayRef<enzyme::Activity> argumentActivity)
512 : DenseForwardDataFlowAnalysis(solver), entryBlock(entryBlock),
513 argumentActivity(argumentActivity) {}
514
515 LogicalResult visitOperation(Operation *op,
516 const ForwardMemoryActivity &before,
517 ForwardMemoryActivity *after) override {
518 join(after, before);
519 ChangeResult result = ChangeResult::NoChange;
520
521 // TODO If we know this is inactive by definition
522 // if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
523 // if (ifaceOp.isInactive()) {
524 // propagateIfChanged(after, result);
525 // return;
526 // }
527 // }
528
529 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
530 // If we can't reason about the memory effects, then conservatively assume
531 // we can't deduce anything about activity via side-effects.
532 if (!memory)
533 return success();
534
535 SmallVector<MemoryEffects::EffectInstance> effects;
536 memory.getEffects(effects);
537
538 for (const auto &effect : effects) {
539 Value value = effect.getValue();
540
541 // If we see an effect on anything other than a value, assume we can't
542 // deduce anything about the activity.
543 if (!value)
544 return success();
545
546 // In forward-flow, a value is active if loaded from a memory resource
547 // that has previously been actively stored to.
548 if (isa<MemoryEffects::Read>(effect.getEffect())) {
549 auto *ptrAliasClass =
550 getOrCreateFor<AliasClassLattice>(getProgramPointAfter(op), value);
551 forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
552 if (before.hasActiveData(alloc)) {
553 for (OpResult opResult : op->getResults()) {
554 // Mark the result as (forward) active
555 // TODO: We might need type analysis here
556 // Structs and tensors also have value semantics
557 if (isa<FloatType, ComplexType>(opResult.getType())) {
558 auto *valueState = getOrCreate<ForwardValueActivity>(opResult);
559 propagateIfChanged(
560 valueState,
561 valueState->join(ValueActivity::getActiveVal()));
562 }
563 }
564
565 if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
566 // propagate from input to block argument
567 for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) {
568 if (inputOperand->get() == value) {
569 auto *valueState = getOrCreate<ForwardValueActivity>(
570 linalgOp.getMatchingBlockArgument(inputOperand));
571 propagateIfChanged(
572 valueState,
573 valueState->join(ValueActivity::getActiveVal()));
574 }
575 }
576 }
577 }
578 });
579 }
580
581 if (isa<MemoryEffects::Write>(effect.getEffect())) {
582 std::optional<Value> stored = getStored(op);
583 if (stored.has_value()) {
584 auto *valueState = getOrCreateFor<ForwardValueActivity>(
585 getProgramPointAfter(op), *stored);
586 if (valueState->getValue().isActiveVal()) {
587 auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
588 getProgramPointAfter(op), value);
589 forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
590 // Mark the pointer as having been actively stored into
591 result |= after->setActiveIn(alloc);
592 });
593 }
594 } else if (auto copySource = getCopySource(op)) {
595 auto *srcAliasClass = getOrCreateFor<AliasClassLattice>(
596 getProgramPointAfter(op), *copySource);
597 forEachAliasedAlloc(srcAliasClass, [&](DistinctAttr srcAlloc) {
598 if (before.hasActiveData(srcAlloc)) {
599 auto *destAliasClass = getOrCreateFor<AliasClassLattice>(
600 getProgramPointAfter(op), value);
601 forEachAliasedAlloc(destAliasClass, [&](DistinctAttr destAlloc) {
602 result |= after->setActiveIn(destAlloc);
603 });
604 }
605 });
606 } else if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
607 // linalg.yield stores to the corresponding value.
608 for (OpOperand &dpsInit : linalgOp.getDpsInitsMutable()) {
609 if (dpsInit.get() == value) {
610 int64_t resultIndex =
611 dpsInit.getOperandNumber() - linalgOp.getNumDpsInputs();
612 Value yieldOperand =
613 linalgOp.getBlock()->getTerminator()->getOperand(resultIndex);
614 auto *valueState = getOrCreateFor<ForwardValueActivity>(
615 getProgramPointAfter(op), yieldOperand);
616 if (valueState->getValue().isActiveVal()) {
617 auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
618 getProgramPointAfter(op), value);
619 forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
620 result |= after->setActiveIn(alloc);
621 });
622 }
623 }
624 }
625 }
626 }
627 }
628 propagateIfChanged(after, result);
629 return success();
630 }
631
632 void visitCallControlFlowTransfer(CallOpInterface call,
633 CallControlFlowAction action,
634 const ForwardMemoryActivity &before,
635 ForwardMemoryActivity *after) override {
636 join(after, before);
637 }
638
639 /// Initialize the entry block with the supplied argument activities.
640 void setToEntryState(ForwardMemoryActivity *lattice) override {
641 if (auto pp = dyn_cast_if_present<ProgramPoint *>(lattice->getAnchor()))
642 if (Block *block = pp->getBlock();
643 block && block == entryBlock && pp->isBlockStart()) {
644 for (const auto &[arg, activity] :
645 llvm::zip(block->getArguments(), argumentActivity)) {
646 if (activity != enzyme::Activity::enzyme_dup &&
647 activity != enzyme::Activity::enzyme_dupnoneed)
648 continue;
649 auto *argAliasClasses = getOrCreateFor<AliasClassLattice>(
650 getProgramPointBefore(block), arg);
651 ChangeResult changed =
652 argAliasClasses->getAliasClassesObject().foreachElement(
653 [lattice](DistinctAttr argAliasClass,
655 if (state == enzyme::AliasClassSet::State::Undefined)
656 return ChangeResult::NoChange;
657 return lattice->setActiveIn(argAliasClass);
658 });
659 propagateIfChanged(lattice, changed);
660 }
661 }
662 }
663
664private:
665 // A pointer to the entry block and argument activities of the top-level
666 // function being differentiated. This is used to set the entry state
667 // because we need access to the results of points-to analysis.
668 Block *entryBlock;
669 SmallVector<enzyme::Activity> argumentActivity;
670};
671
673 : public DenseBackwardDataFlowAnalysis<BackwardMemoryActivity> {
674public:
675 DenseBackwardActivityAnalysis(DataFlowSolver &solver,
676 SymbolTableCollection &symbolTable,
677 FunctionOpInterface parentOp,
678 ArrayRef<enzyme::Activity> argumentActivity)
679 : DenseBackwardDataFlowAnalysis(solver, symbolTable), parentOp(parentOp),
680 argumentActivity(argumentActivity) {}
681
682 LogicalResult visitOperation(Operation *op,
683 const BackwardMemoryActivity &after,
684 BackwardMemoryActivity *before) override {
685
686 // TODO: If we know this is inactive by definition
687 // if (auto ifaceOp = dyn_cast<enzyme::ActivityOpInterface>(op)) {
688 // if (ifaceOp.isInactive()) {
689 // return;
690 // }
691 // }
692
693 // Initialize the return activity of arguments.
694 if (op->hasTrait<OpTrait::ReturnLike>() && op->getParentOp() == parentOp) {
695 for (const auto &[arg, argActivity] :
696 llvm::zip(parentOp->getRegions().front().getArguments(),
697 argumentActivity)) {
698 if (argActivity != enzyme::Activity::enzyme_dup &&
699 argActivity != enzyme::Activity::enzyme_dupnoneed) {
700 continue;
701 }
702 auto *argAliasClasses =
703 getOrCreateFor<AliasClassLattice>(getProgramPointBefore(op), arg);
704 ChangeResult changed =
705 argAliasClasses->getAliasClassesObject().foreachElement(
706 [before](DistinctAttr argAliasClass,
708 if (state == enzyme::AliasClassSet::State::Undefined)
709 return ChangeResult::NoChange;
710 return before->setActiveOut(argAliasClass);
711 });
712 propagateIfChanged(before, changed);
713 }
714
715 // Initialize the return activity of the operands
716 for (Value operand : op->getOperands()) {
717 if (isa<MemRefType, LLVM::LLVMPointerType>(operand.getType())) {
718 auto *retAliasClasses = getOrCreateFor<AliasClassLattice>(
719 getProgramPointBefore(op), operand);
720 ChangeResult changed =
721 retAliasClasses->getAliasClassesObject().foreachElement(
722 [before](DistinctAttr retAliasClass,
724 if (state == enzyme::AliasClassSet::State::Undefined)
725 return ChangeResult::NoChange;
726 return before->setActiveOut(retAliasClass);
727 });
728 propagateIfChanged(before, changed);
729 }
730 }
731 }
732
733 meet(before, after);
734 ChangeResult result = ChangeResult::NoChange;
735 auto memory = dyn_cast<MemoryEffectOpInterface>(op);
736 // If we can't reason about the memory effects, then conservatively assume
737 // we can't deduce anything about activity via side-effects.
738 if (!memory)
739 return success();
740
741 SmallVector<MemoryEffects::EffectInstance> effects;
742 memory.getEffects(effects);
743
744 for (const auto &effect : effects) {
745 Value value = effect.getValue();
746
747 // If we see an effect on anything other than a value, assume we can't
748 // deduce anything about the activity.
749 if (!value)
750 return success();
751
752 // In backward-flow, a value is active if stored into a memory resource
753 // that has subsequently been actively loaded from.
754 if (isa<MemoryEffects::Read>(effect.getEffect())) {
755 for (Value opResult : op->getResults()) {
756 auto *valueState = getOrCreateFor<BackwardValueActivity>(
757 getProgramPointBefore(op), opResult);
758 if (valueState->getValue().isActiveVal()) {
759 auto *ptrAliasClass = getOrCreateFor<AliasClassLattice>(
760 getProgramPointBefore(op), value);
761 forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
762 result |= before->setActiveOut(alloc);
763 });
764 }
765 }
766 }
767 if (isa<MemoryEffects::Write>(effect.getEffect())) {
768 auto *ptrAliasClass =
769 getOrCreateFor<AliasClassLattice>(getProgramPointBefore(op), value);
770 std::optional<Value> stored = getStored(op);
771 std::optional<Value> copySource = getCopySource(op);
772 forEachAliasedAlloc(ptrAliasClass, [&](DistinctAttr alloc) {
773 if (stored.has_value() && after.activeDataFlowsOut(alloc)) {
774 if (isa<FloatType, ComplexType>(stored->getType())) {
775 auto *valueState = getOrCreate<BackwardValueActivity>(*stored);
776 propagateIfChanged(
777 valueState, valueState->meet(ValueActivity::getActiveVal()));
778 }
779 } else if (copySource.has_value() &&
780 after.activeDataFlowsOut(alloc)) {
781 auto *srcAliasClass = getOrCreateFor<AliasClassLattice>(
782 getProgramPointBefore(op), *copySource);
783 forEachAliasedAlloc(srcAliasClass, [&](DistinctAttr srcAlloc) {
784 result |= before->setActiveOut(srcAlloc);
785 });
786 } else if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
787 if (after.activeDataFlowsOut(alloc)) {
788 for (OpOperand &dpsInit : linalgOp.getDpsInitsMutable()) {
789 if (dpsInit.get() == value) {
790 int64_t resultIndex =
791 dpsInit.getOperandNumber() - linalgOp.getNumDpsInputs();
792 Value yieldOperand =
793 linalgOp.getBlock()->getTerminator()->getOperand(
794 resultIndex);
795 auto *valueState =
796 getOrCreate<BackwardValueActivity>(yieldOperand);
797 propagateIfChanged(
798 valueState,
799 valueState->meet(ValueActivity::getActiveVal()));
800 }
801 }
802 }
803 }
804 });
805 }
806 }
807 propagateIfChanged(before, result);
808 return success();
809 }
810
811 void visitCallControlFlowTransfer(CallOpInterface call,
812 CallControlFlowAction action,
813 const BackwardMemoryActivity &after,
814 BackwardMemoryActivity *before) override {
815 meet(before, after);
816 }
817
818 void setToExitState(BackwardMemoryActivity *lattice) override {}
819
820private:
821 FunctionOpInterface parentOp;
822 SmallVector<enzyme::Activity> argumentActivity;
823};
824
825void traverseCallGraph(FunctionOpInterface root,
826 SymbolTableCollection *symbolTable,
827 function_ref<void(FunctionOpInterface)> processFunc) {
828 std::deque<FunctionOpInterface> frontier{root};
829 DenseSet<FunctionOpInterface> visited{root};
830
831 while (!frontier.empty()) {
832 FunctionOpInterface curr = frontier.front();
833 frontier.pop_front();
834 processFunc(curr);
835
836 curr.walk([&](CallOpInterface call) {
837 auto neighbor = dyn_cast_if_present<FunctionOpInterface>(
838 call.resolveCallableInTable(symbolTable));
839 if (neighbor && !visited.contains(neighbor)) {
840 frontier.push_back(neighbor);
841 visited.insert(neighbor);
842 }
843 });
844 }
845}
846
847void printActivityAnalysisResults(DataFlowSolver &solver,
848 FunctionOpInterface callee,
849 const SmallPtrSet<Operation *, 2> &returnOps,
850 SymbolTableCollection *symbolTable,
851 bool verbose, bool annotate) {
852 auto isActiveData = [&](Value value) {
853 auto fva = solver.lookupState<ForwardValueActivity>(value);
854 auto bva = solver.lookupState<BackwardValueActivity>(value);
855 bool forwardActive = fva && fva->getValue().isActiveVal();
856 bool backwardActive = bva && bva->getValue().isActiveVal();
857 return forwardActive && backwardActive;
858 };
859
860 auto isConstantValue = [&](Value value) {
861 // TODO: integers/vectors that might be pointers
862 if (isa<LLVM::LLVMPointerType, MemRefType>(value.getType())) {
863 assert(returnOps.size() == 1);
864 auto *fma = solver.lookupState<ForwardMemoryActivity>(
865 solver.getProgramPointAfter(*returnOps.begin()));
866 auto *bma = solver.lookupState<BackwardMemoryActivity>(
867 solver.getProgramPointBefore(
868 &callee.getFunctionBody().front().front()));
869
870 const enzyme::PointsToSets *pointsToSets =
871 solver.lookupState<enzyme::PointsToSets>(
872 solver.getProgramPointAfter(*returnOps.begin()));
873 auto *aliasClassLattice = solver.lookupState<AliasClassLattice>(value);
874 // Traverse the points-to sets in a simple BFS
875 std::deque<DistinctAttr> frontier;
876 DenseSet<DistinctAttr> visited;
877 auto scheduleVisit = [&](const enzyme::AliasClassSet &aliasClasses) {
878 (void)aliasClasses.foreachElement(
879 [&](DistinctAttr neighbor, enzyme::AliasClassSet::State state) {
880 assert(neighbor &&
881 "unhandled undefined/unknown case before visit");
882 if (!visited.contains(neighbor)) {
883 visited.insert(neighbor);
884 frontier.push_back(neighbor);
885 }
886 return ChangeResult::NoChange;
887 });
888 };
889
890 // If this triggers, investigate why the alias classes weren't computed.
891 // If they weren't computed legitimately, treat the value as
892 // conservatively non-constant or change the return type to be tri-state.
893 assert(!aliasClassLattice->isUndefined() &&
894 "didn't compute alias classes");
895
896 if (aliasClassLattice->isUnknown()) {
897 // Pointers of unknown class may point to active data.
898 // TODO: is this overly conservative? Should we rather check
899 // if listed classes may point to non-constants?
900 return false;
901 } else {
902 scheduleVisit(aliasClassLattice->getAliasClassesObject());
903 }
904 while (!frontier.empty()) {
905 DistinctAttr aliasClass = frontier.front();
906 frontier.pop_front();
907
908 // It's an active pointer if active data flows in from the forward
909 // direction and out from the backward direction.
910 if (fma->hasActiveData(aliasClass) &&
911 bma->activeDataFlowsOut(aliasClass))
912 return false;
913
914 // If this triggers, investigate why points-to sets couldn't be
915 // computed. Treat conservatively as "unknown" if necessary.
916 assert(!pointsToSets->getPointsTo(aliasClass).isUndefined() &&
917 "couldn't compute points-to sets");
918
919 // Pointers to unknown classes may (transitively) point to active data.
920 if (pointsToSets->getPointsTo(aliasClass).isUnknown())
921 return false;
922
923 scheduleVisit(pointsToSets->getPointsTo(aliasClass));
924 }
925 // Otherwise, it's constant
926 return true;
927 }
928
929 return !isActiveData(value);
930 };
931
932 std::function<bool(Operation *)> isConstantInstruction = [&](Operation *op) {
933 if (isPure(op)) {
934 // If an operation doesn't have side effects, the only way it can
935 // propagate active data is through its results.
936 return llvm::none_of(op->getResults(), isActiveData);
937 }
938 // We need a special case because stores of active pointers don't fit the
939 // definition but are active instructions
940 if (auto storeOp = dyn_cast<LLVM::StoreOp>(op)) {
941 if (!isConstantValue(storeOp.getValue())) {
942 return false;
943 }
944 } else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
945 // TODO: Should traverse bottom-up for performance (or cache
946 // intermediate results)
947 auto callable = cast<CallableOpInterface>(callOp.resolveCallable());
948 if (callable.getCallableRegion()) {
949 // If any of the instructions in the body are active instructions, the
950 // function is active.
951 WalkResult result = callable->walk([&](Operation *op) {
952 if (!isConstantInstruction(op)) {
953 return WalkResult::interrupt();
954 }
955 return WalkResult::advance();
956 });
957 return !result.wasInterrupted();
958 } else {
959 // fall back to seeing if any operand or result is active data
960 }
961 }
962 return llvm::none_of(op->getOperands(), isActiveData) &&
963 llvm::none_of(op->getResults(), isActiveData);
964 };
965
966 errs() << FlatSymbolRefAttr::get(callee) << ":\n";
967 for (BlockArgument arg : callee.getArguments()) {
968 if (Attribute tagAttr =
969 callee.getArgAttr(arg.getArgNumber(), "enzyme.tag")) {
970 errs() << " " << tagAttr << ": "
971 << (isConstantValue(arg) ? "Constant" : "Active") << "\n";
972 }
973 }
974
975 if (annotate) {
976 MLIRContext *ctx = callee.getContext();
977 traverseCallGraph(callee, symbolTable, [&](FunctionOpInterface func) {
978 func.walk([&](Operation *op) {
979 if (op == func) {
980 SmallVector<bool> argICVs(func.getNumArguments());
981 llvm::transform(func.getArguments(), argICVs.begin(),
982 isConstantValue);
983 func->setAttr("enzyme.icv", DenseBoolArrayAttr::get(ctx, argICVs));
984 return;
985 }
986
987 op->setAttr("enzyme.ici",
988 BoolAttr::get(ctx, isConstantInstruction(op)));
989
990 bool icv;
991 if (op->getNumResults() == 0) {
992 icv = true;
993 } else if (op->getNumResults() == 1) {
994 icv = isConstantValue(op->getResult(0));
995 } else {
996 op->emitWarning(
997 "annotating icv for op that produces multiple results");
998 icv = false;
999 }
1000 op->setAttr("enzyme.icv", BoolAttr::get(ctx, icv));
1001 });
1002 });
1003 }
1004 callee.walk([&](Operation *op) {
1005 if (op->hasAttr("tag")) {
1006 errs() << " " << op->getAttr("tag") << ": ";
1007 for (OpResult opResult : op->getResults()) {
1008 errs() << (isConstantValue(opResult) ? "Constant" : "Active") << "\n";
1009 }
1010 }
1011 if (verbose) {
1012 // Annotate each op's results with its value activity states
1013 for (OpResult result : op->getResults()) {
1014 auto forwardValueActivity =
1015 solver.lookupState<ForwardValueActivity>(result);
1016 if (forwardValueActivity) {
1017 std::string dest, key{"fva"};
1018 llvm::raw_string_ostream os(dest);
1019 if (op->getNumResults() != 1)
1020 key += result.getResultNumber();
1021 forwardValueActivity->getValue().print(os);
1022 op->setAttr(key, StringAttr::get(op->getContext(), dest));
1023 }
1024
1025 auto backwardValueActivity =
1026 solver.lookupState<BackwardValueActivity>(result);
1027 if (backwardValueActivity) {
1028 std::string dest, key{"bva"};
1029 llvm::raw_string_ostream os(dest);
1030 if (op->getNumResults() != 1)
1031 key += result.getResultNumber();
1032 backwardValueActivity->getValue().print(os);
1033 op->setAttr(key, StringAttr::get(op->getContext(), dest));
1034 }
1035 }
1036 }
1037 });
1038
1039 if (verbose) {
1040 // Annotate function attributes
1041 for (BlockArgument arg : callee.getArguments()) {
1042 auto backwardValueActivity =
1043 solver.lookupState<BackwardValueActivity>(arg);
1044 if (backwardValueActivity) {
1045 std::string dest;
1046 llvm::raw_string_ostream os(dest);
1047 backwardValueActivity->getValue().print(os);
1048 callee.setArgAttr(arg.getArgNumber(), "enzyme.bva",
1049 StringAttr::get(callee->getContext(), dest));
1050 }
1051 }
1052
1053 for (Operation *returnOp : returnOps) {
1054 auto *state = solver.lookupState<ForwardMemoryActivity>(
1055 solver.getProgramPointAfter(returnOp));
1056 if (state)
1057 errs() << "forward end state:\n" << *state << "\n";
1058 else
1059 errs() << "state was null\n";
1060 }
1061
1062 auto startState = solver.lookupState<BackwardMemoryActivity>(
1063 solver.getProgramPointAfter(&callee.getFunctionBody().front().front()));
1064 if (startState)
1065 errs() << "backwards end state:\n" << *startState << "\n";
1066 else
1067 errs() << "backwards end state was null\n";
1068 }
1069}
1070
1072 FunctionOpInterface callee, ArrayRef<enzyme::Activity> argumentActivity,
1073 bool print, bool verbose, bool annotate) {
1074 SymbolTableCollection symbolTable;
1075 DataFlowSolver solver;
1076
1077 solver.load<enzyme::PointsToPointerAnalysis>();
1078 solver.load<enzyme::AliasAnalysis>(callee.getContext());
1079 solver.load<SparseForwardActivityAnalysis>();
1080 solver.load<DenseForwardActivityAnalysis>(&callee.getFunctionBody().front(),
1081 argumentActivity);
1082 solver.load<SparseBackwardActivityAnalysis>(symbolTable);
1083 solver.load<DenseBackwardActivityAnalysis>(symbolTable, callee,
1084 argumentActivity);
1085
1086 // Required for the dataflow framework to traverse region-based control flow
1087 solver.load<DeadCodeAnalysis>();
1088 solver.load<SparseConstantPropagation>();
1089
1090 // Initialize the argument states based on the given activity annotations.
1091 for (const auto &[arg, activity] :
1092 llvm::zip(callee.getArguments(), argumentActivity)) {
1093 // enzyme_dup, dupnoneed are initialized within the dense forward/backward
1094 // analyses, enzyme_const is the default.
1095 if (activity == enzyme::Activity::enzyme_active) {
1096 auto *argLattice = solver.getOrCreateState<ForwardValueActivity>(arg);
1097 (void)argLattice->join(ValueActivity::getActiveVal());
1098 }
1099 }
1100
1101 // Detect function returns as direct children of the FunctionOpInterface
1102 // that have the ReturnLike trait.
1103 SmallPtrSet<Operation *, 2> returnOps;
1104 for (Operation &op : callee.getFunctionBody().getOps()) {
1105 if (op.hasTrait<OpTrait::ReturnLike>()) {
1106 returnOps.insert(&op);
1107 for (Value operand : op.getOperands()) {
1108 auto *returnLattice =
1109 solver.getOrCreateState<BackwardValueActivity>(operand);
1110 // Very basic type inference of the type
1111 if (isa<FloatType, ComplexType>(operand.getType())) {
1112 (void)returnLattice->meet(ValueActivity::getActiveVal());
1113 }
1114 }
1115 }
1116 }
1117
1118 if (failed(solver.initializeAndRun(callee->getParentOfType<ModuleOp>()))) {
1119 assert(false && "dataflow analysis failed\n");
1120 }
1121
1122 if (print) {
1123 printActivityAnalysisResults(solver, callee, returnOps, &symbolTable,
1124 verbose, annotate);
1125 }
1126}
void printActivityAnalysisResults(DataFlowSolver &solver, FunctionOpInterface callee, const SmallPtrSet< Operation *, 2 > &returnOps, SymbolTableCollection *symbolTable, bool verbose, bool annotate)
std::optional< Value > getStored(Operation *op)
void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass, function_ref< void(DistinctAttr)> forEachFn)
The dense analyses operate using a pointer's "canonical allocation", the Value corresponding to its a...
raw_ostream & operator<<(raw_ostream &os, const ValueActivity &val)
std::optional< Value > getCopySource(Operation *op)
ActivityKind
From LLVM Enzyme's activity analysis, there are four activity states.
void traverseCallGraph(FunctionOpInterface root, SymbolTableCollection *symbolTable, function_ref< void(FunctionOpInterface)> processFunc)
ChangeResult meet(const AbstractDenseLattice &lattice) override
ChangeResult meet(ValueActivity other)
void print(raw_ostream &os) const override
ChangeResult meet(const AbstractSparseLattice &other) override
void visitCallControlFlowTransfer(CallOpInterface call, CallControlFlowAction action, const BackwardMemoryActivity &after, BackwardMemoryActivity *before) override
LogicalResult visitOperation(Operation *op, const BackwardMemoryActivity &after, BackwardMemoryActivity *before) override
void setToExitState(BackwardMemoryActivity *lattice) override
DenseBackwardActivityAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, FunctionOpInterface parentOp, ArrayRef< enzyme::Activity > argumentActivity)
LogicalResult visitOperation(Operation *op, const ForwardMemoryActivity &before, ForwardMemoryActivity *after) override
DenseForwardActivityAnalysis(DataFlowSolver &solver, Block *entryBlock, ArrayRef< enzyme::Activity > argumentActivity)
void setToEntryState(ForwardMemoryActivity *lattice) override
Initialize the entry block with the supplied argument activities.
void visitCallControlFlowTransfer(CallOpInterface call, CallControlFlowAction action, const ForwardMemoryActivity &before, ForwardMemoryActivity *after) override
ChangeResult join(const AbstractDenseLattice &lattice)
Join the activity states.
raw_ostream & operator<<(raw_ostream &os) const
void print(raw_ostream &os) const override
ChangeResult merge(const AbstractDenseLattice &lattice)
bool hasActiveData(DistinctAttr aliasClass) const
ChangeResult setActiveOut(DistinctAttr aliasClass)
bool activeDataFlowsOut(DistinctAttr aliasClass) const
ChangeResult setActiveIn(DistinctAttr aliasClass)
Set the internal activity state.
ChangeResult reset()
Clear all modifications.
void visitCallOperand(OpOperand &operand) override
LogicalResult visitOperation(Operation *op, ArrayRef< BackwardValueActivity * > operands, ArrayRef< const BackwardValueActivity * > results) override
void visitExternalCall(CallOpInterface call, ArrayRef< BackwardValueActivity * > operands, ArrayRef< const BackwardValueActivity * > results) override
void transfer(Operation *op, ArrayRef< BackwardValueActivity * > operands, ArrayRef< const BackwardValueActivity * > results)
void visitNonControlFlowArguments(RegionSuccessor &successor, ArrayRef< BlockArgument > arguments) override
void setToExitState(BackwardValueActivity *lattice) override
void visitBranchOperand(OpOperand &operand) override
Sparse activity analysis reasons about activity by traversing forward down the def-use chains startin...
void setToEntryState(ForwardValueActivity *lattice) override
In general, we don't know anything about entry operands.
LogicalResult visitOperation(Operation *op, ArrayRef< const ForwardValueActivity * > operands, ArrayRef< ForwardValueActivity * > results) override
void transfer(Operation *op, ArrayRef< const ForwardValueActivity * > operands, ArrayRef< ForwardValueActivity * > results)
void visitExternalCall(CallOpInterface call, ArrayRef< const ForwardValueActivity * > operands, ArrayRef< ForwardValueActivity * > results) override
ValueActivity(ActivityKind value)
static ValueActivity getActiveVal()
bool operator==(const ValueActivity &rhs) const
raw_ostream & operator<<(raw_ostream &os) const
static ValueActivity join(const ValueActivity &lhs, const ValueActivity &rhs)
static ValueActivity getConstant()
static ValueActivity getUnknown()
void print(raw_ostream &os) const
static ValueActivity merge(const ValueActivity &lhs, const ValueActivity &rhs)
const ActivityKind & getValue() const
Get the known activity state.
This analysis implements interprocedural alias analysis.
void runDataFlowActivityAnalysis(FunctionOpInterface callee, ArrayRef< enzyme::Activity > argumentActivity, bool print=false, bool verbose=false, bool annotate=false)
This needs to keep track of three things:
bool operator!=(const MemoryActivityState &other)
ChangeResult merge(const MemoryActivityState &other)
bool activeOut
Whether active data was loaded out of this memory location.
bool operator==(const MemoryActivityState &other)
bool activeIn
Whether active data has stored into this memory location.