15#ifndef ENZYMEMLIR_CORE_IMPL_H_
16#define ENZYMEMLIR_CORE_IMPL_H_
19#include "mlir/Support/LogicalResult.h"
21#include "llvm/ADT/DenseSet.h"
31class MGradientUtilsReverse;
38 MGradientUtils *gutils);
41 Operation *op, OpBuilder &builder, MGradientUtils *gutils,
42 const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
43 const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow);
48 MGradientUtils *gutils);
53 MGradientUtils *gutils);
57 MGradientUtilsReverse *gutils);
62 MGradientUtils *gutils,
63 ArrayRef<int> storedVals);
67 MGradientUtils *gutils,
bool zero);
72template <
typename OpTy>
74 :
public AutoDiffOpInterface::ExternalModel<AutoDiffUsingControlFlow<OpTy>,
85template <
typename OpTy>
87 :
public AutoDiffOpInterface::ExternalModel<AutoDiffUsingBranch<OpTy>,
99template <
typename OpTy>
101 :
public AutoDiffOpInterface::ExternalModel<
102 AutoDiffUsingRegionTerminator<OpTy>, OpTy> {
111template <
typename OpTy>
113 :
public ReverseAutoDiffOpInterface::ExternalModel<
114 NoopRevAutoDiffInterface<OpTy>, OpTy> {
118 SmallVector<Value> caches)
const {
124 return SmallVector<Value>();
131template <
typename OpTy>
133 :
public ReverseAutoDiffOpInterface::ExternalModel<
134 ReturnRevAutoDiffInterface<OpTy>, OpTy> {
138 SmallVector<Value> caches)
const {
145 return SmallVector<Value>();
154template <
typename OpTy,
int... storedvals>
156 :
public AutoDiffOpInterface::ExternalModel<
157 AutoDiffUsingMemoryIdentity<OpTy, storedvals...>, OpTy> {
163 op, builder, gutils, std::initializer_list<int>{storedvals...});
169template <
typename OpTy>
171 AutoDiffUsingAllocationFwd<OpTy>, OpTy> {
182template <
typename OpTy>
184 :
public ReverseAutoDiffOpInterface::ExternalModel<
185 AutoDiffUsingAllocationRev<OpTy>, OpTy> {
189 SmallVector<Value> caches)
const {
195 return SmallVector<Value>();
207template <
typename OpTy>
209 OpTy::template attachInterface<detail::AutoDiffUsingControlFlow<OpTy>>(
213template <
typename OpTy>
215 OpTy::template attachInterface<detail::AutoDiffUsingBranch<OpTy>>(context);
216 OpTy::template attachInterface<detail::NoopRevAutoDiffInterface<OpTy>>(
220template <
typename OpTy>
222 OpTy::template attachInterface<detail::AutoDiffUsingRegionTerminator<OpTy>>(
224 OpTy::template attachInterface<detail::NoopRevAutoDiffInterface<OpTy>>(
228template <
typename OpTy>
230 OpTy::template attachInterface<detail::AutoDiffUsingRegionTerminator<OpTy>>(
232 OpTy::template attachInterface<detail::ReturnRevAutoDiffInterface<OpTy>>(
236template <
typename OpTy,
int... storedvals>
238 OpTy::template attachInterface<
242template <
typename OpTy>
244 OpTy::template attachInterface<detail::AutoDiffUsingAllocationFwd<OpTy>>(
246 OpTy::template attachInterface<detail::AutoDiffUsingAllocationRev<OpTy>>(
269mlir::TypedAttr
getConstantAttr(mlir::Type type, llvm::StringRef value);
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const
SmallVector< Value > cacheValues(Operation *op, MGradientUtilsReverse *gutils) const
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector< Value > caches) const
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
SmallVector< Value > cacheValues(Operation *op, MGradientUtilsReverse *gutils) const
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector< Value > caches) const
SmallVector< Value > cacheValues(Operation *op, MGradientUtilsReverse *gutils) const
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector< Value > caches) const
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const
LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils, ArrayRef< int > storedVals)
void returnReverseHandler(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
void branchingForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
LogicalResult allocationForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils, bool zero)
void registerComplexDialectAutoDiffInterface(DialectRegistry ®istry)
void registerArithDialectAutoDiffInterface(DialectRegistry ®istry)
void registerAutoDiffUsingAllocationInterface(MLIRContext &context)
void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context)
void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry)
void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry)
void registerAutoDiffUsingBranchInterface(MLIRContext &context)
void registerMathDialectAutoDiffInterface(DialectRegistry ®istry)
void registerEnzymeDialectAutoDiffInterface(DialectRegistry ®istry)
void registerBuiltinDialectAutoDiffInterface(DialectRegistry ®istry)
void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry)
mlir::TypedAttr getConstantAttr(mlir::Type type, llvm::StringRef value)
void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry)
void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry)
void registerMemRefDialectAutoDiffInterface(DialectRegistry ®istry)
void registerCFDialectAutoDiffInterface(DialectRegistry ®istry)
void registerAutoDiffUsingMemoryIdentityInterface(MLIRContext &context)
void registerAutoDiffUsingReturnInterface(MLIRContext &context)
void registerLLVMExtDialectAutoDiffInterface(DialectRegistry ®istry)
void registerAffineDialectAutoDiffInterface(DialectRegistry ®istry)
void registerAutoDiffUsingControlFlowInterface(MLIRContext &context)
void registerNVVMDialectAutoDiffInterface(DialectRegistry ®istry)
void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry)