21#include "mlir/IR/Matchers.h"
27 llvm::StringRef value) {
30 auto ATI = cast<AutoDiffTypeInterface>(type);
31 return cast<TypedAttr>(ATI.createNullAttr());
33 if (
auto T = dyn_cast<TensorType>(type)) {
34 auto ET = dyn_cast<FloatType>(T.getElementType());
36 llvm::errs() <<
" unsupported eltype: " << ET <<
" of type " << type
39 APFloat values[] = {APFloat(ET.getFloatSemantics(), value)};
40 return DenseElementsAttr::get(cast<ShapedType>(type),
41 ArrayRef<APFloat>(values));
43 auto T = cast<FloatType>(type);
44 APFloat apvalue(T.getFloatSemantics(), value);
45 return FloatAttr::get(T, apvalue);
53 auto binst = cast<BranchOpInterface>(inst);
56 SmallVector<Value> newVals;
58 SmallVector<int32_t> segSizes;
60 size_t non_forwarded = 0;
61 for (
size_t i = 0; i < newInst->getNumSuccessors(); i++) {
62 auto ops = binst.getSuccessorOperands(i).getForwardedOperands();
65 non_forwarded = ops.getBeginOperandIndex();
69 for (
size_t i = 0; i < non_forwarded; i++)
72 segSizes.push_back(newVals.size());
73 for (
size_t i = 0; i < newInst->getNumSuccessors(); i++) {
74 size_t cur = newVals.size();
75 auto ops = binst.getSuccessorOperands(i).getForwardedOperands();
76 for (
auto &&[idx, op] : llvm::enumerate(ops)) {
78 *binst.getSuccessorBlockArgument(ops.getBeginOperandIndex() + idx);
84 Type retTy = cast<AutoDiffTypeInterface>(arg.getType())
85 .getShadowType(gutils->
width);
86 auto toret = cast<AutoDiffTypeInterface>(retTy).createNullValue(
87 builder, op.getLoc());
88 newVals.push_back(toret);
92 cur = newVals.size() - cur;
93 segSizes.push_back(cur);
96 SmallVector<NamedAttribute> attrs(newInst->getAttrs());
97 bool has_cases =
false;
98 for (
auto &attr : attrs) {
99 if (attr.getName() ==
"case_operand_segments") {
103 for (
auto &attr : attrs) {
104 if (attr.getName() ==
"operandSegmentSizes") {
106 attr.setValue(builder.getDenseI32ArrayAttr(segSizes));
108 SmallVector<int32_t> segSlices2(segSizes.begin(), segSizes.begin() + 2);
109 segSlices2.push_back(0);
110 for (
size_t i = 2; i < segSizes.size(); i++)
111 segSlices2[2] += segSizes[i];
112 attr.setValue(builder.getDenseI32ArrayAttr(segSlices2));
115 if (attr.getName() ==
"case_operand_segments") {
116 SmallVector<int32_t> segSlices2(segSizes.begin() + 2, segSizes.end());
117 attr.setValue(builder.getDenseI32ArrayAttr(segSlices2));
122 ->push_back(newInst->create(newInst->getLoc(), newInst->getName(),
123 TypeRange(), newVals, attrs,
124 mlir::PropertyRef(), newInst->getSuccessors(),
125 newInst->getNumRegions()));
126 gutils->
erase(newInst);
141 ArrayRef<int> storedVals) {
142 auto iface = cast<ActivityOpInterface>(orig);
144 SmallVector<Value> newOperands;
145 newOperands.reserve(orig->getNumOperands());
146 SmallVector<bool> inverted(orig->getNumOperands(),
false);
147 for (OpOperand &operand : orig->getOpOperands()) {
148 if (iface.isArgInactive(operand.getOperandNumber())) {
153 if (
contains(storedVals, operand.getOperandNumber()) ||
156 dyn_cast<AutoDiffTypeInterface>(operand.get().getType())) {
157 if (!iface.isMutable()) {
158 Type retTy = iface.getShadowType(gutils->
width);
159 auto toret = cast<AutoDiffTypeInterface>(retTy).createNullValue(
160 builder, operand.get().getLoc());
161 newOperands.push_back(toret);
167 <<
"Unsupported constant arg to memory identity forward "
169 << operand.getOperandNumber() <<
", op=" << operand.get() <<
")\n";
172 inverted[newOperands.size()] =
true;
173 newOperands.push_back(gutils->
invertPointerM(operand.get(), builder));
180 SmallVector<Operation *, 1> shadows;
181 if (gutils->
width == 1) {
182 Operation *shadow = builder.clone(*primal);
183 shadow->setOperands(newOperands);
184 shadows.push_back(shadow);
186 for (
size_t w = 0; w < gutils->
width; w++) {
187 SmallVector<Value> newOperands2(newOperands);
188 for (
size_t i = 0; i < newOperands.size(); i++) {
192 builder, orig->getLoc(), orig->getOperands()[i].getType(),
195 Operation *shadow = builder.clone(*primal);
196 shadow->setOperands(newOperands2);
197 shadows.push_back(shadow);
200 for (
auto &&[i, oval] : llvm::enumerate(orig->getResults())) {
202 if (gutils->
width == 1) {
203 sval = shadows[0]->getResult(i);
205 SmallVector<Value> shadowRes;
206 for (
auto s : shadows) {
207 shadowRes.push_back(s->getResult(i));
211 gutils->
setDiffe(oval, sval, builder);
218 Operation *orig, OpBuilder &builder,
MGradientUtils *gutils,
bool zero) {
221 Operation *shadow = builder.clone(*primal);
223 Value shadowRes = shadow->getResult(0);
225 gutils->
setDiffe(orig->getResult(0), shadowRes, builder);
230 if (
auto iface = dyn_cast<AutoDiffTypeInterface>(shadowRes.getType())) {
231 return iface.zeroInPlace(builder, orig->getLoc(), shadowRes);
233 orig->emitError() <<
"Type " << shadowRes.getType()
234 <<
" does not implement "
235 "AutoDiffTypeInterface";
252 auto args = gutils->
newFunc->getRegions().begin()->begin()->getArguments();
254 for (
auto &&[op, act] : llvm::zip(op->getOperands(), gutils->
RetDiffeTypes)) {
257 auto d_out = args[args.size() - num_out + idx];
266 Operation *origTerminator, OpBuilder &builder,
MGradientUtils *gutils) {
267 auto parentOp = origTerminator->getParentOp();
269 llvm::SmallDenseSet<unsigned> operandsToShadow;
270 auto termIface = dyn_cast<RegionBranchTerminatorOpInterface>(origTerminator);
271 auto regionBranchOp =
272 dyn_cast<RegionBranchOpInterface>(origTerminator->getParentOp());
273 if (termIface && regionBranchOp) {
275 SmallVector<RegionSuccessor> successors;
276 termIface.getSuccessorRegions(
277 SmallVector<Attribute>(origTerminator->getNumOperands(), Attribute()),
280 for (
auto &successor : successors) {
281 OperandRange operandRange = termIface.getSuccessorOperands(successor);
282 ValueRange targetValues =
283 successor.isParent() ? parentOp->getResults()
284 : regionBranchOp.getSuccessorInputs(successor);
285 assert(operandRange.size() == targetValues.size());
286 for (
auto &&[i, target] : llvm::enumerate(targetValues)) {
288 operandsToShadow.insert(operandRange.getBeginOperandIndex() + i);
292 assert(parentOp->getNumResults() == origTerminator->getNumOperands());
293 for (
auto res : parentOp->getResults()) {
295 operandsToShadow.insert(res.getResultNumber());
299 SmallVector<Value> newOperands;
300 newOperands.reserve(origTerminator->getNumOperands() +
301 operandsToShadow.size());
302 for (OpOperand &operand : origTerminator->getOpOperands()) {
304 if (operandsToShadow.contains(operand.getOperandNumber()))
305 newOperands.push_back(gutils->
invertPointerM(operand.get(), builder));
311 replTerminator->setOperands(newOperands);
319 auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
320 if (!regionBranchOp) {
321 op->emitError() <<
" RegionBranchOpInterface not implemented for " << *op
325 auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
327 op->emitError() <<
" ControlFlowAutoDiffOpInterface not implemented for "
334 llvm::SmallDenseSet<unsigned> operandPositionsToShadow;
335 llvm::SmallDenseSet<unsigned> resultPositionsToShadow;
337 SmallVector<RegionSuccessor> entrySuccessors;
338 regionBranchOp.getEntrySuccessorRegions(
339 SmallVector<Attribute>(op->getNumOperands(), Attribute()),
342 for (
const RegionSuccessor &successor : entrySuccessors) {
344 OperandRange operandRange =
345 iface.getSuccessorOperands(regionBranchOp, successor);
347 ValueRange targetValues =
348 successor.isParent() ? op->getResults()
349 : regionBranchOp.getSuccessorInputs(successor);
353 for (
auto &&[i, regionValue, operand] :
354 llvm::enumerate(targetValues, operandRange)) {
357 operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
358 if (successor.isParent())
359 resultPositionsToShadow.insert(i);
363 for (
auto res : op->getResults())
365 resultPositionsToShadow.insert(res.getResultNumber());
368 op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow);
373 const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
374 const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow) {
378 SmallVector<Type> newOpResultTypes;
379 newOpResultTypes.reserve(op->getNumResults() * 2);
380 for (
auto result : op->getResults()) {
383 newOpResultTypes.push_back(result.getType());
385 assert(resultPositionsToShadow.count(result.getResultNumber()));
387 if (!resultPositionsToShadow.count(result.getResultNumber()))
389 auto typeIface = dyn_cast<AutoDiffTypeInterface>(result.getType());
391 op->emitError() <<
" AutoDiffTypeInterface not implemented for "
392 << result.getType() <<
"\n";
395 newOpResultTypes.push_back(typeIface.getShadowType(gutils->
width));
398 SmallVector<Value> newOperands;
399 newOperands.reserve(op->getNumOperands() + operandPositionsToShadow.size());
400 for (OpOperand &operand : op->getOpOperands()) {
402 if (operandPositionsToShadow.contains(operand.getOperandNumber()))
403 newOperands.push_back(gutils->
invertPointerM(operand.get(), builder));
412 auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);
414 op->emitError() <<
" ControlFlowAutoDiffOpInterface not implemented for "
418 Operation *replacement = iface.createWithShadows(
419 builder, gutils, op, newOperands, newOpResultTypes);
420 assert(replacement->getNumResults() == newOpResultTypes.size());
421 for (
auto &&[region, replacementRegion] :
422 llvm::zip(newOp->getRegions(), replacement->getRegions())) {
423 replacementRegion.takeBody(region);
428 SmallVector<Value> reps;
430 for (OpResult r : op->getResults()) {
432 reps.push_back(replacement->getResult(idx));
435 assert(resultPositionsToShadow.count(r.getResultNumber()));
439 inverted.replaceAllUsesWith(replacement->getResult(idx));
440 gutils->
erase(inverted.getDefiningOp());
442 }
else if (resultPositionsToShadow.count(r.getResultNumber())) {
448 for (
auto &origRegion : op->getRegions()) {
449 for (
auto &origBlock : origRegion) {
450 for (Operation &o : origBlock) {
460 gutils->
erase(newOp);
467 DialectRegistry ®istry) {
static bool contains(ArrayRef< int > ar, int v)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
ArrayRef< DIFFE_TYPE > RetDiffeTypes
IRMapping invertedPointers
void replaceOrigOpWith(Operation *op, ValueRange vals)
std::map< Operation *, Operation * > originalToNewFnOps
LogicalResult visitChild(Operation *op)
void erase(Operation *op)
void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
void eraseIfUnused(Operation *op, bool erase=true, bool check=true)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
FunctionOpInterface newFunc
bool isConstantValue(mlir::Value v) const
LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils, ArrayRef< int > storedVals)
void returnReverseHandler(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
void branchingForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
LogicalResult allocationForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils, bool zero)
void registerComplexDialectAutoDiffInterface(DialectRegistry ®istry)
Value getConcatValue(OpBuilder &builder, Location loc, ArrayRef< Value > argList)
void registerArithDialectAutoDiffInterface(DialectRegistry ®istry)
void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry)
Value getExtractValue(OpBuilder &builder, Location loc, Type argTy, Value val, int64_t index)
void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry)
void registerMathDialectAutoDiffInterface(DialectRegistry ®istry)
void registerEnzymeDialectAutoDiffInterface(DialectRegistry ®istry)
void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry)
void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry)
mlir::TypedAttr getConstantAttr(mlir::Type type, llvm::StringRef value)
void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry)
void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry)
void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry)
void registerCFDialectAutoDiffInterface(DialectRegistry ®istry)
void registerLLVMExtDialectAutoDiffInterface(DialectRegistry ®istry)
void registerAffineDialectAutoDiffInterface(DialectRegistry ®istry)
void registerNVVMDialectAutoDiffInterface(DialectRegistry ®istry)
void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry)