Enzyme main
Loading...
Searching...
No Matches
CoreDialectsAutoDiffImplementations.cpp
Go to the documentation of this file.
1//===- CoreDialectsAutoDiffImplementations.cpp ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains common utilities for the external model implementation of
10// the automatic differentiation op interfaces for upstream MLIR dialects.
11//
12//===----------------------------------------------------------------------===//
13
19#include "Passes/Utils.h"
20
21#include "mlir/IR/Matchers.h"
22
23using namespace mlir;
24using namespace mlir::enzyme;
25
26mlir::TypedAttr mlir::enzyme::getConstantAttr(mlir::Type type,
27 llvm::StringRef value) {
28 using namespace mlir;
29 if (value == "0") {
30 auto ATI = cast<AutoDiffTypeInterface>(type);
31 return cast<TypedAttr>(ATI.createNullAttr());
32 }
33 if (auto T = dyn_cast<TensorType>(type)) {
34 auto ET = dyn_cast<FloatType>(T.getElementType());
35 if (!ET) {
36 llvm::errs() << " unsupported eltype: " << ET << " of type " << type
37 << "\n";
38 }
39 APFloat values[] = {APFloat(ET.getFloatSemantics(), value)};
40 return DenseElementsAttr::get(cast<ShapedType>(type),
41 ArrayRef<APFloat>(values));
42 }
43 auto T = cast<FloatType>(type);
44 APFloat apvalue(T.getFloatSemantics(), value);
45 return FloatAttr::get(T, apvalue);
46}
47
49 OpBuilder &builder,
50 MGradientUtils *gutils) {
51 auto newInst = gutils->getNewFromOriginal(inst);
52
53 auto binst = cast<BranchOpInterface>(inst);
54
55 // TODO generalize to cloneWithNewBlockArgs interface
56 SmallVector<Value> newVals;
57
58 SmallVector<int32_t> segSizes;
59 // Keep non-differentiated, non-forwarded operands
60 size_t non_forwarded = 0;
61 for (size_t i = 0; i < newInst->getNumSuccessors(); i++) {
62 auto ops = binst.getSuccessorOperands(i).getForwardedOperands();
63 if (ops.empty())
64 continue;
65 non_forwarded = ops.getBeginOperandIndex();
66 break;
67 }
68
69 for (size_t i = 0; i < non_forwarded; i++)
70 newVals.push_back(gutils->getNewFromOriginal(binst->getOperand(i)));
71
72 segSizes.push_back(newVals.size());
73 for (size_t i = 0; i < newInst->getNumSuccessors(); i++) {
74 size_t cur = newVals.size();
75 auto ops = binst.getSuccessorOperands(i).getForwardedOperands();
76 for (auto &&[idx, op] : llvm::enumerate(ops)) {
77 auto arg =
78 *binst.getSuccessorBlockArgument(ops.getBeginOperandIndex() + idx);
79 newVals.push_back(gutils->getNewFromOriginal(op));
80 if (!gutils->isConstantValue(arg)) {
81 if (!gutils->isConstantValue(op)) {
82 newVals.push_back(gutils->invertPointerM(op, builder));
83 } else {
84 Type retTy = cast<AutoDiffTypeInterface>(arg.getType())
85 .getShadowType(gutils->width);
86 auto toret = cast<AutoDiffTypeInterface>(retTy).createNullValue(
87 builder, op.getLoc());
88 newVals.push_back(toret);
89 }
90 }
91 }
92 cur = newVals.size() - cur;
93 segSizes.push_back(cur);
94 }
95
96 SmallVector<NamedAttribute> attrs(newInst->getAttrs());
97 bool has_cases = false;
98 for (auto &attr : attrs) {
99 if (attr.getName() == "case_operand_segments") {
100 has_cases = true;
101 }
102 }
103 for (auto &attr : attrs) {
104 if (attr.getName() == "operandSegmentSizes") {
105 if (!has_cases) {
106 attr.setValue(builder.getDenseI32ArrayAttr(segSizes));
107 } else {
108 SmallVector<int32_t> segSlices2(segSizes.begin(), segSizes.begin() + 2);
109 segSlices2.push_back(0);
110 for (size_t i = 2; i < segSizes.size(); i++)
111 segSlices2[2] += segSizes[i];
112 attr.setValue(builder.getDenseI32ArrayAttr(segSlices2));
113 }
114 }
115 if (attr.getName() == "case_operand_segments") {
116 SmallVector<int32_t> segSlices2(segSizes.begin() + 2, segSizes.end());
117 attr.setValue(builder.getDenseI32ArrayAttr(segSlices2));
118 }
119 }
120
121 gutils->getNewFromOriginal(inst->getBlock())
122 ->push_back(newInst->create(newInst->getLoc(), newInst->getName(),
123 TypeRange(), newVals, attrs,
124 mlir::PropertyRef(), newInst->getSuccessors(),
125 newInst->getNumRegions()));
126 gutils->erase(newInst);
127 return;
128}
129
130static bool contains(ArrayRef<int> ar, int v) {
131 for (auto a : ar) {
132 if (a == v) {
133 return true;
134 }
135 }
136 return false;
137}
138
140 Operation *orig, OpBuilder &builder, MGradientUtils *gutils,
141 ArrayRef<int> storedVals) {
142 auto iface = cast<ActivityOpInterface>(orig);
143
144 SmallVector<Value> newOperands;
145 newOperands.reserve(orig->getNumOperands());
146 SmallVector<bool> inverted(orig->getNumOperands(), false);
147 for (OpOperand &operand : orig->getOpOperands()) {
148 if (iface.isArgInactive(operand.getOperandNumber())) {
149 newOperands.push_back(gutils->getNewFromOriginal(operand.get()));
150 } else {
151 if (gutils->isConstantValue(operand.get())) {
152
153 if (contains(storedVals, operand.getOperandNumber()) ||
154 contains(storedVals, -1)) {
155 if (auto iface =
156 dyn_cast<AutoDiffTypeInterface>(operand.get().getType())) {
157 if (!iface.isMutable()) {
158 Type retTy = iface.getShadowType(gutils->width);
159 auto toret = cast<AutoDiffTypeInterface>(retTy).createNullValue(
160 builder, operand.get().getLoc());
161 newOperands.push_back(toret);
162 continue;
163 }
164 }
165 }
166 orig->emitError()
167 << "Unsupported constant arg to memory identity forward "
168 "handler(opidx="
169 << operand.getOperandNumber() << ", op=" << operand.get() << ")\n";
170 return failure();
171 }
172 inverted[newOperands.size()] = true;
173 newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
174 }
175 }
176
177 // Assuming shadows following the originals are fine.
178 // TODO: consider extending to have a ShadowableTerminatorOpInterface
179 Operation *primal = gutils->getNewFromOriginal(orig);
180 SmallVector<Operation *, 1> shadows;
181 if (gutils->width == 1) {
182 Operation *shadow = builder.clone(*primal);
183 shadow->setOperands(newOperands);
184 shadows.push_back(shadow);
185 } else {
186 for (size_t w = 0; w < gutils->width; w++) {
187 SmallVector<Value> newOperands2(newOperands);
188 for (size_t i = 0; i < newOperands.size(); i++) {
189 if (!inverted[i])
190 continue;
191 newOperands2[i] = enzyme::getExtractValue(
192 builder, orig->getLoc(), orig->getOperands()[i].getType(),
193 newOperands2[i], w);
194 }
195 Operation *shadow = builder.clone(*primal);
196 shadow->setOperands(newOperands2);
197 shadows.push_back(shadow);
198 }
199 }
200 for (auto &&[i, oval] : llvm::enumerate(orig->getResults())) {
201 Value sval;
202 if (gutils->width == 1) {
203 sval = shadows[0]->getResult(i);
204 } else {
205 SmallVector<Value> shadowRes;
206 for (auto s : shadows) {
207 shadowRes.push_back(s->getResult(i));
208 }
209 sval = enzyme::getConcatValue(builder, orig->getLoc(), shadowRes);
210 }
211 gutils->setDiffe(oval, sval, builder);
212 }
213
214 return success();
215}
216
218 Operation *orig, OpBuilder &builder, MGradientUtils *gutils, bool zero) {
219
220 Operation *primal = gutils->getNewFromOriginal(orig);
221 Operation *shadow = builder.clone(*primal);
222
223 Value shadowRes = shadow->getResult(0);
224
225 gutils->setDiffe(orig->getResult(0), shadowRes, builder);
226 gutils->eraseIfUnused(orig);
227
228 if (zero) {
229 // Fill with zeros
230 if (auto iface = dyn_cast<AutoDiffTypeInterface>(shadowRes.getType())) {
231 return iface.zeroInPlace(builder, orig->getLoc(), shadowRes);
232 } else {
233 orig->emitError() << "Type " << shadowRes.getType()
234 << " does not implement "
235 "AutoDiffTypeInterface";
236 return failure();
237 }
238 }
239 return success();
240}
241
243 OpBuilder &builder,
244 MGradientUtilsReverse *gutils) {
245 size_t num_out = 0;
246 for (auto act : gutils->RetDiffeTypes) {
247 if (act == DIFFE_TYPE::OUT_DIFF)
248 num_out++;
249 }
250
251 size_t idx = 0;
252 auto args = gutils->newFunc->getRegions().begin()->begin()->getArguments();
253
254 for (auto &&[op, act] : llvm::zip(op->getOperands(), gutils->RetDiffeTypes)) {
255 if (act == DIFFE_TYPE::OUT_DIFF) {
256 if (!gutils->isConstantValue(op)) {
257 auto d_out = args[args.size() - num_out + idx];
258 gutils->addToDiffe(op, d_out, builder);
259 }
260 idx++;
261 }
262 }
263}
264
266 Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) {
267 auto parentOp = origTerminator->getParentOp();
268
269 llvm::SmallDenseSet<unsigned> operandsToShadow;
270 auto termIface = dyn_cast<RegionBranchTerminatorOpInterface>(origTerminator);
271 auto regionBranchOp =
272 dyn_cast<RegionBranchOpInterface>(origTerminator->getParentOp());
273 if (termIface && regionBranchOp) {
274
275 SmallVector<RegionSuccessor> successors;
276 termIface.getSuccessorRegions(
277 SmallVector<Attribute>(origTerminator->getNumOperands(), Attribute()),
278 successors);
279
280 for (auto &successor : successors) {
281 OperandRange operandRange = termIface.getSuccessorOperands(successor);
282 ValueRange targetValues =
283 successor.isParent() ? parentOp->getResults()
284 : regionBranchOp.getSuccessorInputs(successor);
285 assert(operandRange.size() == targetValues.size());
286 for (auto &&[i, target] : llvm::enumerate(targetValues)) {
287 if (!gutils->isConstantValue(target))
288 operandsToShadow.insert(operandRange.getBeginOperandIndex() + i);
289 }
290 }
291 } else {
292 assert(parentOp->getNumResults() == origTerminator->getNumOperands());
293 for (auto res : parentOp->getResults()) {
294 if (!gutils->isConstantValue(res))
295 operandsToShadow.insert(res.getResultNumber());
296 }
297 }
298
299 SmallVector<Value> newOperands;
300 newOperands.reserve(origTerminator->getNumOperands() +
301 operandsToShadow.size());
302 for (OpOperand &operand : origTerminator->getOpOperands()) {
303 newOperands.push_back(gutils->getNewFromOriginal(operand.get()));
304 if (operandsToShadow.contains(operand.getOperandNumber()))
305 newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
306 }
307
308 // Assuming shadows following the originals are fine.
309 // TODO: consider extending to have a ShadowableTerminatorOpInterface
310 Operation *replTerminator = gutils->getNewFromOriginal(origTerminator);
311 replTerminator->setOperands(newOperands);
312}
313
315 Operation *op, OpBuilder &builder, MGradientUtils *gutils) {
316
317 // For all operands that are forwarded to the body, if they are active, also
318 // add the shadow as operand.
319 auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
320 if (!regionBranchOp) {
321 op->emitError() << " RegionBranchOpInterface not implemented for " << *op
322 << "\n";
323 return failure();
324 }
325 auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
326 if (!iface) {
327 op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for "
328 << *op << "\n";
329 return failure();
330 }
331
332 // TODO: we may need to record, for every successor, which of its inputs
333 // need a shadow to recreate the body correctly.
334 llvm::SmallDenseSet<unsigned> operandPositionsToShadow;
335 llvm::SmallDenseSet<unsigned> resultPositionsToShadow;
336
337 SmallVector<RegionSuccessor> entrySuccessors;
338 regionBranchOp.getEntrySuccessorRegions(
339 SmallVector<Attribute>(op->getNumOperands(), Attribute()),
340 entrySuccessors);
341
342 for (const RegionSuccessor &successor : entrySuccessors) {
343
344 OperandRange operandRange =
345 iface.getSuccessorOperands(regionBranchOp, successor);
346
347 ValueRange targetValues =
348 successor.isParent() ? op->getResults()
349 : regionBranchOp.getSuccessorInputs(successor);
350
351 // Need to know which of the arguments are being forwarded to from
352 // operands.
353 for (auto &&[i, regionValue, operand] :
354 llvm::enumerate(targetValues, operandRange)) {
355 if (gutils->isConstantValue(regionValue))
356 continue;
357 operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
358 if (successor.isParent())
359 resultPositionsToShadow.insert(i);
360 }
361 }
362
363 for (auto res : op->getResults())
364 if (!gutils->isConstantValue(res))
365 resultPositionsToShadow.insert(res.getResultNumber());
366
368 op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow);
369}
370
372 Operation *op, OpBuilder &builder, MGradientUtils *gutils,
373 const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
374 const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow) {
375 // For all active results, add shadow types.
376 // For now, assuming all results are relevant.
377 Operation *newOp = gutils->getNewFromOriginal(op);
378 SmallVector<Type> newOpResultTypes;
379 newOpResultTypes.reserve(op->getNumResults() * 2);
380 for (auto result : op->getResults()) {
381 // TODO only if used (can we DCE the primal after having done the
382 // derivative).
383 newOpResultTypes.push_back(result.getType());
384 if (!gutils->isConstantValue(result)) {
385 assert(resultPositionsToShadow.count(result.getResultNumber()));
386 }
387 if (!resultPositionsToShadow.count(result.getResultNumber()))
388 continue;
389 auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType());
390 if (!typeIface) {
391 op->emitError() << " AutoDiffTypeInterface not implemented for "
392 << result.getType() << "\n";
393 return failure();
394 }
395 newOpResultTypes.push_back(typeIface.getShadowType(gutils->width));
396 }
397
398 SmallVector<Value> newOperands;
399 newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size());
400 for (OpOperand &operand : op->getOpOperands()) {
401 newOperands.push_back(gutils->getNewFromOriginal(operand.get()));
402 if (operandPositionsToShadow.contains(operand.getOperandNumber()))
403 newOperands.push_back(gutils->invertPointerM(operand.get(), builder));
404 }
405 // We are assuming the op can forward additional operands, listed
406 // immediately after the original operands, to the same regions.
407 // ^^
408 // Our interface guarantees this.
409 // We also assume that the region-holding op returns all of the values
410 // yielded by terminators, and only those values.
411
412 auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
413 if (!iface) {
414 op->emitError() << " ControlFlowAutoDiffOpInterface not implemented for "
415 << *op << "\n";
416 return failure();
417 }
418 Operation *replacement = iface.createWithShadows(
419 builder, gutils, op, newOperands, newOpResultTypes);
420 assert(replacement->getNumResults() == newOpResultTypes.size());
421 for (auto &&[region, replacementRegion] :
422 llvm::zip(newOp->getRegions(), replacement->getRegions())) {
423 replacementRegion.takeBody(region);
424 }
425
426 // Inject the mapping for the new results into GradientUtil's shadow
427 // table.
428 SmallVector<Value> reps;
429 size_t idx = 0;
430 for (OpResult r : op->getResults()) {
431 // TODO only if used
432 reps.push_back(replacement->getResult(idx));
433 idx++;
434 if (!gutils->isConstantValue(r)) {
435 assert(resultPositionsToShadow.count(r.getResultNumber()));
436 auto inverted = gutils->invertedPointers.lookupOrNull(r);
437 assert(inverted);
438 gutils->invertedPointers.map(r, replacement->getResult(idx));
439 inverted.replaceAllUsesWith(replacement->getResult(idx));
440 gutils->erase(inverted.getDefiningOp());
441 idx++;
442 } else if (resultPositionsToShadow.count(r.getResultNumber())) {
443 idx++;
444 }
445 }
446
447 // Differentiate body.
448 for (auto &origRegion : op->getRegions()) {
449 for (auto &origBlock : origRegion) {
450 for (Operation &o : origBlock) {
451 if (failed(gutils->visitChild(&o))) {
452 return failure();
453 }
454 }
455 }
456 }
457
458 // Replace all uses of original results
459 gutils->replaceOrigOpWith(op, reps);
460 gutils->erase(newOp);
461 gutils->originalToNewFnOps[op] = replacement;
462
463 return success();
464}
465
static bool contains(ArrayRef< int > ar, int v)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
ArrayRef< DIFFE_TYPE > RetDiffeTypes
void replaceOrigOpWith(Operation *op, ValueRange vals)
std::map< Operation *, Operation * > originalToNewFnOps
LogicalResult visitChild(Operation *op)
void erase(Operation *op)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
void eraseIfUnused(Operation *op, bool erase=true, bool check=true)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
FunctionOpInterface newFunc
bool isConstantValue(mlir::Value v) const
LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils, ArrayRef< int > storedVals)
void returnReverseHandler(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
void branchingForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
LogicalResult allocationForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils, bool zero)
void registerComplexDialectAutoDiffInterface(DialectRegistry &registry)
Value getConcatValue(OpBuilder &builder, Location loc, ArrayRef< Value > argList)
Definition Utils.cpp:155
void registerArithDialectAutoDiffInterface(DialectRegistry &registry)
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry)
Value getExtractValue(OpBuilder &builder, Location loc, Type argTy, Value val, int64_t index)
Definition Utils.cpp:163
void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry)
void registerMathDialectAutoDiffInterface(DialectRegistry &registry)
void registerEnzymeDialectAutoDiffInterface(DialectRegistry &registry)
void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry)
void registerFuncDialectAutoDiffInterface(DialectRegistry &registry)
mlir::TypedAttr getConstantAttr(mlir::Type type, llvm::StringRef value)
void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry)
void registerSCFDialectAutoDiffInterface(DialectRegistry &registry)
void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry)
void registerCFDialectAutoDiffInterface(DialectRegistry &registry)
void registerLLVMExtDialectAutoDiffInterface(DialectRegistry &registry)
void registerAffineDialectAutoDiffInterface(DialectRegistry &registry)
void registerNVVMDialectAutoDiffInterface(DialectRegistry &registry)
void registerTensorDialectAutoDiffInterface(DialectRegistry &registry)