50 Function *oldFunc, Function *newFunc,
unsigned width,
51 ValueMap<
const Value *, std::vector<Value *>> &vectorizedValues,
52 ValueToValueMapTy &originalToNewFn, SmallPtrSetImpl<Value *> &toVectorize,
54 : hasError(false), vectorizedValues(vectorizedValues),
55 originalToNewFn(originalToNewFn), toVectorize(toVectorize), width(width),
87 auto found = vectorizedValues.find(&inst);
88 assert(found != vectorizedValues.end());
89 auto placeholders = found->second;
90 Instruction *placeholder = cast<Instruction>(placeholders[0]);
92 for (
unsigned i = 1; i < width; ++i) {
93 ValueToValueMapTy vmap;
94 Instruction *new_inst = placeholder->clone();
95 vmap[placeholder] = new_inst;
97 for (
unsigned j = 0; j < inst.getNumOperands(); ++j) {
98 Value *op = inst.getOperand(j);
102 if (isa<GlobalValue>(op) && !isa<ConstantData>(op) &&
103 inst.mayWriteToMemory() && toVectorize.count(op) != 0) {
106 EmitFailure(
"GlobalValueCannotBeVectorized", inst.getDebugLoc(), &inst,
107 "global variables have to be scalar values", inst);
111 if (
auto meta = dyn_cast<MetadataAsValue>(op))
112 if (!isa<ValueAsMetadata>(meta->getMetadata()))
115 Value *new_op = getNewOperand(i, op);
116 vmap[placeholder->getOperand(j)] = new_op;
119 if (placeholders.size() == width) {
121 Instruction *placeholder = cast<Instruction>(placeholders[i]);
122 assert(!placeholder->getType()->isVoidTy());
124 ReplaceInstWithInst(placeholder, new_inst);
125 vectorizedValues[&inst][i] = new_inst;
126 }
else if (placeholders.size() == 1) {
128 assert(placeholder->getType()->isVoidTy());
130 Instruction *insertionPoint =
131 placeholder->getNextNode() ? placeholder->getNextNode() : placeholder;
132 IRBuilder<> Builder2(insertionPoint);
133 Builder2.SetCurrentDebugLocation(DebugLoc());
134 Builder2.Insert(new_inst);
135 vectorizedValues[&inst].push_back(new_inst);
137 llvm_unreachable(
"Unexpected number of values in mapping");
140 RemapInstruction(new_inst, vmap, RF_NoModuleLevelChanges);
142 if (!inst.getType()->isVoidTy() && inst.hasName())
143 new_inst->setName(inst.getName() + Twine(i));
148 PHINode *placeholder = cast<PHINode>(vectorizedValues[&phi][0]);
150 for (
unsigned i = 1; i < width; ++i) {
151 ValueToValueMapTy vmap;
152 Instruction *new_phi = placeholder->clone();
153 vmap[placeholder] = new_phi;
155 for (
unsigned j = 0; j < phi.getNumIncomingValues(); ++j) {
156 Value *orig_block = phi.getIncomingBlock(j);
157 BasicBlock *new_block = cast<BasicBlock>(originalToNewFn[orig_block]);
158 Value *orig_val = phi.getIncomingValue(j);
159 Value *new_val = getNewOperand(i, orig_val);
161 vmap[placeholder->getIncomingValue(j)] = new_val;
162 vmap[new_block] = new_block;
165 RemapInstruction(new_phi, vmap, RF_NoModuleLevelChanges);
166 Instruction *placeholder = cast<Instruction>(vectorizedValues[&phi][i]);
167 ReplaceInstWithInst(placeholder, new_phi);
168 new_phi->setName(phi.getName());
169 vectorizedValues[&phi][i] = new_phi;
190 auto found = originalToNewFn.find(ret.getParent());
191 assert(found != originalToNewFn.end());
192 BasicBlock *nBB = dyn_cast<BasicBlock>(&*found->second);
193 IRBuilder<> Builder2 = IRBuilder<>(nBB);
194 Builder2.SetCurrentDebugLocation(DebugLoc());
195 ReturnInst *placeholder = cast<ReturnInst>(nBB->getTerminator());
196 SmallVector<Value *, 4> rets;
198 for (
unsigned j = 0; j < ret.getNumOperands(); ++j) {
199 Value *op = ret.getOperand(j);
200 for (
unsigned i = 0; i < width; ++i) {
201 Value *new_op = getNewOperand(i, op);
202 rets.push_back(new_op);
206 if (ret.getNumOperands() != 0) {
207#if LLVM_VERSION_MAJOR > 22
208 auto ret = Builder2.CreateAggregateRet(rets);
210 auto ret = Builder2.CreateAggregateRet(rets.data(), width);
212 ret->setDebugLoc(placeholder->getDebugLoc());
213 placeholder->eraseFromParent();
218 auto found = vectorizedValues.find(&call);
219 assert(found != vectorizedValues.end());
220 auto placeholders = found->second;
221 Instruction *placeholder = cast<Instruction>(placeholders[0]);
222 IRBuilder<> Builder2(placeholder);
223 Builder2.SetCurrentDebugLocation(DebugLoc());
226 bool isDefined = !orig_func->isDeclaration();
231 SmallVector<Value *, 4> args;
232 SmallVector<BATCH_TYPE, 4> arg_types;
233#if LLVM_VERSION_MAJOR >= 14
234 for (
unsigned j = 0; j < call.arg_size(); ++j) {
236 for (
unsigned j = 0; j < call.getNumArgOperands(); ++j) {
238 Value *op = call.getArgOperand(j);
240 if (toVectorize.count(op) != 0) {
242 Value *agg = UndefValue::get(aggTy);
243 for (
unsigned i = 0; i < width; i++) {
244 auto found = vectorizedValues.find(op);
245 assert(found != vectorizedValues.end());
246 Value *new_op = found->second[i];
247 Builder2.CreateInsertValue(agg, new_op, {i});
251 }
else if (isa<ConstantData>(op)) {
255 auto found = originalToNewFn.find(op);
256 assert(found != originalToNewFn.end());
257 Value *arg = found->second;
263 BATCH_TYPE ret_type = orig_func->getReturnType()->isVoidTy()
268 orig_func, width, arg_types, ret_type);
269 CallInst *new_call = Builder2.CreateCall(new_func->getFunctionType(),
270 new_func, args, call.getName());
272 new_call->setDebugLoc(placeholder->getDebugLoc());
274 if (!call.getType()->isVoidTy()) {
275 for (
unsigned i = 0; i < width; ++i) {
276 Instruction *placeholder = dyn_cast<Instruction>(placeholders[i]);
277 ExtractValueInst *ret = ExtractValueInst::Create(
279 "unwrap" + (call.hasName() ?
"." + call.getName() + Twine(i) :
""));
280 ReplaceInstWithInst(placeholder, ret);
281 vectorizedValues[&call][i] = ret;
284 placeholder->replaceAllUsesWith(new_call);
285 placeholder->eraseFromParent();
llvm::Function * CreateBatch(RequestContext context, llvm::Function *tobatch, unsigned width, llvm::ArrayRef< BATCH_TYPE > arg_types, BATCH_TYPE ret_type)
Create a function batched in its inputs.
InstructionBatcher(llvm::Function *oldFunc, llvm::Function *newFunc, unsigned width, llvm::ValueMap< const llvm::Value *, std::vector< llvm::Value * > > &vectorizedValues, llvm::ValueMap< const llvm::Value *, llvm::WeakTrackingVH > &originalToNewFn, llvm::SmallPtrSetImpl< llvm::Value * > &toVectorize, EnzymeLogic &Logic)