Enzyme main
Loading...
Searching...
No Matches
CoreDialectsAutoDiffImplementations.h
Go to the documentation of this file.
1//===- CoreDialectsAutoDiffImplementation.h - Impl registrations -* C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains context registration facilities for external model
10// implementations of the automatic differentiation interface for upstream MLIR
11// dialects.
12//
13//===----------------------------------------------------------------------===//
14
15#ifndef ENZYMEMLIR_CORE_IMPL_H_
16#define ENZYMEMLIR_CORE_IMPL_H_
17
19#include "mlir/Support/LogicalResult.h"
20
21#include "llvm/ADT/DenseSet.h"
22
23namespace mlir {
24class DialectRegistry;
25class Operation;
26class OpBuilder;
27class RegionSuccessor;
28
29namespace enzyme {
30class MGradientUtils;
31class MGradientUtilsReverse;
32
33namespace detail {
34// Non-template implementation of
35// AutoDiffUsingControlFlow::createForwardModeTangent.
36
37LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder,
38 MGradientUtils *gutils);
39
40LogicalResult controlFlowForwardHandler(
41 Operation *op, OpBuilder &builder, MGradientUtils *gutils,
42 const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
43 const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow);
44
45// Implements forward-mode differentiation of branching operations.
46// Assumes that successive shadows are legal
47void branchingForwardHandler(Operation *op, OpBuilder &builder,
48 MGradientUtils *gutils);
49
50// Implements forward-mode differentiation of region-terminator operations.
51// Assumes that successive shadows are legal
52void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder,
53 MGradientUtils *gutils);
54
55// Implements reverse-mode differentiation of return operations.
56void returnReverseHandler(Operation *op, OpBuilder &builder,
57 MGradientUtilsReverse *gutils);
58
59// Implements forward-mode differentiation of read-only (including read-none)
60// operations which do not perform computation
61LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder,
62 MGradientUtils *gutils,
63 ArrayRef<int> storedVals);
64
65// Implements shadow initialization differentiation of allocation
66LogicalResult allocationForwardHandler(Operation *op, OpBuilder &builder,
67 MGradientUtils *gutils, bool zero);
68
69// Implements the forward autodiff interface for operations whose derivatives
70// are can be inferred by analyzing their control flow and differentiating the
71// nested operations.
72template <typename OpTy>
74 : public AutoDiffOpInterface::ExternalModel<AutoDiffUsingControlFlow<OpTy>,
75 OpTy> {
76public:
77 LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
78 MGradientUtils *gutils) const {
79 return controlFlowForwardHandler(op, builder, gutils);
80 }
81};
82
83// Implements the forward autodiff interface for operations whose derivatives
84// are can be inferred by analyzing their branching properties.
85template <typename OpTy>
87 : public AutoDiffOpInterface::ExternalModel<AutoDiffUsingBranch<OpTy>,
88 OpTy> {
89public:
90 LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
91 MGradientUtils *gutils) const {
92 branchingForwardHandler(op, builder, gutils);
93 return success();
94 }
95};
96
97// Implements the forward autodiff interface for operations whose derivatives
98// are can be inferred by analyzing their region terminator properties.
99template <typename OpTy>
101 : public AutoDiffOpInterface::ExternalModel<
102 AutoDiffUsingRegionTerminator<OpTy>, OpTy> {
103public:
104 LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
105 MGradientUtils *gutils) const {
106 regionTerminatorForwardHandler(op, builder, gutils);
107 return success();
108 }
109};
110
111template <typename OpTy>
113 : public ReverseAutoDiffOpInterface::ExternalModel<
114 NoopRevAutoDiffInterface<OpTy>, OpTy> {
115public:
116 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
117 MGradientUtilsReverse *gutils,
118 SmallVector<Value> caches) const {
119 return success();
120 }
121
122 SmallVector<Value> cacheValues(Operation *op,
123 MGradientUtilsReverse *gutils) const {
124 return SmallVector<Value>();
125 }
126
127 void createShadowValues(Operation *op, OpBuilder &builder,
128 MGradientUtilsReverse *gutils) const {}
129};
130
131template <typename OpTy>
133 : public ReverseAutoDiffOpInterface::ExternalModel<
134 ReturnRevAutoDiffInterface<OpTy>, OpTy> {
135public:
136 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
137 MGradientUtilsReverse *gutils,
138 SmallVector<Value> caches) const {
139 returnReverseHandler(op, builder, gutils);
140 return success();
141 }
142
143 SmallVector<Value> cacheValues(Operation *op,
144 MGradientUtilsReverse *gutils) const {
145 return SmallVector<Value>();
146 }
147
148 void createShadowValues(Operation *op, OpBuilder &builder,
149 MGradientUtilsReverse *gutils) const {}
150};
151
152// Implements the forward autodiff interface for operations which are
153// read only and identity like (aka not computing sin of mem read).
154template <typename OpTy, int... storedvals>
156 : public AutoDiffOpInterface::ExternalModel<
157 AutoDiffUsingMemoryIdentity<OpTy, storedvals...>, OpTy> {
158public:
159 LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
160 MGradientUtils *gutils) const {
161
163 op, builder, gutils, std::initializer_list<int>{storedvals...});
164 }
165};
166
167// Implements the forward autodiff interface for operations which are
168// allocation like
169template <typename OpTy>
170class AutoDiffUsingAllocationFwd : public AutoDiffOpInterface::ExternalModel<
171 AutoDiffUsingAllocationFwd<OpTy>, OpTy> {
172public:
173 LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder,
174 MGradientUtils *gutils) const {
175
176 return allocationForwardHandler(op, builder, gutils, /*zero*/ false);
177 }
178};
179
180// Implements the reverse autodiff interface for operations which are
181// allocation like
182template <typename OpTy>
184 : public ReverseAutoDiffOpInterface::ExternalModel<
185 AutoDiffUsingAllocationRev<OpTy>, OpTy> {
186public:
187 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
188 MGradientUtilsReverse *gutils,
189 SmallVector<Value> caches) const {
190 return success();
191 }
192
193 SmallVector<Value> cacheValues(Operation *op,
194 MGradientUtilsReverse *gutils) const {
195 return SmallVector<Value>();
196 }
197
198 void createShadowValues(Operation *op, OpBuilder &builder,
199 MGradientUtilsReverse *gutils) const {
200 (void)allocationForwardHandler(op, builder, (MGradientUtils *)gutils,
201 /*zero*/ true);
202 }
203};
204} // namespace detail
205
206// Registers AutoDiffUsingControlFlow for the given op.
207template <typename OpTy>
209 OpTy::template attachInterface<detail::AutoDiffUsingControlFlow<OpTy>>(
210 context);
211}
212// Registers AutoDiffUsingBranch for the given op.
213template <typename OpTy>
214void registerAutoDiffUsingBranchInterface(MLIRContext &context) {
215 OpTy::template attachInterface<detail::AutoDiffUsingBranch<OpTy>>(context);
216 OpTy::template attachInterface<detail::NoopRevAutoDiffInterface<OpTy>>(
217 context);
218}
219// Registers AutoDiffUsingRegionTerminator for the given op.
220template <typename OpTy>
222 OpTy::template attachInterface<detail::AutoDiffUsingRegionTerminator<OpTy>>(
223 context);
224 OpTy::template attachInterface<detail::NoopRevAutoDiffInterface<OpTy>>(
225 context);
226}
227// Registers registerAutoDiffUsingReturnInterface for the given op.
228template <typename OpTy>
229void registerAutoDiffUsingReturnInterface(MLIRContext &context) {
230 OpTy::template attachInterface<detail::AutoDiffUsingRegionTerminator<OpTy>>(
231 context);
232 OpTy::template attachInterface<detail::ReturnRevAutoDiffInterface<OpTy>>(
233 context);
234}
235// Registers AutoDiffUsingMemoryIdentity for the given op.
236template <typename OpTy, int... storedvals>
238 OpTy::template attachInterface<
239 detail::AutoDiffUsingMemoryIdentity<OpTy, storedvals...>>(context);
240}
241// Registers AutoDiffUsingAllocation for the given op.
242template <typename OpTy>
244 OpTy::template attachInterface<detail::AutoDiffUsingAllocationFwd<OpTy>>(
245 context);
246 OpTy::template attachInterface<detail::AutoDiffUsingAllocationRev<OpTy>>(
247 context);
248}
249
250// Interface registration hooks for individual upstream dialects.
251void registerAffineDialectAutoDiffInterface(DialectRegistry &registry);
252void registerArithDialectAutoDiffInterface(DialectRegistry &registry);
253void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry);
254void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry);
255void registerLLVMExtDialectAutoDiffInterface(DialectRegistry &registry);
256void registerNVVMDialectAutoDiffInterface(DialectRegistry &registry);
257void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry);
258void registerComplexDialectAutoDiffInterface(DialectRegistry &registry);
259void registerSCFDialectAutoDiffInterface(DialectRegistry &registry);
260void registerCFDialectAutoDiffInterface(DialectRegistry &registry);
261void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
262void registerMathDialectAutoDiffInterface(DialectRegistry &registry);
263void registerFuncDialectAutoDiffInterface(DialectRegistry &registry);
264void registerTensorDialectAutoDiffInterface(DialectRegistry &registry);
265void registerEnzymeDialectAutoDiffInterface(DialectRegistry &registry);
266
267void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry);
268
269mlir::TypedAttr getConstantAttr(mlir::Type type, llvm::StringRef value);
270} // namespace enzyme
271} // namespace mlir
272
273#endif
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const
SmallVector< Value > cacheValues(Operation *op, MGradientUtilsReverse *gutils) const
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector< Value > caches) const
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
LogicalResult createForwardModeTangent(Operation *op, OpBuilder &builder, MGradientUtils *gutils) const
SmallVector< Value > cacheValues(Operation *op, MGradientUtilsReverse *gutils) const
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector< Value > caches) const
SmallVector< Value > cacheValues(Operation *op, MGradientUtilsReverse *gutils) const
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector< Value > caches) const
void createShadowValues(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) const
LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils, ArrayRef< int > storedVals)
void returnReverseHandler(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
void branchingForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils)
LogicalResult allocationForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils, bool zero)
void registerComplexDialectAutoDiffInterface(DialectRegistry &registry)
void registerArithDialectAutoDiffInterface(DialectRegistry &registry)
void registerAutoDiffUsingAllocationInterface(MLIRContext &context)
void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context)
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry)
void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry)
void registerAutoDiffUsingBranchInterface(MLIRContext &context)
void registerMathDialectAutoDiffInterface(DialectRegistry &registry)
void registerEnzymeDialectAutoDiffInterface(DialectRegistry &registry)
void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry)
void registerFuncDialectAutoDiffInterface(DialectRegistry &registry)
mlir::TypedAttr getConstantAttr(mlir::Type type, llvm::StringRef value)
void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry)
void registerSCFDialectAutoDiffInterface(DialectRegistry &registry)
void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry)
void registerCFDialectAutoDiffInterface(DialectRegistry &registry)
void registerAutoDiffUsingMemoryIdentityInterface(MLIRContext &context)
void registerAutoDiffUsingReturnInterface(MLIRContext &context)
void registerLLVMExtDialectAutoDiffInterface(DialectRegistry &registry)
void registerAffineDialectAutoDiffInterface(DialectRegistry &registry)
void registerAutoDiffUsingControlFlowInterface(MLIRContext &context)
void registerNVVMDialectAutoDiffInterface(DialectRegistry &registry)
void registerTensorDialectAutoDiffInterface(DialectRegistry &registry)