Enzyme main
Loading...
Searching...
No Matches
Ops.cpp
Go to the documentation of this file.
1//===- EnzymeOps.cpp - Enzyme dialect ops -----------------------*- C++ -*-===//
2// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
3// See https://llvm.org/LICENSE.txt for license information.
4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5//
6//===----------------------------------------------------------------------===//
7
8#include "Ops.h"
10#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/IR/Value.h"
16#include "mlir/Interfaces/MemorySlotInterfaces.h"
17
18#include "mlir/Dialect/Func/IR/FuncOps.h"
19#include "mlir/Dialect/MemRef/IR/MemRef.h"
20#include "mlir/IR/IntegerSet.h"
21
22#include "llvm/ADT/STLExtras.h"
23#include "llvm/ADT/SmallVector.h"
24
25#include "llvm/Support/ErrorHandling.h"
26#include "llvm/Support/LogicalResult.h"
27#include <type_traits>
28
29#define DEBUG_TYPE "enzyme"
30
31using namespace mlir;
32using namespace enzyme;
33using namespace mlir::arith;
34
35//===----------------------------------------------------------------------===//
36// InitOp
37//===----------------------------------------------------------------------===//
38
39llvm::SmallVector<MemorySlot> InitOp::getPromotableSlots() {
40 auto Ty = this->getType();
41 if (isa<CacheType>(Ty))
42 return {};
43
44 if (!getOperation()->getBlock()->isEntryBlock())
45 return {};
46
47 auto gTy = cast<GradientType>(Ty);
48 MemorySlot slot = {this->getResult(), gTy.getBasetype()};
49
50 return {slot};
51}
52
53Value InitOp::getDefaultValue(const MemorySlot &slot, OpBuilder &builder) {
54 auto gTy = cast<GradientType>(this->getType());
55 return cast<AutoDiffTypeInterface>(gTy.getBasetype())
56 .createNullValue(builder, this->getLoc());
57}
58
59void InitOp::handleBlockArgument(const MemorySlot &slot, BlockArgument argument,
60 OpBuilder &builder) {}
61
62std::optional<mlir::PromotableAllocationOpInterface>
63InitOp::handlePromotionComplete(const MemorySlot &slot, Value defaultValue,
64 OpBuilder &builder) {
65 if (defaultValue && defaultValue.use_empty())
66 defaultValue.getDefiningOp()->erase();
67 this->erase();
68 return std::nullopt;
69}
70
71//===----------------------------------------------------------------------===//
72// GetOp
73//===----------------------------------------------------------------------===//
74
75bool GetOp::loadsFrom(const MemorySlot &slot) {
76 return this->getGradient() == slot.ptr;
77}
78
79bool GetOp::storesTo(const MemorySlot &slot) { return false; }
80
81Value GetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
82 Value reachingDef, const DataLayout &dataLayout) {
83 return {};
84}
85
86bool GetOp::canUsesBeRemoved(
87 const MemorySlot &slot,
88 const llvm::SmallPtrSetImpl<OpOperand *> &blockingUses,
89 llvm::SmallVectorImpl<OpOperand *> &newBlockingUses,
90 const mlir::DataLayout &dataLayout) {
91 if (blockingUses.size() != 1)
92 return false;
93
94 Value blockingUse = (*blockingUses.begin())->get();
95 return blockingUse == slot.ptr && getGradient() == slot.ptr;
96}
97
98DeletionKind GetOp::removeBlockingUses(
99 const MemorySlot &slot,
100 const llvm::SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder,
101 Value reachingDefinition, const DataLayout &dataLayout) {
102 this->getResult().replaceAllUsesWith(reachingDefinition);
103 return DeletionKind::Delete;
104}
105
106llvm::LogicalResult GetOp::ensureOnlySafeAccesses(
107 const MemorySlot &slot, llvm::SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
108 const DataLayout &dataLayout) {
109 return success(slot.ptr == getGradient());
110}
111
112//===----------------------------------------------------------------------===//
113// SetOp
114//===----------------------------------------------------------------------===//
115
116bool SetOp::loadsFrom(const MemorySlot &slot) { return false; }
117
118bool SetOp::storesTo(const MemorySlot &slot) {
119 return this->getGradient() == slot.ptr;
120}
121
122Value SetOp::getStored(const MemorySlot &slot, OpBuilder &builder,
123 Value reachingDef, const DataLayout &dataLayout) {
124 return this->getValue();
125}
126
127bool SetOp::canUsesBeRemoved(
128 const MemorySlot &slot,
129 const llvm::SmallPtrSetImpl<OpOperand *> &blockingUses,
130 llvm::SmallVectorImpl<OpOperand *> &newBlockingUses,
131 const mlir::DataLayout &dataLayout) {
132 return true;
133}
134
135DeletionKind SetOp::removeBlockingUses(
136 const MemorySlot &slot,
137 const llvm::SmallPtrSetImpl<OpOperand *> &blockingUses, OpBuilder &builder,
138 Value reachingDefinition, const DataLayout &dataLayout) {
139 return DeletionKind::Delete;
140}
141
142llvm::LogicalResult SetOp::ensureOnlySafeAccesses(
143 const MemorySlot &slot, llvm::SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
144 const DataLayout &dataLayout) {
145 return success(slot.ptr == getGradient());
146}
147
148//===----------------------------------------------------------------------===//
149// GetFuncOp
150//===----------------------------------------------------------------------===//
151
152LogicalResult
153ForwardDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
154 // TODO: Verify that the result type is same as the type of the referenced
155 // func.func op.
156 auto global =
157 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
158 if (!global)
159 return emitOpError("'")
160 << getFn() << "' does not reference a valid global funcOp";
161
162 return success();
163}
164
165LogicalResult JacobianOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
166 auto global =
167 symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFnAttr());
168 if (!global)
169 return emitOpError("'")
170 << getFn() << "' does not reference a valid global funcOp";
171
172 return success();
173}
174
175//===----------------------------------------------------------------------===//
176// ForwardDiffOp
177//===----------------------------------------------------------------------===//
178
179// Some templated helpers for rewriting EnzymeOps(we can overload the create
180// definitions as and when necessary)
181template <typename SourceOp> struct EnzymeOpCreator;
182
183template <> struct EnzymeOpCreator<AutoDiffOp> {
184 static AutoDiffOp create(PatternRewriter &rewriter, AutoDiffOp uop,
185 TypeRange out_ty, ValueRange in_args,
186 ArrayAttr newInActivity, ArrayAttr newRetActivity) {
187
188 return AutoDiffOp::create(rewriter, uop.getLoc(), out_ty, uop.getFnAttr(),
189 in_args, newInActivity, newRetActivity,
190 uop.getWidthAttr(), uop.getStrongZeroAttr());
191 }
192};
193
194template <> struct EnzymeOpCreator<AutoDiffRegionOp> {
195 static AutoDiffRegionOp create(PatternRewriter &rewriter,
196 AutoDiffRegionOp uop, TypeRange out_ty,
197 ValueRange in_args, ArrayAttr newInActivity,
198 ArrayAttr newRetActivity) {
199 auto newOp = AutoDiffRegionOp::create(
200 rewriter, uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity,
201 uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr());
202
203 rewriter.inlineRegionBefore(uop.getBody(), newOp.getBody(),
204 newOp.getBody().begin());
205 return newOp;
206 }
207};
208
209template <> struct EnzymeOpCreator<ForwardDiffOp> {
210 static ForwardDiffOp create(PatternRewriter &rewriter, ForwardDiffOp uop,
211 TypeRange out_ty, ValueRange in_args,
212 ArrayAttr newInActivity,
213 ArrayAttr newRetActivity) {
214 return ForwardDiffOp::create(
215 rewriter, uop.getLoc(), out_ty, uop.getFnAttr(), in_args, newInActivity,
216 newRetActivity, uop.getWidthAttr(), uop.getStrongZeroAttr());
217 }
218};
219
220template <> struct EnzymeOpCreator<ForwardDiffRegionOp> {
221 static ForwardDiffRegionOp create(PatternRewriter &rewriter,
222 ForwardDiffRegionOp uop, TypeRange out_ty,
223 ValueRange in_args, ArrayAttr newInActivity,
224 ArrayAttr newRetActivity) {
225 auto newOp = ForwardDiffRegionOp::create(
226 rewriter, uop.getLoc(), out_ty, in_args, newInActivity, newRetActivity,
227 uop.getWidthAttr(), uop.getStrongZeroAttr(), uop.getFnAttr());
228 rewriter.inlineRegionBefore(uop.getBody(), newOp.getBody(),
229 newOp.getBody().begin());
230 return newOp;
231 }
232};
233
234// Helper: check if any input is mutable.
235static inline bool isMutable(Type type) {
236 if (isa<mlir::MemRefType>(type) || isa<mlir::UnrankedMemRefType>(type) ||
237 isa<mlir::LLVM::LLVMPointerType>(type)) {
238 return true;
239 }
240
241 return false;
242}
243
244/**
245 *
246 * Modifies input activites for the FwdDiffOp
247 * The activity promotion flow is as follows
248 * (depending on variable use):
249 *
250 * -----> enzyme_dupnoneed
251 * / /
252 * enzyme_dup /
253 * \ v
254 * ------> enzyme_const
255 *
256 */
257template <typename SourceOp>
258class FwdInpOpt final : public OpRewritePattern<SourceOp> {
259public:
261
262 LogicalResult matchAndRewrite(SourceOp uop,
263 PatternRewriter &rewriter) const override {
264
265 if (uop.getOutputs().size() == 0)
266 return failure();
267
268 auto inActivity = uop.getActivity();
269
270 auto in_idx = 0;
271 SmallVector<mlir::Value, 2> in_args;
272 SmallVector<ActivityAttr, 2> newInActivityArgs;
273 bool changed = false;
274 for (auto [idx, act] : llvm::enumerate(inActivity)) {
275 auto iattr = cast<ActivityAttr>(act);
276 auto val = iattr.getValue();
277
278 // Forward mode Input activities can only take values {dup, dupnoneed,
279 // const }
280
281 mlir::Value inp = uop.getInputs()[in_idx];
282
283 switch (val) {
284
285 case mlir::enzyme::Activity::enzyme_const:
286 in_args.push_back(inp);
287 newInActivityArgs.push_back(iattr);
288 break;
289
290 case Activity::enzyme_dupnoneed: {
291 // always pass in primal
292 in_args.push_back(inp);
293 in_idx++;
294
295 // selectively push or skip directional derivative
296 inp = uop.getInputs()[in_idx];
297 auto ET = inp.getType();
298 auto ETintf = dyn_cast<AutoDiffTypeInterface>(ET);
299
300 if (ETintf && !isMutable(ET) && ETintf.isZero(inp)) {
301 // skip and promote to const
302 auto new_const = mlir::enzyme::ActivityAttr::get(
303 rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
304 newInActivityArgs.push_back(new_const);
305 changed = true;
306 } else {
307 // push derivative value
308 in_args.push_back(inp);
309 newInActivityArgs.push_back(iattr);
310 }
311 break;
312 }
313
314 case Activity::enzyme_dup: {
315 // always pass in primal
316 in_args.push_back(inp);
317 in_idx++;
318
319 // selectively push or skip directional derivative
320 inp = uop.getInputs()[in_idx];
321 auto ET = inp.getType();
322 auto ETintf = dyn_cast<AutoDiffTypeInterface>(ET);
323
324 if (ETintf && !isMutable(ET) && ETintf.isZero(inp)) {
325 // skip and promote to const
326 auto new_const = mlir::enzyme::ActivityAttr::get(
327 rewriter.getContext(), mlir::enzyme::Activity::enzyme_const);
328 newInActivityArgs.push_back(new_const);
329 changed = true;
330 } else {
331 // push derivative value
332 in_args.push_back(inp);
333 newInActivityArgs.push_back(iattr);
334 }
335 break;
336 }
337 default:
338 llvm_unreachable("unexpected input activity arg");
339 }
340
341 in_idx++;
342 }
343
344 if (!changed)
345 return failure();
346
347 // create the new op
348 ArrayAttr newInActivity =
349 ArrayAttr::get(rewriter.getContext(),
350 llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
351 newInActivityArgs.end()));
352
353 if constexpr (std::is_same_v<SourceOp, ForwardDiffOp>) {
354
355 rewriter.replaceOpWithNewOp<ForwardDiffOp>(
356 uop, uop->getResultTypes(), uop.getFnAttr(), in_args, newInActivity,
357 uop.getRetActivityAttr(), uop.getWidthAttr(),
358 uop.getStrongZeroAttr());
359 } else {
360 rewriter.replaceOpWithNewOp<ForwardDiffRegionOp>(
361 uop, uop->getResultTypes(), in_args, newInActivity,
362 uop.getRetActivityAttr(), uop.getWidthAttr(), uop.getStrongZeroAttr(),
363 uop.getFnAttr());
364 }
365 return success();
366 }
367};
368
369/**
370 *
371 * Modifies return activites for the FwdDiffOp
372 * The activity promotion flow is as follows
373 * (depending on variable use):
374 *
375 * -----> enzyme_dupnoneed ----
376 * / \
377 * enzyme_dup ---> enzyme_constnoneed
378 * \ /
379 * ------> enzyme_const -----
380 *
381 */
382template <typename SourceOp>
383class FwdRetOpt final : public OpRewritePattern<SourceOp> {
384private:
386
387public:
389
390 LogicalResult matchAndRewrite(SourceOp uop,
391 PatternRewriter &rewriter) const override {
392
393 if (uop.getOutputs().size() == 0)
394 return failure();
395
396 auto retActivity = uop.getRetActivity();
397 auto out_idx = 0;
398 SmallVector<mlir::Value, 2> outs_args;
399 SmallVector<Type, 2> out_ty;
400 SmallVector<ActivityAttr, 2> newRetActivityArgs;
401 bool changed = false;
402
403 for (auto [idx, act] : llvm::enumerate(retActivity)) {
404 auto iattr = cast<ActivityAttr>(act);
405 auto val = iattr.getValue();
406
407 // const_noneed does not have a value associated with it
408 // so we can't index into outputs.
409 if (val == Activity::enzyme_constnoneed) {
410 newRetActivityArgs.push_back(iattr);
411 continue;
412 }
413
414 mlir::Value res = uop.getOutputs()[out_idx];
415
416 switch (val) {
417 case Activity::enzyme_active:
418 outs_args.push_back(res);
419 out_ty.push_back(res.getType());
420 newRetActivityArgs.push_back(iattr);
421 break;
422
423 case mlir::enzyme::Activity::enzyme_const:
424 if (!res.use_empty()) {
425 outs_args.push_back(res);
426 out_ty.push_back(res.getType());
427 newRetActivityArgs.push_back(iattr);
428 } else {
429 changed = true;
430 auto new_constnn = mlir::enzyme::ActivityAttr::get(
431 rewriter.getContext(),
432 mlir::enzyme::Activity::enzyme_constnoneed);
433 newRetActivityArgs.push_back(new_constnn);
434 }
435 break;
436 case Activity::enzyme_dupnoneed:
437
438 if (!res.use_empty()) {
439 outs_args.push_back(res);
440 out_ty.push_back(res.getType());
441 newRetActivityArgs.push_back(iattr);
442 } else {
443 if (!isMutable(res.getType())) {
444 changed = true;
445 auto new_constnn = mlir::enzyme::ActivityAttr::get(
446 rewriter.getContext(),
447 mlir::enzyme::Activity::enzyme_constnoneed);
448 newRetActivityArgs.push_back(new_constnn);
449 } else {
450 outs_args.push_back(res);
451 out_ty.push_back(res.getType());
452 newRetActivityArgs.push_back(iattr);
453 }
454 }
455 break;
456 case Activity::enzyme_constnoneed:
457 outs_args.push_back(res);
458 out_ty.push_back(res.getType());
459 newRetActivityArgs.push_back(iattr);
460 break;
461 case Activity::enzyme_activenoneed:
462 outs_args.push_back(res);
463 out_ty.push_back(res.getType());
464 newRetActivityArgs.push_back(iattr);
465 break;
466 case Activity::enzyme_dup: {
467 ActivityAttr new_dup = iattr;
468 if (!res.use_empty()) {
469 outs_args.push_back(res);
470 out_ty.push_back(res.getType());
471 } else {
472 changed = true;
473 // discard return, change attr
474 new_dup = ActivityAttr::get(rewriter.getContext(),
475 Activity::enzyme_dupnoneed);
476 }
477
478 out_idx++;
479
480 // derivative
481 res = uop.getOutputs()[out_idx];
482 if (!res.use_empty()) {
483 // activity arg doesn't update
484 out_ty.push_back(res.getType());
485 outs_args.push_back(res);
486 } else {
487 // no uses, can discard
488 if (!isMutable(res.getType())) {
489 changed = true;
490 // check if primal is used
491 if (new_dup.getValue() == Activity::enzyme_dupnoneed) {
492 new_dup = ActivityAttr::get(rewriter.getContext(),
493 Activity::enzyme_constnoneed);
494 } else {
495 new_dup = ActivityAttr::get(rewriter.getContext(),
496 Activity::enzyme_const);
497 }
498 } else {
499 out_ty.push_back(res.getType());
500 outs_args.push_back(res);
501 }
502 }
503 newRetActivityArgs.push_back(new_dup);
504 break;
505 }
506 default:
507 llvm_unreachable("unexpected activity arg");
508 }
509
510 out_idx++;
511 }
512
513 if (!changed)
514 return failure();
515
516 ArrayAttr newRetActivity =
517 ArrayAttr::get(rewriter.getContext(),
518 llvm::ArrayRef<Attribute>(newRetActivityArgs.begin(),
519 newRetActivityArgs.end()));
520
521 SmallVector<Value> in_args = uop.getInputs();
522 SourceOp newOp = SourceOpCreator::create(
523 rewriter, uop, out_ty, in_args, uop.getActivityAttr(), newRetActivity);
524
525 // Map old uses of uop to newOp
526 auto oldIdx = 0;
527 auto newIdx = 0;
528 for (auto [idx, old_act, new_act] :
529 llvm::enumerate(retActivity, newRetActivityArgs)) {
530
531 auto iattr = cast<ActivityAttr>(old_act);
532 auto old_val = iattr.getValue();
533 auto new_val = new_act.getValue();
534
535 if (old_val == new_val) {
536 // don't index into op if its a const_noneed
537 if (old_val == Activity::enzyme_constnoneed) {
538 continue;
539 }
540 // replace use
541 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
542 newOp.getOutputs()[newIdx++]);
543 if (old_val == Activity::enzyme_dup) {
544 // 2nd replacement for derivative
545 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
546 newOp.getOutputs()[newIdx++]);
547 }
548 } else {
549 // handle all substitutions
550 if (new_val == Activity::enzyme_dupnoneed &&
551 old_val == Activity::enzyme_dup) {
552 ++oldIdx; // skip primal
553 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
554 newOp.getOutputs()[newIdx++]);
555 } else if (new_val == mlir::enzyme::Activity::enzyme_constnoneed &&
556 old_val == mlir::enzyme::Activity::enzyme_const) {
557 ++oldIdx; // skip const
558 } else if (new_val == mlir::enzyme::Activity::enzyme_constnoneed &&
559 old_val == mlir::enzyme::Activity::enzyme_dupnoneed) {
560 ++oldIdx; // skip gradient too
561 } else if (new_val == Activity::enzyme_const &&
562 old_val == Activity::enzyme_dup) {
563
564 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
565 newOp.getOutputs()[newIdx++]);
566 ++oldIdx; // skip derivative
567 } else if (new_val == Activity::enzyme_constnoneed &&
568 old_val == Activity::enzyme_dup) {
569 ++oldIdx; // skip primal
570 ++oldIdx; // skip derivative
571 }
572 }
573 }
574
575 rewriter.eraseOp(uop);
576 return success();
577 }
578};
579
580void ForwardDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
581 MLIRContext *context) {
582
584}
585
586LogicalResult AutoDiffOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
587 // TODO: Verify that the result type is same as the type of the referenced
588 // func.func op.
589 auto global = symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
590 *this, getFnAttr());
591 if (!global)
592 return emitOpError("'")
593 << getFn() << "' does not reference a valid global funcOp";
594
595 return success();
596}
597
598LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
599 // TODO: Verify that the result type is same as the type of the referenced
600 // func.func op.
601 auto global = symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(
602 *this, getFnAttr());
603 if (!global)
604 return emitOpError("'")
605 << getFn() << "' does not reference a valid global funcOp";
606
607 return success();
608}
609
610//===----------------------------------------------------------------------===//
611// BroadcastOp
612//===----------------------------------------------------------------------===//
613
614void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
615 ArrayRef<int64_t> shape) {
616 auto shapeAttr = builder.getDenseI64ArrayAttr(shape);
617 auto resultTy = input.getType();
618 for (auto s : llvm::reverse(shape)) {
619 resultTy = cast<AutoDiffTypeInterface>(resultTy).getShadowType(s);
620 }
621 build(builder, result, resultTy, input, shapeAttr);
622}
623
624/**
625 *
626 * Modifies activities for the AutoDiffOp.
627 *
628 * This is also a nice place to understand the semantics of the rev-mode
629 * autodiff op. At it's core, the reverse mode autodiff takes in a function
630 *
631 * f: def f (pInput):
632 * ...perform computation
633 * return pOutput
634 *
635 * One can assign a very simple function signature to f here:
636 * f: pInput -> pOutput
637 *
638 * When trying to differentiate this function (using autodiff op), Enzyme
639 * creates a new function which also takes in the co-tangents of the outputs
640 * (dOutput), and computes and returns both the output and the input co-tangent
641 * (dInput). This is how the generated autodiff op eventually looks like:
642 *
643 * def revdiff_f(pInput, dOutput):
644 * ...perform computation to compute pOutput
645 * ...perform computation to compute dInput
646 * return pOutput, dInput
647 *
648 * The new function signature now becomes
649 * revdiff_f : (pInput', dOutput) -> (pOutput, dInput)
650 *
651 * I mention pInput' here because it is not exactly the input arguments to the
652 * function we are differentiating, for example, if the input argument type is
653 * an `enzyme_dup`, we will provide both tht primal value along with the
654 * shadow(which we then accumulate into and return). So specifically,
655 *
656 * pInput' = pInput (if the activity is enzyme_active, enzyme_const)
657 * | pInput, dInput (if the activity is enzyme_dup,
658 * enzyme_dupnoneed)
659 *
660 * Now that we have fixed the codegen semantics, we can go ahead and optimize
661 * for both the input return activities based on usage. Possible activity
662 * promotion flow for the inputs can be as follows:
663 * 1. enzyme_active --> enzyme_const (dInput is never used, so we simply don't
664 * compute it)
665 * 2. enzyme_dup --> enzyme_const (pInput is mutable, readonly, nocapture,
666 * dInput is never used post AD)
667 *
668 * Similarly, one can define a similar activity promotion flow for the outputs:
669 * 1. enzyme_active --> enzyme_activenoneed (we do need to pass in dOutput into
670 * the function, but we can see that pOutput is never used, so let's just not
671 * return it. This has the advantage of triggering some additional DCE inside
672 * the generated derivative function)
673 * 2. enzyme_const --> enzyme_constnoneed (same as above, but we now simply skip
674 * over this output)
675 *
676 * One other thing to note here is that these optimizations preserve the input
677 * function signature, and only modify the number of outputs.
678 *
679 */
680
681template <typename SourceOp>
682class ReverseRetOpt final : public OpRewritePattern<SourceOp> {
683private:
685
686public:
688
689 LogicalResult matchAndRewrite(SourceOp uop,
690 PatternRewriter &rewriter) const override {
691 // early return if there are no outputs
692 if (uop.getOutputs().size() == 0)
693 return failure();
694
695 auto inpActivity = uop.getActivity();
696 auto retActivity = uop.getRetActivity();
697 auto out_idx = 0;
698 SmallVector<mlir::Value, 2> in_args;
699 SmallVector<mlir::Value, 2> outs_args;
700 SmallVector<Type, 2> in_ty;
701 SmallVector<Type, 2> out_ty;
702 SmallVector<ActivityAttr, 2> newInActivityArgs;
703 SmallVector<ActivityAttr, 2> newRetActivityArgs;
704
705 bool changed = false;
706 auto in_idx = 0;
707
708 // go upto dOutput
709 for (auto [idx, act] : llvm::enumerate(inpActivity)) {
710 auto iattr = cast<ActivityAttr>(act);
711 auto val = iattr.getValue();
712 mlir::Value res = uop.getInputs()[in_idx];
713 in_args.push_back(res);
714 in_ty.push_back(res.getType());
715 in_idx++;
716
717 if (val == Activity::enzyme_dup || val == Activity::enzyme_dupnoneed) {
718 mlir::Value dres = uop.getInputs()[in_idx];
719 in_args.push_back(dres);
720 in_ty.push_back(dres.getType());
721 in_idx++;
722 }
723 }
724 // function isn't differentiable
725 if (in_idx == uop.getInputs().size())
726 return failure();
727
728 // handle pOutput
729 for (auto [idx, act] : llvm::enumerate(retActivity)) {
730 auto iattr = cast<ActivityAttr>(act);
731 auto val = iattr.getValue();
732
733 // skip primal return
734 if (val == Activity::enzyme_constnoneed ||
735 val == Activity::enzyme_dupnoneed) {
736 newRetActivityArgs.push_back(iattr);
737 continue;
738 }
739
740 mlir::Value res = uop.getOutputs()[out_idx];
741
742 switch (val) {
743 case Activity::enzyme_active: {
744 // active -> activenoneed(if res isn't used)
745 // active -> const(if dres == 0)
746 // active -> constnoneed(both)
747
748 mlir::Value dres = uop.getInputs()[in_idx];
749 in_idx++;
750
751 auto dres_type = dres.getType();
752 auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
753
754 if (!res.use_empty()) {
755 outs_args.push_back(res);
756 out_ty.push_back(res.getType());
757 ActivityAttr new_act = iattr;
758 if (dres_type_intf && !isMutable(dres_type) &&
759 dres_type_intf.isZero(dres)) {
760 // const
761 changed = true;
762 new_act = ActivityAttr::get(rewriter.getContext(),
763 Activity::enzyme_const);
764 } else {
765 in_args.push_back(dres);
766 in_ty.push_back(dres_type);
767 }
768 newRetActivityArgs.push_back(new_act);
769 } else {
770 changed = true;
771 ActivityAttr new_act = ActivityAttr::get(
772 rewriter.getContext(), Activity::enzyme_activenoneed);
773 if (dres_type_intf && !isMutable(dres_type) &&
774 dres_type_intf.isZero(dres)) {
775 // constnoneed
776 new_act = ActivityAttr::get(rewriter.getContext(),
777 Activity::enzyme_constnoneed);
778 } else {
779 // activenoneed
780 in_args.push_back(dres);
781 in_ty.push_back(dres_type);
782 }
783 newRetActivityArgs.push_back(new_act);
784 }
785
786 ++out_idx;
787 break;
788 }
789
790 case Activity::enzyme_activenoneed:
791 // activenoneed -> constnoneed
792 {
793 mlir::Value dres = uop.getInputs()[in_idx];
794 in_idx++;
795 auto new_act = iattr;
796
797 auto dres_type = dres.getType();
798 auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
799 if (dres_type_intf && !isMutable(dres_type) &&
800 dres_type_intf.isZero(dres)) {
801 // constnoneed
802 new_act = ActivityAttr::get(rewriter.getContext(),
803 Activity::enzyme_constnoneed);
804 } else {
805 in_args.push_back(dres);
806 in_ty.push_back(dres_type);
807 }
808 newRetActivityArgs.push_back(iattr);
809 break;
810 }
811 case Activity::enzyme_const:
812 // const -> constnoneed
813 {
814 auto new_act = iattr;
815 if (!res.use_empty()) {
816 outs_args.push_back(res);
817 out_ty.push_back(res.getType());
818 newRetActivityArgs.push_back(new_act);
819 } else {
820 changed = true;
821 new_act = ActivityAttr::get(rewriter.getContext(),
822 Activity::enzyme_constnoneed);
823 newRetActivityArgs.push_back(new_act);
824 }
825 ++out_idx;
826 break;
827 }
828
829 case Activity::enzyme_dup:
830 // TODO: check if ret_arg == enzyme_dup inserts a derivative as the
831 // output and input both
832 outs_args.push_back(res);
833 out_ty.push_back(res.getType());
834 newRetActivityArgs.push_back(iattr);
835 ++out_idx;
836 break;
837
838 case Activity::enzyme_constnoneed:
839 case Activity::enzyme_dupnoneed:
840 break;
841
842 default:
843 llvm_unreachable("unexpected activity arg");
844 }
845 }
846
847 // handle dInputs
848 for (auto [idx, act] : llvm::enumerate(inpActivity)) {
849 auto iattr = cast<ActivityAttr>(act);
850 auto val = iattr.getValue();
851
852 if (val == Activity::enzyme_active) {
853 mlir::Value res = uop.getOutputs()[out_idx];
854 if (!res.use_empty()) {
855 out_ty.push_back(res.getType());
856 outs_args.push_back(res);
857 newInActivityArgs.push_back(iattr);
858 } else {
859 // TODO: check if we can relax immutability here
860 if (!isMutable(res.getType())) {
861 changed = true;
862 auto new_const = ActivityAttr::get(rewriter.getContext(),
863 Activity::enzyme_const);
864 newInActivityArgs.push_back(new_const);
865 } else {
866 // noop even if its not used.
867 out_ty.push_back(res.getType());
868 outs_args.push_back(res);
869 newInActivityArgs.push_back(iattr);
870 }
871 }
872
873 ++out_idx;
874 } else if (val == Activity::enzyme_activenoneed) {
875 mlir::Value res = uop.getOutputs()[out_idx];
876 out_ty.push_back(res.getType());
877 outs_args.push_back(res);
878 newInActivityArgs.push_back(iattr);
879 ++out_idx;
880 llvm_unreachable("unsupported arg activenoneed");
881 } else {
882 newInActivityArgs.push_back(iattr);
883 }
884 }
885
886 if (!changed)
887 return failure();
888
889 ArrayAttr newInActivity =
890 ArrayAttr::get(rewriter.getContext(),
891 llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
892 newInActivityArgs.end()));
893 ArrayAttr newRetActivity =
894 ArrayAttr::get(rewriter.getContext(),
895 llvm::ArrayRef<Attribute>(newRetActivityArgs.begin(),
896 newRetActivityArgs.end()));
897
898 SourceOp newOp = SourceOpCreator::create(rewriter, uop, out_ty, in_args,
899 newInActivity, newRetActivity);
900
901 // Map old uses of uop to newOp
902 auto oldIdx = 0;
903 auto newIdx = 0;
904 for (auto [idx, old_act, new_act] :
905 llvm::enumerate(retActivity, newRetActivityArgs)) {
906 auto iattr = cast<ActivityAttr>(old_act);
907 auto old_val = iattr.getValue();
908 auto new_val = new_act.getValue();
909
910 if (old_val == new_val) {
911 // don't index into op if no primal is returned
912 if (old_val == Activity::enzyme_constnoneed ||
913 old_val == Activity::enzyme_activenoneed ||
914 old_val == Activity::enzyme_dupnoneed) {
915 continue;
916 }
917 // replace current Primal
918 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
919 newOp.getOutputs()[newIdx++]);
920 } else {
921 // handle all substitutions
922 if (new_val == Activity::enzyme_activenoneed &&
923 old_val == Activity::enzyme_active) {
924 ++oldIdx; // skip active primal
925 } else if (new_val == Activity::enzyme_constnoneed &&
926 old_val == Activity::enzyme_const) {
927 ++oldIdx; // skip const primal
928 } else if (old_val == Activity::enzyme_active &&
929 new_val == Activity::enzyme_const) {
930 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
931 newOp.getOutputs()[newIdx++]);
932 } else if (old_val == Activity::enzyme_active &&
933 new_val == Activity::enzyme_constnoneed) {
934 ++oldIdx;
935 } else if (old_val == Activity::enzyme_activenoneed &&
936 new_val == Activity::enzyme_constnoneed) {
937 // just skip
938 }
939 }
940 }
941
942 for (auto [idx, old_act, new_act] :
943 llvm::enumerate(inpActivity, newInActivityArgs)) {
944 auto iattr = cast<ActivityAttr>(old_act);
945 auto old_val = iattr.getValue();
946 auto new_val = new_act.getValue();
947
948 if (old_val == new_val) {
949 if (old_val == Activity::enzyme_active ||
950 old_val == Activity::enzyme_activenoneed) {
951 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
952 newOp.getOutputs()[newIdx++]);
953 } else {
954 continue;
955 }
956 } else {
957 if (old_val == Activity::enzyme_active &&
958 new_val == Activity::enzyme_const) {
959 oldIdx++; // skip derivative
960 }
961 }
962 }
963 rewriter.eraseOp(uop);
964 return success();
965 }
966};
967
968template <typename SourceRegionOp>
969class RemoveUnusedArgs final : public OpRewritePattern<SourceRegionOp> {
970
971private:
973
974public:
975 using OpRewritePattern<SourceRegionOp>::OpRewritePattern;
976
977 LogicalResult matchAndRewrite(SourceRegionOp uop,
978 PatternRewriter &rewriter) const override {
979 SmallVector<Value> newInArgs;
980 SmallVector<size_t> argIdxToErase;
981 SmallVector<ActivityAttr> newInActivityArgs;
982 llvm::SmallVector<Value> blockArg(uop.getBody().getArguments());
983 auto in_idx = 0;
984 for (auto [idx, act] : llvm::enumerate(
985 uop.getActivity().template getAsRange<ActivityAttr>())) {
986 auto act_val = act.getValue();
987 Value res = uop.getInputs()[in_idx++];
988
989 if (blockArg[idx].use_empty()) {
990 argIdxToErase.push_back(idx);
991 if (act_val == Activity::enzyme_dup ||
992 act_val == Activity::enzyme_dupnoneed) {
993 in_idx++;
994 }
995 } else {
996 newInActivityArgs.push_back(act);
997 newInArgs.push_back(res);
998 if (act_val == Activity::enzyme_dup ||
999 act_val == Activity::enzyme_dupnoneed) {
1000 res = uop.getInputs()[in_idx++];
1001 newInArgs.push_back(res);
1002 }
1003 }
1004 }
1005
1006 if (argIdxToErase.empty())
1007 return failure();
1008
1009 // only needed for Autodiff region op
1010 if constexpr (std::is_same_v<SourceRegionOp, AutoDiffRegionOp>) {
1011 newInArgs.append(uop.getDifferentialReturns());
1012 }
1013
1014 ArrayAttr newInActivity =
1015 ArrayAttr::get(rewriter.getContext(),
1016 llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
1017 newInActivityArgs.end()));
1018 auto newOp =
1019 SourceOpCreator::create(rewriter, uop, uop.getResultTypes(), newInArgs,
1020 newInActivity, uop.getRetActivity());
1021
1022 for (auto idx : llvm::reverse(argIdxToErase)) {
1023 newOp.getBody().eraseArgument(idx);
1024 }
1025
1026 rewriter.replaceOp(uop, newOp);
1027 return success();
1028 }
1029};
1030
1031void AutoDiffOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1032 MLIRContext *context) {
1033 patterns.add<ReverseRetOpt<AutoDiffOp>>(context);
1034}
1035
1036void AutoDiffRegionOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1037 MLIRContext *context) {
1038 patterns
1040 context);
1041}
1042
1043void ForwardDiffRegionOp::getCanonicalizationPatterns(
1044 RewritePatternSet &patterns, MLIRContext *context) {
1047}
static bool isMutable(Type type)
Definition Ops.cpp:235
Modifies input activites for the FwdDiffOp The activity promotion flow is as follows (depending on va...
Definition Ops.cpp:258
LogicalResult matchAndRewrite(SourceOp uop, PatternRewriter &rewriter) const override
Definition Ops.cpp:262
Modifies return activites for the FwdDiffOp The activity promotion flow is as follows (depending on v...
Definition Ops.cpp:383
LogicalResult matchAndRewrite(SourceOp uop, PatternRewriter &rewriter) const override
Definition Ops.cpp:390
LogicalResult matchAndRewrite(SourceRegionOp uop, PatternRewriter &rewriter) const override
Definition Ops.cpp:977
Modifies activities for the AutoDiffOp.
Definition Ops.cpp:682
LogicalResult matchAndRewrite(SourceOp uop, PatternRewriter &rewriter) const override
Definition Ops.cpp:689
static AutoDiffOp create(PatternRewriter &rewriter, AutoDiffOp uop, TypeRange out_ty, ValueRange in_args, ArrayAttr newInActivity, ArrayAttr newRetActivity)
Definition Ops.cpp:184
static AutoDiffRegionOp create(PatternRewriter &rewriter, AutoDiffRegionOp uop, TypeRange out_ty, ValueRange in_args, ArrayAttr newInActivity, ArrayAttr newRetActivity)
Definition Ops.cpp:195
static ForwardDiffOp create(PatternRewriter &rewriter, ForwardDiffOp uop, TypeRange out_ty, ValueRange in_args, ArrayAttr newInActivity, ArrayAttr newRetActivity)
Definition Ops.cpp:210
static ForwardDiffRegionOp create(PatternRewriter &rewriter, ForwardDiffRegionOp uop, TypeRange out_ty, ValueRange in_args, ArrayAttr newInActivity, ArrayAttr newRetActivity)
Definition Ops.cpp:221