Enzyme main
Loading...
Searching...
No Matches
EnzymeWrapPass.cpp
Go to the documentation of this file.
1//===- EnzymeWrapPass.cpp - Replace calls with their derivatives ------------ //
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to create wrapper functions which differentiate
10// ops.
11//===----------------------------------------------------------------------===//
12
13#include "Dialect/Ops.h"
16#include "PassDetails.h"
17#include "Passes/Passes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/Interfaces/ControlFlowInterfaces.h"
21#include "mlir/Interfaces/FunctionInterfaces.h"
22
23#include "mlir/Dialect/Func/IR/FuncOps.h"
24
25#define DEBUG_TYPE "enzyme"
26
27namespace mlir {
28namespace enzyme {
29#define GEN_PASS_DEF_DIFFERENTIATEWRAPPERPASS
30#include "Passes/Passes.h.inc"
31} // namespace enzyme
32} // namespace mlir
33
34using namespace mlir;
35using namespace mlir::enzyme;
36using namespace enzyme;
37
38std::vector<DIFFE_TYPE> parseActivityString(StringRef inp) {
39 if (inp.size() == 0)
40 return {};
41 std::vector<DIFFE_TYPE> ArgActivity;
42 SmallVector<StringRef, 1> split;
43 StringRef(inp.data(), inp.size()).split(split, ',');
44 for (auto &str : split) {
45 if (str == "enzyme_dup")
46 ArgActivity.push_back(DIFFE_TYPE::DUP_ARG);
47 else if (str == "enzyme_const")
48 ArgActivity.push_back(DIFFE_TYPE::CONSTANT);
49 else if (str == "enzyme_dupnoneed")
50 ArgActivity.push_back(DIFFE_TYPE::DUP_NONEED);
51 else if (str == "enzyme_active")
52 ArgActivity.push_back(DIFFE_TYPE::OUT_DIFF);
53 else {
54 llvm::errs() << "unknown activity to parse, found: '" << str << "'\n";
55 assert(0 && " unknown constant");
56 }
57 }
58 return ArgActivity;
59}
60
61namespace {
62struct DifferentiateWrapperPass
63 : public enzyme::impl::DifferentiateWrapperPassBase<
64 DifferentiateWrapperPass> {
65 using DifferentiateWrapperPassBase::DifferentiateWrapperPassBase;
66
67 void runOnOperation() override {
68 MEnzymeLogic Logic;
69 SymbolTableCollection symbolTable;
70 symbolTable.getSymbolTable(getOperation());
71
72 Operation *symbolOp = nullptr;
73 if (infn != "")
74 symbolOp = symbolTable.lookupSymbolIn<Operation *>(
75 getOperation(), StringAttr::get(getOperation()->getContext(), infn));
76 else {
77 for (auto &op : getOperation()->getRegion(0).front()) {
78 auto fn = dyn_cast<FunctionOpInterface>(symbolOp);
79 if (!fn)
80 continue;
81 assert(symbolOp == nullptr);
82 symbolOp = &op;
83 }
84 }
85 if (!symbolOp) {
86 llvm::errs() << " Could not find function '" << infn
87 << "' to differentiate\n";
88 signalPassFailure();
89 return;
90 }
91 auto fn = cast<FunctionOpInterface>(symbolOp);
92 bool omp = false;
93 std::string postpasses = "";
94 bool verifyPostPasses = true;
95 bool strongZero = false;
96
97 std::vector<DIFFE_TYPE> ArgActivity =
98 parseActivityString(argTys.getValue());
99
100 if (ArgActivity.size() != fn.getNumArguments()) {
101 fn->emitError()
102 << "Incorrect number of arg activity states for function, found "
103 << ArgActivity.size() << " expected "
104 << fn.getFunctionBody().front().getNumArguments();
105 return;
106 }
107
108 std::vector<DIFFE_TYPE> RetActivity =
109 parseActivityString(retTys.getValue());
110 if (RetActivity.size() != fn.getNumResults()) {
111 fn->emitError()
112 << "Incorrect number of ret activity states for function, found "
113 << RetActivity.size() << " expected " << fn.getNumResults();
114 return;
115 }
116 std::vector<bool> returnPrimal;
117 std::vector<bool> returnShadow;
118 for (auto act : RetActivity) {
119 returnPrimal.push_back(act == DIFFE_TYPE::DUP_ARG);
120 returnShadow.push_back(false);
121 }
122
123 MTypeAnalysis TA;
124 auto type_args = TA.getAnalyzedTypeInfo(fn);
125
126 bool freeMemory = true;
127 size_t width = 1;
128
129 std::vector<bool> volatile_args;
130 for (auto &a : fn.getFunctionBody().getArguments()) {
131 (void)a;
132 volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
133 }
134
135 FunctionOpInterface newFunc;
136 if (mode == DerivativeMode::ForwardMode) {
137 newFunc = Logic.CreateForwardDiff(
138 fn, RetActivity, ArgActivity, TA, returnPrimal, mode, freeMemory,
139 width,
140 /*addedType*/ nullptr, type_args, volatile_args,
141 /*augmented*/ nullptr, omp, postpasses, verifyPostPasses, strongZero);
142 } else {
143 newFunc = Logic.CreateReverseDiff(
144 fn, RetActivity, ArgActivity, TA, returnPrimal, returnShadow, mode,
145 freeMemory, width,
146 /*addedType*/ nullptr, type_args, volatile_args,
147 /*augmented*/ nullptr, omp, postpasses, verifyPostPasses, strongZero);
148 }
149 if (!newFunc) {
150 signalPassFailure();
151 return;
152 }
153 if (outfn == "") {
154 fn->erase();
155 SymbolTable::setSymbolVisibility(newFunc,
156 SymbolTable::Visibility::Public);
157 SymbolTable::setSymbolName(cast<FunctionOpInterface>(newFunc),
158 (std::string)infn);
159 } else {
160 SymbolTable::setSymbolName(cast<FunctionOpInterface>(newFunc),
161 (std::string)outfn);
162 }
163 }
164};
165
166} // end anonymous namespace
std::vector< DIFFE_TYPE > parseActivityString(StringRef inp)
static std::string str(AugmentedStruct c)
Definition EnzymeLogic.h:62
FunctionOpInterface CreateReverseDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, std::vector< bool > returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
FunctionOpInterface CreateForwardDiff(FunctionOpInterface fn, std::vector< DIFFE_TYPE > retType, std::vector< DIFFE_TYPE > constants, MTypeAnalysis &TA, std::vector< bool > returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector< bool > volatile_args, void *augmented, bool omp, llvm::StringRef postpasses, bool verifyPostPasses, bool strongZero)
MFnTypeInfo getAnalyzedTypeInfo(FunctionOpInterface op) const
Definition EnzymeLogic.h:24