Enzyme main
Loading...
Searching...
No Matches
InstructionBatcher.cpp
Go to the documentation of this file.
1//===- InstructionBatcher.cpp
2//--------------------------------------------------===//
3//
4// Enzyme Project
5//
6// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
7// Exceptions. See https://llvm.org/LICENSE.txt for license information.
8// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9//
10// If using this code in an academic setting, please cite the following:
11// @incollection{enzymeNeurips,
12// title = {Instead of Rewriting Foreign Code for Machine Learning,
13// Automatically Synthesize Fast Gradients},
14// author = {Moses, William S. and Churavy, Valentin},
15// booktitle = {Advances in Neural Information Processing Systems 33},
16// year = {2020},
17// note = {To appear in},
18// }
19//
20//===----------------------------------------------------------------------===//
21//
22// This file contains an instruction visitor InstructionBatcher that generates
23// the batches all LLVM instructions.
24//
25//===----------------------------------------------------------------------===//
26
27#include "InstructionBatcher.h"
28
29#include "llvm/IR/InstVisitor.h"
30
31#include "llvm/ADT/ArrayRef.h"
32#include "llvm/ADT/SmallVector.h"
33
34#include "llvm/Support/Casting.h"
35
36#include "llvm/IR/Constants.h"
37#include "llvm/IR/IRBuilder.h"
38#include "llvm/IR/Value.h"
39
40#include "llvm/Transforms/Utils/BasicBlockUtils.h"
41#include "llvm/Transforms/Utils/Cloning.h"
42#include "llvm/Transforms/Utils/ValueMapper.h"
43
44#include "DiffeGradientUtils.h"
45#include "GradientUtils.h"
46
47using namespace llvm;
48
50 Function *oldFunc, Function *newFunc, unsigned width,
51 ValueMap<const Value *, std::vector<Value *>> &vectorizedValues,
52 ValueToValueMapTy &originalToNewFn, SmallPtrSetImpl<Value *> &toVectorize,
53 EnzymeLogic &Logic)
54 : hasError(false), vectorizedValues(vectorizedValues),
55 originalToNewFn(originalToNewFn), toVectorize(toVectorize), width(width),
56 Logic(Logic) {}
57
58Value *InstructionBatcher::getNewOperand(unsigned int i, llvm::Value *op) {
59 if (auto meta = dyn_cast<MetadataAsValue>(op)) {
60 auto md = meta->getMetadata();
61 if (auto val = dyn_cast<ValueAsMetadata>(md))
62 return MetadataAsValue::get(
63 op->getContext(),
64 ValueAsMetadata::get(getNewOperand(i, val->getValue())));
65 }
66
67 if (isa<ConstantData>(op)) {
68 return op;
69 } else if (isa<Function>(op)) {
70 return op;
71 } else if (isa<GlobalValue>(op)) {
72 llvm::errs() << "unimplelemented GlobalValue!\n";
73 llvm_unreachable("unimplelemented GlobalValue!");
74 // TODO: !!!
75 } else if (toVectorize.count(op) != 0) {
76 auto found = vectorizedValues.find(op);
77 assert(found != vectorizedValues.end());
78 return found->second[i];
79 } else {
80 auto found = originalToNewFn.find(op);
81 assert(found != originalToNewFn.end());
82 return found->second;
83 }
84}
85
86void InstructionBatcher::visitInstruction(llvm::Instruction &inst) {
87 auto found = vectorizedValues.find(&inst);
88 assert(found != vectorizedValues.end());
89 auto placeholders = found->second;
90 Instruction *placeholder = cast<Instruction>(placeholders[0]);
91
92 for (unsigned i = 1; i < width; ++i) {
93 ValueToValueMapTy vmap;
94 Instruction *new_inst = placeholder->clone();
95 vmap[placeholder] = new_inst;
96
97 for (unsigned j = 0; j < inst.getNumOperands(); ++j) {
98 Value *op = inst.getOperand(j);
99
100 // Don't allow writing vectors to global memory, loading and splatting a
101 // global is fine though.
102 if (isa<GlobalValue>(op) && !isa<ConstantData>(op) &&
103 inst.mayWriteToMemory() && toVectorize.count(op) != 0) {
104 // TODO: handle buffer access
105 hasError = true;
106 EmitFailure("GlobalValueCannotBeVectorized", inst.getDebugLoc(), &inst,
107 "global variables have to be scalar values", inst);
108 return;
109 }
110
111 if (auto meta = dyn_cast<MetadataAsValue>(op))
112 if (!isa<ValueAsMetadata>(meta->getMetadata()))
113 continue;
114
115 Value *new_op = getNewOperand(i, op);
116 vmap[placeholder->getOperand(j)] = new_op;
117 }
118
119 if (placeholders.size() == width) {
120 // Instructions which return a value
121 Instruction *placeholder = cast<Instruction>(placeholders[i]);
122 assert(!placeholder->getType()->isVoidTy());
123
124 ReplaceInstWithInst(placeholder, new_inst);
125 vectorizedValues[&inst][i] = new_inst;
126 } else if (placeholders.size() == 1) {
127 // Instructions which don't return a value
128 assert(placeholder->getType()->isVoidTy());
129
130 Instruction *insertionPoint =
131 placeholder->getNextNode() ? placeholder->getNextNode() : placeholder;
132 IRBuilder<> Builder2(insertionPoint);
133 Builder2.SetCurrentDebugLocation(DebugLoc());
134 Builder2.Insert(new_inst);
135 vectorizedValues[&inst].push_back(new_inst);
136 } else {
137 llvm_unreachable("Unexpected number of values in mapping");
138 }
139
140 RemapInstruction(new_inst, vmap, RF_NoModuleLevelChanges);
141
142 if (!inst.getType()->isVoidTy() && inst.hasName())
143 new_inst->setName(inst.getName() + Twine(i));
144 }
145}
146
148 PHINode *placeholder = cast<PHINode>(vectorizedValues[&phi][0]);
149
150 for (unsigned i = 1; i < width; ++i) {
151 ValueToValueMapTy vmap;
152 Instruction *new_phi = placeholder->clone();
153 vmap[placeholder] = new_phi;
154
155 for (unsigned j = 0; j < phi.getNumIncomingValues(); ++j) {
156 Value *orig_block = phi.getIncomingBlock(j);
157 BasicBlock *new_block = cast<BasicBlock>(originalToNewFn[orig_block]);
158 Value *orig_val = phi.getIncomingValue(j);
159 Value *new_val = getNewOperand(i, orig_val);
160
161 vmap[placeholder->getIncomingValue(j)] = new_val;
162 vmap[new_block] = new_block;
163 }
164
165 RemapInstruction(new_phi, vmap, RF_NoModuleLevelChanges);
166 Instruction *placeholder = cast<Instruction>(vectorizedValues[&phi][i]);
167 ReplaceInstWithInst(placeholder, new_phi);
168 new_phi->setName(phi.getName());
169 vectorizedValues[&phi][i] = new_phi;
170 }
171}
172
173void InstructionBatcher::visitSwitchInst(llvm::SwitchInst &inst) {
174 // TODO: runtime check
175 hasError = true;
176 EmitFailure("SwitchConditionCannotBeVectorized", inst.getDebugLoc(), &inst,
177 "switch conditions have to be scalar values", inst);
178 return;
179}
180
181void InstructionBatcher::visitBranchInst(llvm::BranchInst &branch) {
182 // TODO: runtime check
183 hasError = true;
184 EmitFailure("BranchConditionCannotBeVectorized", branch.getDebugLoc(),
185 &branch, "branch conditions have to be scalar values", branch);
186 return;
187}
188
189void InstructionBatcher::visitReturnInst(llvm::ReturnInst &ret) {
190 auto found = originalToNewFn.find(ret.getParent());
191 assert(found != originalToNewFn.end());
192 BasicBlock *nBB = dyn_cast<BasicBlock>(&*found->second);
193 IRBuilder<> Builder2 = IRBuilder<>(nBB);
194 Builder2.SetCurrentDebugLocation(DebugLoc());
195 ReturnInst *placeholder = cast<ReturnInst>(nBB->getTerminator());
196 SmallVector<Value *, 4> rets;
197
198 for (unsigned j = 0; j < ret.getNumOperands(); ++j) {
199 Value *op = ret.getOperand(j);
200 for (unsigned i = 0; i < width; ++i) {
201 Value *new_op = getNewOperand(i, op);
202 rets.push_back(new_op);
203 }
204 }
205
206 if (ret.getNumOperands() != 0) {
207#if LLVM_VERSION_MAJOR > 22
208 auto ret = Builder2.CreateAggregateRet(rets);
209#else
210 auto ret = Builder2.CreateAggregateRet(rets.data(), width);
211#endif
212 ret->setDebugLoc(placeholder->getDebugLoc());
213 placeholder->eraseFromParent();
214 }
215}
216
217void InstructionBatcher::visitCallInst(llvm::CallInst &call) {
218 auto found = vectorizedValues.find(&call);
219 assert(found != vectorizedValues.end());
220 auto placeholders = found->second;
221 Instruction *placeholder = cast<Instruction>(placeholders[0]);
222 IRBuilder<> Builder2(placeholder);
223 Builder2.SetCurrentDebugLocation(DebugLoc());
224 auto orig_func = getFunctionFromCall(&call);
225
226 bool isDefined = !orig_func->isDeclaration();
227
228 if (!isDefined)
229 return visitInstruction(call);
230
231 SmallVector<Value *, 4> args;
232 SmallVector<BATCH_TYPE, 4> arg_types;
233#if LLVM_VERSION_MAJOR >= 14
234 for (unsigned j = 0; j < call.arg_size(); ++j) {
235#else
236 for (unsigned j = 0; j < call.getNumArgOperands(); ++j) {
237#endif
238 Value *op = call.getArgOperand(j);
239
240 if (toVectorize.count(op) != 0) {
241 Type *aggTy = GradientUtils::getShadowType(op->getType(), width);
242 Value *agg = UndefValue::get(aggTy);
243 for (unsigned i = 0; i < width; i++) {
244 auto found = vectorizedValues.find(op);
245 assert(found != vectorizedValues.end());
246 Value *new_op = found->second[i];
247 Builder2.CreateInsertValue(agg, new_op, {i});
248 }
249 args.push_back(agg);
250 arg_types.push_back(BATCH_TYPE::VECTOR);
251 } else if (isa<ConstantData>(op)) {
252 args.push_back(op);
253 arg_types.push_back(BATCH_TYPE::SCALAR);
254 } else {
255 auto found = originalToNewFn.find(op);
256 assert(found != originalToNewFn.end());
257 Value *arg = found->second;
258 args.push_back(arg);
259 arg_types.push_back(BATCH_TYPE::SCALAR);
260 }
261 }
262
263 BATCH_TYPE ret_type = orig_func->getReturnType()->isVoidTy()
266
267 Function *new_func = Logic.CreateBatch(RequestContext(&call, &Builder2),
268 orig_func, width, arg_types, ret_type);
269 CallInst *new_call = Builder2.CreateCall(new_func->getFunctionType(),
270 new_func, args, call.getName());
271
272 new_call->setDebugLoc(placeholder->getDebugLoc());
273
274 if (!call.getType()->isVoidTy()) {
275 for (unsigned i = 0; i < width; ++i) {
276 Instruction *placeholder = dyn_cast<Instruction>(placeholders[i]);
277 ExtractValueInst *ret = ExtractValueInst::Create(
278 new_call, {i},
279 "unwrap" + (call.hasName() ? "." + call.getName() + Twine(i) : ""));
280 ReplaceInstWithInst(placeholder, ret);
281 vectorizedValues[&call][i] = ret;
282 }
283 } else {
284 placeholder->replaceAllUsesWith(new_call);
285 placeholder->eraseFromParent();
286 }
287}
static Operation * getFunctionFromCall(CallOpInterface iface)
BATCH_TYPE
Definition Utils.h:385
void EmitFailure(llvm::StringRef RemarkName, const llvm::DiagnosticLocation &Loc, const llvm::Instruction *CodeRegion, Args &...args)
Definition Utils.h:203
llvm::Function * CreateBatch(RequestContext context, llvm::Function *tobatch, unsigned width, llvm::ArrayRef< BATCH_TYPE > arg_types, BATCH_TYPE ret_type)
Create a function batched in its inputs.
static llvm::Type * getShadowType(llvm::Type *ty, unsigned width)
void visitReturnInst(llvm::ReturnInst &ret)
void visitSwitchInst(llvm::SwitchInst &inst)
void visitBranchInst(llvm::BranchInst &branch)
void visitInstruction(llvm::Instruction &inst)
void visitPHINode(llvm::PHINode &phi)
InstructionBatcher(llvm::Function *oldFunc, llvm::Function *newFunc, unsigned width, llvm::ValueMap< const llvm::Value *, std::vector< llvm::Value * > > &vectorizedValues, llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > &originalToNewFn, llvm::SmallPtrSetImpl< llvm::Value * > &toVectorize, EnzymeLogic &Logic)
void visitCallInst(llvm::CallInst &call)