Enzyme main
Loading...
Searching...
No Matches
LLVMAutoDiffOpInterfaceImpl.cpp
Go to the documentation of this file.
1//===- LLVMAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the external model implementation of the automatic
10// differentiation op interfaces for the upstream LLVM dialect.
11//
12//===----------------------------------------------------------------------===//
13
20#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Support/LogicalResult.h"
23
24using namespace mlir;
25using namespace mlir::enzyme;
26
27namespace {
28#include "Implementations/LLVMDerivatives.inc"
29
30struct InlineAsmActivityInterface
31 : public ActivityOpInterface::ExternalModel<InlineAsmActivityInterface,
32 LLVM::InlineAsmOp> {
33 bool isInactive(Operation *op) const {
34 auto asmOp = cast<LLVM::InlineAsmOp>(op);
35 auto str = asmOp.getAsmString();
36 return str.contains("cpuid") || str.contains("exit");
37 }
38 bool isArgInactive(Operation *op, size_t) const { return isInactive(op); }
39};
40
41class PointerTypeInterface
42 : public AutoDiffTypeInterface::ExternalModel<PointerTypeInterface,
43 LLVM::LLVMPointerType> {
44public:
45 mlir::Attribute createNullAttr(mlir::Type self) const {
46 llvm::errs() << " unsupported: createNullAttribute of pointertype\n";
47 return nullptr;
48 }
49
50 mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
51 Location loc) const {
52 return LLVM::ZeroOp::create(builder, loc, self);
53 }
54
55 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
56 Value b) const {
57 llvm_unreachable("TODO");
58 }
59
60 Value createConjOp(Type self, OpBuilder &builder, Location loc,
61 Value a) const {
62 llvm_unreachable("TODO");
63 }
64
65 Type getShadowType(Type self, unsigned width) const {
66 assert(width == 1 && "unsupported width != 1");
67 return self;
68 }
69
70 bool isMutable(Type self) const { return true; }
71
72 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
73 Value val) const {
74 // TODO inspect val and memset corresponding size
75 return failure();
76 }
77
78 bool isZero(Type self, Value val) const { return false; }
79 bool isZeroAttr(Type self, Attribute attr) const { return false; }
80};
81
82struct GEPOpInterfaceReverse
83 : public ReverseAutoDiffOpInterface::ExternalModel<GEPOpInterfaceReverse,
84 LLVM::GEPOp> {
85 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
87 SmallVector<Value> caches) const {
88 return success();
89 }
90
91 SmallVector<Value> cacheValues(Operation *op,
92 MGradientUtilsReverse *gutils) const {
93 return {};
94 }
95
96 void createShadowValues(Operation *op, OpBuilder &builder,
97 MGradientUtilsReverse *gutils) const {
98 auto gep = cast<LLVM::GEPOp>(op);
99 auto newGep = cast<LLVM::GEPOp>(gutils->getNewFromOriginal(op));
100 auto base = gep.getBase();
101 if (!gutils->isConstantValue(base)) {
102 auto baseShadow = gutils->invertPointerM(base, builder);
103 auto shadowGep = cast<LLVM::GEPOp>(builder.clone(*newGep));
104 shadowGep.getBaseMutable().assign(baseShadow);
105 gutils->setInvertedPointer(gep.getRes(), shadowGep->getResult(0));
106 }
107 }
108};
109
110struct LoadOpInterfaceReverse
111 : public ReverseAutoDiffOpInterface::ExternalModel<LoadOpInterfaceReverse,
112 LLVM::LoadOp> {
113 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
114 MGradientUtilsReverse *gutils,
115 SmallVector<Value> caches) const {
116 auto loadOp = cast<LLVM::LoadOp>(op);
117 Value addr = loadOp.getAddr();
118
119 if (auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
120 if (!gutils->isConstantValue(loadOp) && !gutils->isConstantValue(addr)) {
121 Value gradient = gutils->diffe(loadOp, builder);
122 Value addrGradient = gutils->popCache(caches.front(), builder);
123
124 if (!gutils->AtomicAdd) {
125 Value loadedGradient = LLVM::LoadOp::create(builder, loadOp.getLoc(),
126 iface, addrGradient);
127 Value addedGradient = iface.createAddOp(builder, loadOp.getLoc(),
128 loadedGradient, gradient);
129
130 LLVM::StoreOp::create(builder, loadOp.getLoc(), addedGradient,
131 addrGradient);
132 } else {
133 LLVM::AtomicRMWOp::create(builder, loadOp.getLoc(),
134 LLVM::AtomicBinOp::fadd, addrGradient,
135 gradient, LLVM::AtomicOrdering::monotonic);
136 }
137 }
138 }
139
140 return success();
141 }
142
143 SmallVector<Value> cacheValues(Operation *op,
144 MGradientUtilsReverse *gutils) const {
145 auto loadOp = cast<LLVM::LoadOp>(op);
146 auto addr = loadOp.getAddr();
147 if (!(isa<AutoDiffTypeInterface>(loadOp.getType()) &&
148 (!gutils->isConstantValue(loadOp) && !gutils->isConstantValue(addr))))
149 return {};
150 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
151 return {gutils->initAndPushCache(gutils->invertPointerM(addr, cacheBuilder),
152 cacheBuilder)};
153 }
154
155 void createShadowValues(Operation *op, OpBuilder &builder,
156 MGradientUtilsReverse *gutils) const {
157 // auto loadOp = cast<LLVM::LoadOp>(op);
158 // Value ptr = loadOp.getAddr();
159 }
160};
161
162struct StoreOpInterfaceReverse
163 : public ReverseAutoDiffOpInterface::ExternalModel<StoreOpInterfaceReverse,
164 LLVM::StoreOp> {
165 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
166 MGradientUtilsReverse *gutils,
167 SmallVector<Value> caches) const {
168 auto storeOp = cast<LLVM::StoreOp>(op);
169 Value val = storeOp.getValue();
170 Value addr = storeOp.getAddr();
171
172 auto iface = cast<AutoDiffTypeInterface>(val.getType());
173
174 if (!gutils->isConstantValue(addr)) {
175 Value addrGradient = gutils->popCache(caches.front(), builder);
176
177 if (!iface.isMutable()) {
178 if (!gutils->isConstantValue(val)) {
179 Value loadedGradient = LLVM::LoadOp::create(
180 builder, storeOp.getLoc(), val.getType(), addrGradient);
181 gutils->addToDiffe(val, loadedGradient, builder);
182 }
183
184 auto zero =
185 cast<AutoDiffTypeInterface>(gutils->getShadowType(val.getType()))
186 .createNullValue(builder, op->getLoc());
187
188 LLVM::StoreOp::create(builder, storeOp.getLoc(), zero, addrGradient);
189 }
190 }
191
192 return success();
193 }
194
195 SmallVector<Value> cacheValues(Operation *op,
196 MGradientUtilsReverse *gutils) const {
197 auto storeOp = cast<LLVM::StoreOp>(op);
198 auto addr = storeOp.getAddr();
199 if (gutils->isConstantValue(addr))
200 return {};
201 OpBuilder cacheBuilder(gutils->getNewFromOriginal(op));
202 return {gutils->initAndPushCache(gutils->invertPointerM(addr, cacheBuilder),
203 cacheBuilder)};
204 }
205
206 void createShadowValues(Operation *op, OpBuilder &builder,
207 MGradientUtilsReverse *gutils) const {}
208};
209
210std::optional<Value> findPtrSize(Value ptr) {
211 if (auto allocOp = ptr.getDefiningOp<llvm_ext::AllocOp>())
212 return allocOp.getSize();
213
214 for (auto user : ptr.getUsers()) {
215 if (auto psh = dyn_cast<llvm_ext::PtrSizeHintOp>(user)) {
216 return psh.getSize();
217 }
218 }
219
220 return std::nullopt;
221}
222
223struct PointerClonableTypeInterface
224 : public ClonableTypeInterface::ExternalModel<PointerClonableTypeInterface,
225 LLVM::LLVMPointerType> {
226 mlir::Value cloneValue(Type self, OpBuilder &builder, Value value) const {
227 auto ptrSize = findPtrSize(value);
228 if (!ptrSize) {
229 llvm::errs() << "cannot find size of ptr: " << value << "\n";
230 return nullptr;
231 }
232
233 auto clone = llvm_ext::AllocOp::create(
234 builder, value.getLoc(), LLVM::LLVMPointerType::get(value.getContext()),
235 *ptrSize);
236 LLVM::MemcpyOp::create(builder, value.getLoc(), clone, value, *ptrSize,
237 /*isVolatile*/ false);
238
239 return clone;
240 }
241
242 void freeClonedValue(Type self, OpBuilder &builder, Value value) const {
243 llvm_ext::FreeOp::create(builder, value.getLoc(), value);
244 }
245};
246
247class StructTypeInterface
248 : public AutoDiffTypeInterface::ExternalModel<StructTypeInterface,
249 LLVM::LLVMStructType> {
250public:
251 mlir::Attribute createNullAttr(mlir::Type self) const {
252 llvm::errs() << " unsupported: createNullAttribute of LLVMStructType\n";
253 return nullptr;
254 }
255
256 mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
257 Location loc) const {
258 auto structTy = cast<LLVM::LLVMStructType>(self);
259 Value result = LLVM::PoisonOp::create(builder, loc, structTy);
260 for (auto &&[i, elemTy] : llvm::enumerate(structTy.getBody())) {
261 auto elemIface = dyn_cast<AutoDiffTypeInterface>(elemTy);
262 if (!elemIface) {
263 Value zero = LLVM::ZeroOp::create(builder, loc, elemTy);
264 result = LLVM::InsertValueOp::create(builder, loc, result, zero, i);
265 continue;
266 }
267 Value nullElem = elemIface.createNullValue(builder, loc);
268 result = LLVM::InsertValueOp::create(builder, loc, result, nullElem, i);
269 }
270 return result;
271 }
272
273 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
274 Value b) const {
275 auto structTy = cast<LLVM::LLVMStructType>(self);
276 Value result = LLVM::PoisonOp::create(builder, loc, structTy);
277 for (auto &&[i, elemTy] : llvm::enumerate(structTy.getBody())) {
278 Value aElem = LLVM::ExtractValueOp::create(builder, loc, a, i);
279 Value bElem = LLVM::ExtractValueOp::create(builder, loc, b, i);
280 auto elemIface = dyn_cast<AutoDiffTypeInterface>(elemTy);
281 Value sum;
282 if (elemIface) {
283 sum = elemIface.createAddOp(builder, loc, aElem, bElem);
284 } else {
285 sum = aElem;
286 }
287 result = LLVM::InsertValueOp::create(builder, loc, result, sum, i);
288 }
289 return result;
290 }
291
292 Value createConjOp(Type self, OpBuilder &builder, Location loc,
293 Value a) const {
294 llvm_unreachable("TODO");
295 }
296
297 Type getShadowType(Type self, unsigned width) const {
298 assert(width == 1 && "unsupported width != 1");
299 return self;
300 }
301
302 bool isMutable(Type self) const { return false; }
303
304 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
305 Value val) const {
306 return failure();
307 }
308
309 bool isZero(Type self, Value val) const { return false; }
310 bool isZeroAttr(Type self, Attribute attr) const { return false; }
311};
312
313static Value packIntoStruct(ValueRange values, OpBuilder &builder,
314 Location loc) {
315 SmallVector<Type> resultTypes =
316 llvm::map_to_vector(values, [](Value v) { return v.getType(); });
317 auto structType =
318 LLVM::LLVMStructType::getLiteral(builder.getContext(), resultTypes);
319 Value result = LLVM::PoisonOp::create(builder, loc, structType);
320 for (auto &&[i, v] : llvm::enumerate(values))
321 result = LLVM::InsertValueOp::create(builder, loc, result, v, i);
322
323 return result;
324}
325
326class AutoDiffLLVMFuncOpFunctionInterface
327 : public AutoDiffFunctionInterface::ExternalModel<
328 AutoDiffLLVMFuncOpFunctionInterface, LLVM::LLVMFuncOp> {
329public:
330 void transformResultTypes(Operation *self,
331 SmallVectorImpl<Type> &resultTypes) const {
332 auto fn = cast<mlir::FunctionOpInterface>(self);
333 auto FTy = fn.getFunctionType();
334 if (resultTypes.empty()) {
335 // llvm.func ops that return no results need to explicitly return
336 // LLVMVoidType
337 resultTypes.push_back(LLVM::LLVMVoidType::get(FTy.getContext()));
338 } else if (resultTypes.size() > 1) {
339 auto structType =
340 LLVM::LLVMStructType::getLiteral(FTy.getContext(), resultTypes);
341 resultTypes.clear();
342 resultTypes.push_back(structType);
343 }
344 }
345
346 Operation *createCall(Operation *self, OpBuilder &builder, Location loc,
347 ValueRange args) const {
348 return LLVM::CallOp::create(builder, loc, cast<LLVM::LLVMFuncOp>(self),
349 args);
350 }
351
352 Operation *createReturn(Operation *self, OpBuilder &builder, Location loc,
353 ValueRange retargs) const {
354 if (retargs.size() > 1) {
355 Value packedReturns = packIntoStruct(retargs, builder, loc);
356 return LLVM::ReturnOp::create(builder, loc, packedReturns);
357 }
358
359 return LLVM::ReturnOp::create(builder, loc, retargs);
360 }
361};
362
363} // namespace
364
366 DialectRegistry &registry) {
367 registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) {
368 registerInterfaces(context);
369 LLVM::LLVMPointerType::attachInterface<PointerTypeInterface>(*context);
370 LLVM::LLVMPointerType::attachInterface<PointerClonableTypeInterface>(
371 *context);
372 LLVM::LLVMStructType::attachInterface<StructTypeInterface>(*context);
373 LLVM::LoadOp::attachInterface<LoadOpInterfaceReverse>(*context);
374 LLVM::StoreOp::attachInterface<StoreOpInterfaceReverse>(*context);
375 LLVM::GEPOp::attachInterface<GEPOpInterfaceReverse>(*context);
376 LLVM::UnreachableOp::template attachInterface<
378 LLVM::LLVMFuncOp::attachInterface<AutoDiffLLVMFuncOpFunctionInterface>(
379 *context);
380 });
381}
Type getShadowType(Type type, unsigned width)
Operation * clone(Operation *src, IRMapping &mapper, Operation::CloneOptions options, std::map< Operation *, Operation * > &opMap)
static bool isZero(llvm::Constant *cst)
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
static bool isMutable(Type type)
Definition Ops.cpp:235
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
mlir::Type getShadowType(mlir::Type T)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
void setInvertedPointer(mlir::Value origv, mlir::Value newv)
bool isConstantValue(mlir::Value v) const
void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry)