Enzyme main
Loading...
Searching...
No Matches
GradientUtilsReverse.cpp
Go to the documentation of this file.
1//===- GradientUtilsReverse.cpp - Utilities for gradient interfaces
2//--------------===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
11#include "Dialect/Ops.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/SymbolTable.h"
16#include "mlir/Interfaces/FunctionInterfaces.h"
17
18// TODO: this shouldn't depend on specific dialects except Enzyme.
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20
21#include "CloneFunction.h"
22#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24#include "mlir/IR/Dominance.h"
25#include "llvm/ADT/BreadthFirstIterator.h"
26
27using namespace mlir;
28using namespace mlir::enzyme;
29
31 MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
32 FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
33 IRMapping invertedPointers_, const llvm::ArrayRef<bool> returnPrimals,
34 const llvm::ArrayRef<bool> returnShadows,
35 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
36 const SmallPtrSetImpl<mlir::Value> &activevals_,
37 ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
38 IRMapping &originalToNewFn_,
39 std::map<Operation *, Operation *> &originalToNewFnOps_,
40 DerivativeMode mode_, unsigned width, bool omp, StringRef postpasses,
41 bool verifyPostPasses, bool strongZero)
42 : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
43 invertedPointers_, returnPrimals, returnShadows,
44 constantvalues_, activevals_, ReturnActivity,
45 ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_,
46 mode_, width, omp, postpasses, verifyPostPasses,
47 strongZero) {}
48
50 Type indexType = getIndexType();
51 return getCacheType(indexType);
52}
53
55 return mlir::IntegerType::get(initializationBlock->begin()->getContext(), 32);
56}
57
59 OpBuilder builder(initializationBlock, initializationBlock->begin());
60 return enzyme::InitOp::create(builder,
61 (initializationBlock->rbegin())->getLoc(), t);
62}
63
64// Cache
66 Type cacheType =
67 CacheType::get(initializationBlock->begin()->getContext(), t);
68 return cacheType;
69}
70
72 std::function<std::pair<Value, Value>(Type)> hook) {
73 if (hook != nullptr) {
74 cacheCreatorHook.push_back(hook);
75 }
76}
77
79 std::function<std::pair<Value, Value>(Type)> hook) {
80 if (hook != nullptr) {
81 cacheCreatorHook.pop_back();
82 }
83}
84
85std::pair<Value, Value> MGradientUtilsReverse::getNewCache(Type t) {
86 if (cacheCreatorHook.empty()) {
87 Value cache = insertInit(t);
88 return {cache, cache};
89 }
90 return cacheCreatorHook.back()(t);
91}
92
93// We assume that caches will only be written to at one location. The returned
94// cache is (might be) "pop only"
95Value MGradientUtilsReverse::initAndPushCache(Value v, OpBuilder &builder) {
96 auto [pushCache, popCache] = getNewCache(getCacheType(v.getType()));
97 enzyme::PushOp::create(builder, v.getLoc(), pushCache, v);
98 return popCache;
99}
100
101Value MGradientUtilsReverse::popCache(Value cache, OpBuilder &builder) {
102 return enzyme::PopOp::create(
103 builder, cache.getLoc(),
104 cast<enzyme::CacheType>(cache.getType()).getType(), cache);
105}
106
107Operation *
109 Operation *op) {
110 IRMapping map;
111 for (auto operand : op->getOperands())
112 map.map(operand, getNewFromOriginal(operand));
113 return B.clone(*op, map);
114}
115
117 Value addedGradient,
118 OpBuilder &builder) {
119 assert(!isConstantValue(oldGradient));
120 Value operandGradient = diffe(oldGradient, builder);
121 auto iface = cast<AutoDiffTypeInterface>(addedGradient.getType());
122 auto added = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient,
123 addedGradient);
124 setDiffe(oldGradient, added, builder);
125}
126
128 Region &newFunc) {
129 for (auto it = oldFunc.getBlocks().rbegin(); it != oldFunc.getBlocks().rend();
130 ++it) {
131 Block *block = &*it;
132 Block *reverseBlock = new Block();
133 newFunc.getBlocks().insert(newFunc.end(), reverseBlock);
134 mapReverseModeBlocks.map(block, reverseBlock);
135 }
136}
137
139 MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
140 FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
141 const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
142 ArrayRef<DIFFE_TYPE> retType, ArrayRef<DIFFE_TYPE> constant_args,
143 mlir::Type additionalArg, bool omp, llvm::StringRef postpasses,
144 bool verifyPostPasses, bool strongZero) {
145 std::string prefix;
146
147 switch (mode_) {
151 assert(false);
152 break;
155 prefix = "diffe";
156 break;
158 llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n");
159 }
160
161 if (width > 1)
162 prefix += std::to_string(width);
163
164 IRMapping originalToNew;
165 std::map<Operation *, Operation *> originalToNewOps;
166
167 SmallPtrSet<mlir::Value, 1> returnvals;
168 SmallPtrSet<mlir::Value, 1> constant_values;
169 SmallPtrSet<mlir::Value, 1> nonconstant_values;
170 IRMapping invertedPointers;
171 FunctionOpInterface newFunc = CloneFunctionWithReturns(
172 mode_, width, todiff, invertedPointers, constant_args, constant_values,
173 nonconstant_values, returnvals, returnPrimals, returnShadows, retType,
174 prefix + todiff.getName(), originalToNew, originalToNewOps,
175 additionalArg);
176
177 return new MGradientUtilsReverse(
179 returnShadows, constant_values, nonconstant_values, retType,
180 constant_args, originalToNew, originalToNewOps, mode_, width, omp,
182}
FunctionOpInterface CloneFunctionWithReturns(DerivativeMode mode, unsigned width, FunctionOpInterface F, IRMapping &ptrInputs, ArrayRef< DIFFE_TYPE > ArgActivity, SmallPtrSetImpl< mlir::Value > &constants, SmallPtrSetImpl< mlir::Value > &nonconstants, SmallPtrSetImpl< mlir::Value > &returnvals, const std::vector< bool > &returnPrimals, const std::vector< bool > &returnShadows, ArrayRef< DIFFE_TYPE > RetActivity, Twine name, IRMapping &VMap, std::map< Operation *, Operation * > &OpMap, mlir::Type additionalArg)
DerivativeMode
Definition Utils.h:390
Value popCache(Value cache, OpBuilder &builder)
SmallVector< std::function< std::pair< Value, Value >(Type)> > cacheCreatorHook
Value initAndPushCache(Value v, OpBuilder &builder)
MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, IRMapping invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > ReturnActivity, ArrayRef< DIFFE_TYPE > ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map< Operation *, Operation * > &originalToNewFnOps_, DerivativeMode mode_, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
void createReverseModeBlocks(Region &oldFunc, Region &newFunc)
void deregisterCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
void registerCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef< bool > returnPrimals, const ArrayRef< bool > returnShadows, llvm::ArrayRef< DIFFE_TYPE > retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
std::pair< Value, Value > getNewCache(Type t)
Operation * cloneWithNewOperands(OpBuilder &B, Operation *op)
const llvm::ArrayRef< bool > returnShadows
const llvm::ArrayRef< bool > returnPrimals
FunctionOpInterface oldFunc
FunctionOpInterface newFunc
static LoopCacheType getCacheType(Operation *op)