19#include "mlir/Analysis/AliasAnalysis.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/Interfaces/FunctionInterfaces.h"
23#include "mlir/Interfaces/SideEffectInterfaces.h"
24#include "llvm/ADT/DenseMap.h"
26#define DEBUG_TYPE "enzyme-diff-batch"
27#define ENZYME_DBGS llvm::dbgs() << "[" << DEBUG_TYPE << "]"
31using namespace enzyme;
35#define GEN_PASS_DEF_BATCHDIFFPASS
36#include "Passes/Passes.h.inc"
42struct BatchDiffPass :
public enzyme::impl::BatchDiffPassBase<BatchDiffPass> {
43 void runOnOperation()
override;
45 void mergeFwddiffCalls(SymbolTableCollection &symbolTable,
46 FunctionOpInterface op) {
49 llvm::DenseMap<FunctionOpInterface,
50 SmallVector<MemoryEffects::EffectInstance>>
53 OpBuilder builder(op);
55 op->walk([&](Block *blk) {
58 SmallVector<enzyme::ForwardDiffOp>>
61 for (
auto fwdOp : blk->getOps<enzyme::ForwardDiffOp>()) {
62 auto fnOp = dyn_cast_or_null<FunctionOpInterface>(
63 symbolTable.lookupNearestSymbolFrom(fwdOp, fwdOp.getFnAttr()));
70 toMerge[key].push_back(fwdOp);
73 for (
auto &pair : toMerge) {
74 auto key = pair.first;
75 auto allDiffs = pair.second;
76 if (allDiffs.size() < 2)
80 if (!innerEffectCache.contains(key.function)) {
81 innerEffectCache[key.function] =
85 SmallVector<MemoryEffects::EffectInstance> &calleeEffects =
86 innerEffectCache[key.function];
89 bool skipMergeEntry =
false;
104 llvm::DenseMap<ForwardDiffOp,
105 SmallVector<MemoryEffects::EffectInstance>>
108 for (
auto &eff : calleeEffects) {
109 if (!isa<MemoryEffects::Read>(eff.getEffect()) &&
110 !isa<MemoryEffects::Write>(eff.getEffect())) {
112 skipMergeEntry =
true;
116 Value effVal = eff.getValue();
119 skipMergeEntry =
true;
124 size_t primalArgPos = 0;
125 bool foundPrimal =
false;
126 if (
auto effBA = dyn_cast<BlockArgument>(effVal)) {
127 if (llvm::is_contained(key.function.getArguments(), effBA)) {
129 primalArgPos = effBA.getArgNumber();
137 skipMergeEntry =
true;
142 Value primalVal = key.inputs[primalArgPos];
143 for (
auto dop : allDiffs) {
145 primalVal, eff.getEffect(), eff.getResource()));
154 (key.inActivity[primalArgPos] == Activity::enzyme_dup) ||
155 (key.inActivity[primalArgPos] == Activity::enzyme_dupnoneed);
158 size_t gradArgPos = 0;
159 for (
auto [idx, act] : llvm::enumerate(key.inActivity)) {
162 if (idx == primalArgPos)
165 if (act == Activity::enzyme_dup ||
166 act == Activity::enzyme_dupnoneed) {
171 for (
auto dop : allDiffs) {
172 Value dVal = dop.getInputs()[gradArgPos];
174 dVal, eff.getEffect(), eff.getResource()));
182 SmallVector<ForwardDiffOp> prunedSources =
186 symbolTable, key, prunedSources, callerEffectMap, innerEffectCache);
190 SmallVector<enzyme::ForwardDiffOp> &allOps = legalMerge;
191 int64_t width = allOps.size();
197 auto firstDiffOp = allOps.front();
198 IRRewriter::InsertionGuard insertGuard(builder);
199 builder.setInsertionPoint(firstDiffOp);
200 auto loc = firstDiffOp->getLoc();
201 auto context = builder.getContext();
203 SmallVector<mlir::Value> in_args;
204 SmallVector<ActivityAttr, 2> inActivityAttrs;
205 SmallVector<ActivityAttr, 2> retActivityAttrs;
206 SmallVector<mlir::Type, 2> out_ty;
210 for (
auto [idx, act] : llvm::enumerate(key.inActivity)) {
211 ActivityAttr iattr = ActivityAttr::get(context, act);
212 inActivityAttrs.push_back(iattr);
213 in_args.push_back(key.inputs[in_idx]);
216 SmallVector<mlir::Value> derivList;
217 if (act == Activity::enzyme_dup ||
218 act == Activity::enzyme_dupnoneed) {
219 for (
auto uop : allOps) {
220 derivList.push_back(uop.getInputs()[in_idx]);
223 mlir::Value batchedDeriv =
225 in_args.push_back(batchedDeriv);
232 for (
auto [idx, ract] : llvm::enumerate(key.retActivity)) {
233 ActivityAttr iattr = ActivityAttr::get(context, ract);
235 retActivityAttrs.push_back(iattr);
238 case Activity::enzyme_active: {
239 mlir::Value res = firstDiffOp.getOutputs()[out_idx];
240 out_ty.push_back(res.getType());
245 case Activity::enzyme_const: {
246 mlir::Value res = firstDiffOp.getOutputs()[out_idx];
247 out_ty.push_back(res.getType());
252 case Activity::enzyme_dupnoneed: {
255 mlir::Value dres = firstDiffOp.getOutputs()[out_idx];
261 case Activity::enzyme_dup: {
262 mlir::Value res = firstDiffOp.getOutputs()[out_idx];
263 out_ty.push_back(res.getType());
268 mlir::Value dres = firstDiffOp.getOutputs()[out_idx];
274 case Activity::enzyme_constnoneed: {
278 case Activity::enzyme_activenoneed: {
279 mlir::Value res = firstDiffOp.getOutputs()[out_idx];
280 out_ty.push_back(res.getType());
287 "unknown activity value encountered for ret_activity");
292 ArrayAttr newInActivity = ArrayAttr::get(
293 context, llvm::ArrayRef<Attribute>(inActivityAttrs.begin(),
294 inActivityAttrs.end()));
296 ArrayAttr newRetActivity = ArrayAttr::get(
297 context, llvm::ArrayRef<Attribute>(retActivityAttrs.begin(),
298 retActivityAttrs.end()));
300 IntegerAttr newWidthAttr =
301 IntegerAttr::get(firstDiffOp.getWidthAttr().getType(), width);
303 auto newDiffOp = ForwardDiffOp::create(
304 builder, loc, out_ty, firstDiffOp.getFnAttr(), in_args,
305 newInActivity, newRetActivity, newWidthAttr,
306 firstDiffOp.getStrongZeroAttr());
310 for (
auto [idx, ract] : llvm::enumerate(key.retActivity)) {
312 case Activity::enzyme_constnoneed:
315 case Activity::enzyme_const: {
316 auto new_out = newDiffOp.getOutputs()[out_idx];
318 for (
auto dop : allOps) {
319 dop.getOutputs()[out_idx].replaceAllUsesWith(new_out);
326 case Activity::enzyme_dupnoneed: {
328 auto batch_dout = newDiffOp.getOutputs()[out_idx];
329 for (
auto [dop_idx, dop] : llvm::enumerate(allOps)) {
330 auto old_dout = dop.getOutputs()[out_idx];
331 auto doutTy = old_dout.getType();
335 old_dout.replaceAllUsesWith(new_dout);
341 case Activity::enzyme_dup: {
342 mlir::Value new_out = newDiffOp.getOutputs()[out_idx];
344 for (ForwardDiffOp dop : allOps) {
345 dop.getOutputs()[out_idx].replaceAllUsesWith(new_out);
350 auto batch_dout = newDiffOp.getOutputs()[out_idx];
351 for (
auto [dop_idx, dop] : llvm::enumerate(allOps)) {
353 auto old_dout = dop.getOutputs()[out_idx];
354 auto doutTy = old_dout.getType();
358 old_dout.replaceAllUsesWith(new_dout);
364 case Activity::enzyme_active: {
365 auto new_out = newDiffOp.getOutputs()[out_idx];
367 for (ForwardDiffOp dop : allOps) {
368 dop.getOutputs()[out_idx].replaceAllUsesWith(new_out);
373 case Activity::enzyme_activenoneed: {
374 auto new_out = newDiffOp.getOutputs()[out_idx];
376 for (ForwardDiffOp dop : allOps) {
377 dop.getOutputs()[out_idx].replaceAllUsesWith(new_out);
386 for (
auto dop : allOps) {
394 void mergeRevdiffCalls(SymbolTableCollection &symbolTable,
395 FunctionOpInterface op) {
401 llvm::DenseMap<FunctionOpInterface,
402 SmallVector<MemoryEffects::EffectInstance>>
405 OpBuilder builder(op);
407 op->walk([&](Block *blk) {
410 SmallVector<enzyme::AutoDiffOp>>
413 for (
auto revOp : blk->getOps<enzyme::AutoDiffOp>()) {
414 auto fnOp = dyn_cast_or_null<FunctionOpInterface>(
415 symbolTable.lookupNearestSymbolFrom(revOp, revOp.getFnAttr()));
422 toMerge[key].push_back(revOp);
425 for (
auto &pair : toMerge) {
426 auto key = pair.first;
427 auto allDiffs = pair.second;
428 if (allDiffs.size() < 2)
432 if (!innerEffectCache.contains(key.
function)) {
437 SmallVector<MemoryEffects::EffectInstance> &calleeEffects =
441 bool skipMergeEntry =
false;
443 llvm::DenseMap<AutoDiffOp, SmallVector<MemoryEffects::EffectInstance>>
446 for (
auto &eff : calleeEffects) {
447 if (!isa<MemoryEffects::Read>(eff.getEffect()) &&
448 !isa<MemoryEffects::Write>(eff.getEffect())) {
450 skipMergeEntry =
true;
454 Value effVal = eff.getValue();
457 skipMergeEntry =
true;
462 size_t primalArgPos = 0;
463 bool foundPrimal =
false;
464 if (
auto effBA = dyn_cast<BlockArgument>(effVal)) {
465 if (llvm::is_contained(key.
function.getArguments(), effBA)) {
467 primalArgPos = effBA.getArgNumber();
475 skipMergeEntry =
true;
480 Value primalVal = key.
inputs[primalArgPos];
481 for (
auto dop : allDiffs) {
483 primalVal, eff.getEffect(), eff.getResource()));
492 (key.
inActivity[primalArgPos] == Activity::enzyme_dup) ||
493 (key.
inActivity[primalArgPos] == Activity::enzyme_dupnoneed);
496 size_t gradArgPos = 0;
497 for (
auto [idx, act] : llvm::enumerate(key.
inActivity)) {
500 if (idx == primalArgPos)
503 if (act == Activity::enzyme_dup ||
504 act == Activity::enzyme_dupnoneed) {
509 for (
auto dop : allDiffs) {
510 Value dVal = dop.getInputs()[gradArgPos];
512 dVal, MemoryEffects::Write::get(), eff.getResource()));
514 dVal, MemoryEffects::Read::get(), eff.getResource()));
522 SmallVector<AutoDiffOp> prunedSources =
526 symbolTable, key, prunedSources, callerEffectMap, innerEffectCache);
531 SmallVector<enzyme::AutoDiffOp> &allOps = legalMerge;
532 int64_t width = allOps.size();
537 auto firstDiffOp = allOps.front();
538 IRRewriter::InsertionGuard insertGuard(builder);
539 builder.setInsertionPoint(firstDiffOp);
540 auto loc = firstDiffOp->getLoc();
541 auto context = builder.getContext();
545 SmallVector<mlir::Value> in_args;
546 SmallVector<ActivityAttr, 2> inActivityAttrs;
547 SmallVector<ActivityAttr, 2> retActivityAttrs;
548 SmallVector<mlir::Type, 2> out_ty;
552 for (
auto [idx, act] : llvm::enumerate(key.
inActivity)) {
553 auto iattr = ActivityAttr::get(context, act);
554 inActivityAttrs.push_back(iattr);
555 in_args.push_back(key.
inputs[call_idx]);
558 if (act == Activity::enzyme_dup ||
559 act == Activity::enzyme_dupnoneed) {
561 SmallVector<mlir::Value> derivList;
562 for (
auto uop : allOps) {
563 derivList.push_back(uop.getInputs()[call_idx]);
568 in_args.push_back(b_din);
574 if (call_idx == firstDiffOp.getInputs().size()) {
581 auto iattr = ActivityAttr::get(context, ract);
582 retActivityAttrs.push_back(iattr);
585 if (ract == Activity::enzyme_constnoneed ||
586 ract == Activity::enzyme_dupnoneed) {
591 if (ract == Activity::enzyme_active ||
592 ract == Activity::enzyme_activenoneed) {
593 SmallVector<mlir::Value> derivList;
594 for (
auto uop : allOps) {
595 derivList.push_back(uop.getInputs()[call_idx]);
599 in_args.push_back(batch_dout);
604 if (ract == Activity::enzyme_active ||
605 ract == Activity::enzyme_const ||
606 ract == Activity::enzyme_dup) {
607 Value out = firstDiffOp.getOutputs()[out_idx];
608 out_ty.push_back(out.getType());
615 if (act == Activity::enzyme_active) {
616 Value din = firstDiffOp.getOutputs()[out_idx];
622 ArrayAttr newInActivity = ArrayAttr::get(
623 context, llvm::ArrayRef<Attribute>(inActivityAttrs.begin(),
624 inActivityAttrs.end()));
626 ArrayAttr newRetActivity = ArrayAttr::get(
627 context, llvm::ArrayRef<Attribute>(retActivityAttrs.begin(),
628 retActivityAttrs.end()));
630 IntegerAttr newWidthAttr =
631 IntegerAttr::get(firstDiffOp.getWidthAttr().getType(), width);
634 AutoDiffOp::create(builder, loc, out_ty, firstDiffOp.getFnAttr(),
635 in_args, newInActivity, newRetActivity,
636 newWidthAttr, firstDiffOp.getStrongZeroAttr());
641 if (ract == Activity::enzyme_active ||
642 ract == Activity::enzyme_const ||
643 ract == Activity::enzyme_dup) {
644 Value new_out = newDiffOp.getOutputs()[out_idx];
645 for (
auto dop : allOps) {
646 dop.getOutputs()[out_idx].replaceAllUsesWith(new_out);
653 if (act == Activity::enzyme_active) {
654 Value batch_din = newDiffOp.getOutputs()[out_idx];
655 for (
auto [dop_idx, dop] : llvm::enumerate(allOps)) {
656 Value old_din = dop.getOutputs()[out_idx];
657 auto dinTy = old_din.getType();
661 old_din.replaceAllUsesWith(new_din);
666 for (
auto dop : allOps) {
677void BatchDiffPass::runOnOperation() {
678 SymbolTableCollection symbolTable;
679 symbolTable.getSymbolTable(getOperation());
680 getOperation()->walk([&](FunctionOpInterface op) {
681 mergeFwddiffCalls(symbolTable, op);
682 mergeRevdiffCalls(symbolTable, op);
BatchDiffCacheKey createDiffCacheKey(SourceOp uop, FunctionOpInterface fn)
llvm::SmallVector< SourceOp, 2 > pruneGradDefs(BatchDiffCacheKey &key, SmallVector< SourceOp > &allDiffs)
llvm::SmallVector< SourceOp > pruneMemoryEffects(SymbolTableCollection &symbolTable, BatchDiffCacheKey &key, SmallVector< SourceOp > &prunedSources, DenseMap< SourceOp, SmallVector< MemoryEffects::EffectInstance > > &callerEffectMap, llvm::DenseMap< FunctionOpInterface, SmallVector< MemoryEffects::EffectInstance > > &innerEffectCache)
SmallVector< MemoryEffects::EffectInstance > collectFnEffects(FunctionOpInterface fnOp)
MemoryEffects::EffectInstance getEffectOfVal(Value val, MemoryEffects::Effect *effect, SideEffects::Resource *resource)
Type getConcatType(Value val, int64_t width)
Value getConcatValue(OpBuilder &builder, Location loc, ArrayRef< Value > argList)
Value getExtractValue(OpBuilder &builder, Location loc, Type argTy, Value val, int64_t index)
SmallVector< enzyme::Activity > retActivity
SmallVector< mlir::Value > inputs
SmallVector< enzyme::Activity > inActivity
FunctionOpInterface function