20#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21#include "mlir/IR/DialectRegistry.h"
22#include "mlir/Support/LogicalResult.h"
28#include "Implementations/LLVMDerivatives.inc"
30struct InlineAsmActivityInterface
31 :
public ActivityOpInterface::ExternalModel<InlineAsmActivityInterface,
33 bool isInactive(Operation *op)
const {
34 auto asmOp = cast<LLVM::InlineAsmOp>(op);
35 auto str = asmOp.getAsmString();
36 return str.contains(
"cpuid") ||
str.contains(
"exit");
38 bool isArgInactive(Operation *op,
size_t)
const {
return isInactive(op); }
41class PointerTypeInterface
42 :
public AutoDiffTypeInterface::ExternalModel<PointerTypeInterface,
43 LLVM::LLVMPointerType> {
45 mlir::Attribute createNullAttr(mlir::Type self)
const {
46 llvm::errs() <<
" unsupported: createNullAttribute of pointertype\n";
50 mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
52 return LLVM::ZeroOp::create(builder, loc, self);
55 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
57 llvm_unreachable(
"TODO");
60 Value createConjOp(Type self, OpBuilder &builder, Location loc,
62 llvm_unreachable(
"TODO");
66 assert(width == 1 &&
"unsupported width != 1");
70 bool isMutable(Type self)
const {
return true; }
72 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
78 bool isZero(Type self, Value val)
const {
return false; }
79 bool isZeroAttr(Type self, Attribute attr)
const {
return false; }
82struct GEPOpInterfaceReverse
83 :
public ReverseAutoDiffOpInterface::ExternalModel<GEPOpInterfaceReverse,
85 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
87 SmallVector<Value> caches)
const {
91 SmallVector<Value> cacheValues(Operation *op,
96 void createShadowValues(Operation *op, OpBuilder &builder,
98 auto gep = cast<LLVM::GEPOp>(op);
100 auto base = gep.getBase();
103 auto shadowGep = cast<LLVM::GEPOp>(builder.clone(*newGep));
104 shadowGep.getBaseMutable().assign(baseShadow);
110struct LoadOpInterfaceReverse
111 :
public ReverseAutoDiffOpInterface::ExternalModel<LoadOpInterfaceReverse,
113 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
115 SmallVector<Value> caches)
const {
116 auto loadOp = cast<LLVM::LoadOp>(op);
117 Value addr = loadOp.getAddr();
119 if (
auto iface = dyn_cast<AutoDiffTypeInterface>(loadOp.getType())) {
121 Value gradient = gutils->
diffe(loadOp, builder);
122 Value addrGradient = gutils->
popCache(caches.front(), builder);
125 Value loadedGradient = LLVM::LoadOp::create(builder, loadOp.getLoc(),
126 iface, addrGradient);
127 Value addedGradient = iface.createAddOp(builder, loadOp.getLoc(),
128 loadedGradient, gradient);
130 LLVM::StoreOp::create(builder, loadOp.getLoc(), addedGradient,
133 LLVM::AtomicRMWOp::create(builder, loadOp.getLoc(),
134 LLVM::AtomicBinOp::fadd, addrGradient,
135 gradient, LLVM::AtomicOrdering::monotonic);
143 SmallVector<Value> cacheValues(Operation *op,
145 auto loadOp = cast<LLVM::LoadOp>(op);
146 auto addr = loadOp.getAddr();
147 if (!(isa<AutoDiffTypeInterface>(loadOp.getType()) &&
155 void createShadowValues(Operation *op, OpBuilder &builder,
162struct StoreOpInterfaceReverse
163 :
public ReverseAutoDiffOpInterface::ExternalModel<StoreOpInterfaceReverse,
165 LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
167 SmallVector<Value> caches)
const {
168 auto storeOp = cast<LLVM::StoreOp>(op);
169 Value val = storeOp.getValue();
170 Value addr = storeOp.getAddr();
172 auto iface = cast<AutoDiffTypeInterface>(val.getType());
175 Value addrGradient = gutils->
popCache(caches.front(), builder);
177 if (!iface.isMutable()) {
179 Value loadedGradient = LLVM::LoadOp::create(
180 builder, storeOp.getLoc(), val.getType(), addrGradient);
181 gutils->
addToDiffe(val, loadedGradient, builder);
185 cast<AutoDiffTypeInterface>(gutils->
getShadowType(val.getType()))
186 .createNullValue(builder, op->getLoc());
188 LLVM::StoreOp::create(builder, storeOp.getLoc(), zero, addrGradient);
195 SmallVector<Value> cacheValues(Operation *op,
197 auto storeOp = cast<LLVM::StoreOp>(op);
198 auto addr = storeOp.getAddr();
206 void createShadowValues(Operation *op, OpBuilder &builder,
210std::optional<Value> findPtrSize(Value ptr) {
211 if (
auto allocOp = ptr.getDefiningOp<llvm_ext::AllocOp>())
212 return allocOp.getSize();
214 for (
auto user : ptr.getUsers()) {
215 if (
auto psh = dyn_cast<llvm_ext::PtrSizeHintOp>(user)) {
216 return psh.getSize();
223struct PointerClonableTypeInterface
224 :
public ClonableTypeInterface::ExternalModel<PointerClonableTypeInterface,
225 LLVM::LLVMPointerType> {
226 mlir::Value cloneValue(Type self, OpBuilder &builder, Value value)
const {
227 auto ptrSize = findPtrSize(value);
229 llvm::errs() <<
"cannot find size of ptr: " << value <<
"\n";
233 auto clone = llvm_ext::AllocOp::create(
234 builder, value.getLoc(), LLVM::LLVMPointerType::get(value.getContext()),
236 LLVM::MemcpyOp::create(builder, value.getLoc(),
clone, value, *ptrSize,
242 void freeClonedValue(Type self, OpBuilder &builder, Value value)
const {
243 llvm_ext::FreeOp::create(builder, value.getLoc(), value);
247class StructTypeInterface
248 :
public AutoDiffTypeInterface::ExternalModel<StructTypeInterface,
249 LLVM::LLVMStructType> {
251 mlir::Attribute createNullAttr(mlir::Type self)
const {
252 llvm::errs() <<
" unsupported: createNullAttribute of LLVMStructType\n";
256 mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
257 Location loc)
const {
258 auto structTy = cast<LLVM::LLVMStructType>(self);
259 Value result = LLVM::PoisonOp::create(builder, loc, structTy);
260 for (
auto &&[i, elemTy] : llvm::enumerate(structTy.getBody())) {
261 auto elemIface = dyn_cast<AutoDiffTypeInterface>(elemTy);
263 Value zero = LLVM::ZeroOp::create(builder, loc, elemTy);
264 result = LLVM::InsertValueOp::create(builder, loc, result, zero, i);
267 Value nullElem = elemIface.createNullValue(builder, loc);
268 result = LLVM::InsertValueOp::create(builder, loc, result, nullElem, i);
273 Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
275 auto structTy = cast<LLVM::LLVMStructType>(self);
276 Value result = LLVM::PoisonOp::create(builder, loc, structTy);
277 for (
auto &&[i, elemTy] : llvm::enumerate(structTy.getBody())) {
278 Value aElem = LLVM::ExtractValueOp::create(builder, loc, a, i);
279 Value bElem = LLVM::ExtractValueOp::create(builder, loc, b, i);
280 auto elemIface = dyn_cast<AutoDiffTypeInterface>(elemTy);
283 sum = elemIface.createAddOp(builder, loc, aElem, bElem);
287 result = LLVM::InsertValueOp::create(builder, loc, result, sum, i);
292 Value createConjOp(Type self, OpBuilder &builder, Location loc,
294 llvm_unreachable(
"TODO");
298 assert(width == 1 &&
"unsupported width != 1");
302 bool isMutable(Type self)
const {
return false; }
304 LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
309 bool isZero(Type self, Value val)
const {
return false; }
310 bool isZeroAttr(Type self, Attribute attr)
const {
return false; }
313static Value packIntoStruct(ValueRange values, OpBuilder &builder,
315 SmallVector<Type> resultTypes =
316 llvm::map_to_vector(values, [](Value v) {
return v.getType(); });
318 LLVM::LLVMStructType::getLiteral(builder.getContext(), resultTypes);
319 Value result = LLVM::PoisonOp::create(builder, loc, structType);
320 for (
auto &&[i, v] : llvm::enumerate(values))
321 result = LLVM::InsertValueOp::create(builder, loc, result, v, i);
326class AutoDiffLLVMFuncOpFunctionInterface
327 :
public AutoDiffFunctionInterface::ExternalModel<
328 AutoDiffLLVMFuncOpFunctionInterface, LLVM::LLVMFuncOp> {
330 void transformResultTypes(Operation *self,
331 SmallVectorImpl<Type> &resultTypes)
const {
332 auto fn = cast<mlir::FunctionOpInterface>(self);
333 auto FTy = fn.getFunctionType();
334 if (resultTypes.empty()) {
337 resultTypes.push_back(LLVM::LLVMVoidType::get(FTy.getContext()));
338 }
else if (resultTypes.size() > 1) {
340 LLVM::LLVMStructType::getLiteral(FTy.getContext(), resultTypes);
342 resultTypes.push_back(structType);
346 Operation *createCall(Operation *self, OpBuilder &builder, Location loc,
347 ValueRange args)
const {
348 return LLVM::CallOp::create(builder, loc, cast<LLVM::LLVMFuncOp>(self),
352 Operation *createReturn(Operation *self, OpBuilder &builder, Location loc,
353 ValueRange retargs)
const {
354 if (retargs.size() > 1) {
355 Value packedReturns = packIntoStruct(retargs, builder, loc);
356 return LLVM::ReturnOp::create(builder, loc, packedReturns);
359 return LLVM::ReturnOp::create(builder, loc, retargs);
366 DialectRegistry ®istry) {
367 registry.addExtension(+[](MLIRContext *context, LLVM::LLVMDialect *) {
368 registerInterfaces(context);
369 LLVM::LLVMPointerType::attachInterface<PointerTypeInterface>(*context);
370 LLVM::LLVMPointerType::attachInterface<PointerClonableTypeInterface>(
372 LLVM::LLVMStructType::attachInterface<StructTypeInterface>(*context);
373 LLVM::LoadOp::attachInterface<LoadOpInterfaceReverse>(*context);
374 LLVM::StoreOp::attachInterface<StoreOpInterfaceReverse>(*context);
375 LLVM::GEPOp::attachInterface<GEPOpInterfaceReverse>(*context);
376 LLVM::UnreachableOp::template attachInterface<
378 LLVM::LLVMFuncOp::attachInterface<AutoDiffLLVMFuncOpFunctionInterface>(
Type getShadowType(Type type, unsigned width)
Operation * clone(Operation *src, IRMapping &mapper, Operation::CloneOptions options, std::map< Operation *, Operation * > &opMap)
static bool isZero(llvm::Constant *cst)
static std::string str(AugmentedStruct c)
static bool isMutable(Type type)
mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder)
Value popCache(Value cache, OpBuilder &builder)
Value initAndPushCache(Value v, OpBuilder &builder)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
mlir::Type getShadowType(mlir::Type T)
mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2)
SmallVector< mlir::Value, 1 > getNewFromOriginal(ValueRange originst) const
void setInvertedPointer(mlir::Value origv, mlir::Value newv)
bool isConstantValue(mlir::Value v) const
void registerLLVMDialectAutoDiffInterface(DialectRegistry ®istry)