Enzyme main
Loading...
Searching...
No Matches
GradientUtilsReverse.h
Go to the documentation of this file.
1//===- GradientUtilsReverse.h - Utilities for gradient interfaces -------* C++
2//-*-===//
3//
4// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef ENZYME_MLIR_INTERFACES_GRADIENT_UTILS_REVERSE_H
11#define ENZYME_MLIR_INTERFACES_GRADIENT_UTILS_REVERSE_H
12
13#include "mlir/IR/IRMapping.h"
14#include "mlir/Interfaces/FunctionInterfaces.h"
15
16#include "CloneFunction.h"
17#include "EnzymeLogic.h"
18
19#include <functional>
20
21#include "GradientUtils.h"
22
23namespace mlir {
24namespace enzyme {
25
27public:
28 MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_,
29 FunctionOpInterface oldFunc_, MTypeAnalysis &TA_,
30 IRMapping invertedPointers_,
31 const llvm::ArrayRef<bool> returnPrimals,
32 const llvm::ArrayRef<bool> returnShadows,
33 const SmallPtrSetImpl<mlir::Value> &constantvalues_,
34 const SmallPtrSetImpl<mlir::Value> &activevals_,
35 ArrayRef<DIFFE_TYPE> ReturnActivity,
36 ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
37 IRMapping &originalToNewFn_,
38 std::map<Operation *, Operation *> &originalToNewFnOps_,
39 DerivativeMode mode_, unsigned width, bool omp,
40 llvm::StringRef postpasses, bool verifyPostPasses,
41 bool strongZero);
42
44
45 void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient,
46 OpBuilder &builder);
47
48 Type getIndexType();
49 Value insertInit(Type t);
50
51 SmallVector<std::function<std::pair<Value, Value>(Type)>> cacheCreatorHook;
52 void
53 registerCacheCreatorHook(std::function<std::pair<Value, Value>(Type)> hook);
54 void
55 deregisterCacheCreatorHook(std::function<std::pair<Value, Value>(Type)> hook);
56 std::pair<Value, Value> getNewCache(Type t);
57
58 // Cache
59 Type getCacheType(Type t);
60 Type getIndexCacheType();
61 Value initAndPushCache(Value v, OpBuilder &builder);
62
63 Operation *cloneWithNewOperands(OpBuilder &B, Operation *op);
64
65 Value popCache(Value cache, OpBuilder &builder);
66
67 void createReverseModeBlocks(Region &oldFunc, Region &newFunc);
68
70 MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
71 FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
72 const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
73 llvm::ArrayRef<DIFFE_TYPE> retType,
74 llvm::ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg,
75 bool omp, llvm::StringRef postpasses, bool verifyPostPasses,
76 bool strongZero);
77};
78
79} // namespace enzyme
80} // namespace mlir
81
82#endif // ENZYME_MLIR_INTERFACES_GRADIENT_UTILS_REVERSE_H
DerivativeMode
Definition Utils.h:390
Value popCache(Value cache, OpBuilder &builder)
SmallVector< std::function< std::pair< Value, Value >(Type)> > cacheCreatorHook
Value initAndPushCache(Value v, OpBuilder &builder)
MGradientUtilsReverse(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA_, IRMapping invertedPointers_, const llvm::ArrayRef< bool > returnPrimals, const llvm::ArrayRef< bool > returnShadows, const SmallPtrSetImpl< mlir::Value > &constantvalues_, const SmallPtrSetImpl< mlir::Value > &activevals_, ArrayRef< DIFFE_TYPE > ReturnActivity, ArrayRef< DIFFE_TYPE > ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map< Operation *, Operation * > &originalToNewFnOps_, DerivativeMode mode_, unsigned width, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
void createReverseModeBlocks(Region &oldFunc, Region &newFunc)
void deregisterCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
void registerCacheCreatorHook(std::function< std::pair< Value, Value >(Type)> hook)
void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder)
static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef< bool > returnPrimals, const ArrayRef< bool > returnShadows, llvm::ArrayRef< DIFFE_TYPE > retType, llvm::ArrayRef< DIFFE_TYPE > constant_args, mlir::Type additionalArg, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
std::pair< Value, Value > getNewCache(Type t)
Operation * cloneWithNewOperands(OpBuilder &B, Operation *op)
const llvm::ArrayRef< bool > returnShadows
const llvm::ArrayRef< bool > returnPrimals
FunctionOpInterface oldFunc
FunctionOpInterface newFunc