3#include "mlir/IR/IRMapping.h"
4#include "mlir/Interfaces/FunctionInterfaces.h"
7#include "../../Utils.h"
35 bool pointerIntSame =
false)
const {
36 if (isa<IntegerType, IndexType>(val.getType())) {
40 llvm_unreachable(
"something happened");
72 if (std::lexicographical_compare(
76 if (std::lexicographical_compare(
146 if (std::lexicographical_compare(rhs.
retActivity.begin(),
155 if (std::lexicographical_compare(rhs.
argActivity.begin(),
222 bool freeMemory,
size_t width, mlir::Type addedType,
223 MFnTypeInfo type_args, std::vector<bool> volatile_args,
224 void *augmented,
bool omp, llvm::StringRef postpasses,
225 bool verifyPostPasses,
bool strongZero);
230 std::vector<bool> returnPrimals,
232 bool freeMemory,
size_t width, mlir::Type addedType,
233 MFnTypeInfo type_args, std::vector<bool> volatile_args,
234 void *augmented,
bool omp, llvm::StringRef postpasses,
235 bool verifyPostPasses,
bool strongZero);
243 llvm::function_ref<buildReturnFunction> buildReturnOp);
246 LogicalResult
visitChild(Operation *op, OpBuilder &builder,
253 llvm::function_ref<buildReturnFunction> buildFuncRetrunOp,
254 std::function<std::pair<Value, Value>(Type)> cacheCreator);
Concrete SubType of a given value.
Class representing the underlying types of values as sequences of offsets to a ConcreteType.
void initializeShadowValues(SmallVector< mlir::Block * > &dominatorToposortBlocks, MGradientUtilsReverse *gutils)
void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils)
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
LogicalResult visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
LogicalResult visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils)
std::map< MReverseCacheKey, FunctionOpInterface > ReverseCachedFunctions
void handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils, llvm::function_ref< buildReturnFunction > buildReturnOp)
std::map< MForwardCacheKey, FunctionOpInterface > ForwardCachedFunctions
LogicalResult differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, llvm::function_ref< buildReturnFunction > buildFuncRetrunOp, std::function< std::pair< Value, Value >(Type)> cacheCreator)
bool operator<(const MFnTypeInfo &rhs) const
MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const
TypeTree query(Value) const
ConcreteType intType(size_t num, Value val, bool errIfNotFound=true, bool pointerIntSame=false) const
TypeTree getReturnAnalysis()
void buildReturnFunction(OpBuilder &, mlir::Block *)
const MFnTypeInfo typeInfo
FunctionOpInterface todiff
mlir::Type additionalType
const std::vector< DIFFE_TYPE > retType
const std::vector< DIFFE_TYPE > constant_args
bool operator<(const MForwardCacheKey &rhs) const
std::vector< bool > returnUsed
const std::vector< DIFFE_TYPE > argActivity
const std::vector< bool > returnPrimals
const std::vector< bool > volatileArgs
const std::vector< DIFFE_TYPE > retActivity
mlir::Type additionalType
const MFnTypeInfo typeInfo
bool operator<(const MReverseCacheKey &rhs) const
FunctionOpInterface todiff
const std::vector< bool > returnShadows