Enzyme main
Loading...
Searching...
No Matches
ImpulseUtils.cpp
Go to the documentation of this file.
1//===- ImpulseUtils.cpp - Utilities for Impulse dialect passes ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "Dialect/Ops.h"
11#include "mlir/IR/Matchers.h"
12#include "mlir/IR/SymbolTable.h"
13#include "mlir/Interfaces/FunctionInterfaces.h"
14
15// TODO: this shouldn't depend on specific dialects except Enzyme.
16#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
17
18#include "CloneFunction.h"
19#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
20#include "mlir/Dialect/Func/IR/FuncOps.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/Dominance.h"
23#include "llvm/ADT/BreadthFirstIterator.h"
24
25using namespace mlir;
26using namespace mlir::impulse;
27
28ImpulseUtils *ImpulseUtils::CreateFromClone(FunctionOpInterface toeval,
29 ImpulseMode mode,
30 int64_t positionSize,
31 int64_t constraintSize) {
32 if (toeval.getFunctionBody().empty()) {
33 llvm::errs() << toeval << "\n";
34 llvm_unreachable("Creating ImpulseUtils from empty function");
35 }
36
37 OpBuilder builder(toeval.getContext());
38
39 std::string suffix;
40 auto originalInputs =
41 cast<mlir::FunctionType>(toeval.getFunctionType()).getInputs();
42 auto originalResults =
43 cast<mlir::FunctionType>(toeval.getFunctionType()).getResults();
44 SmallVector<mlir::Type, 4> OperandTypes;
45 SmallVector<mlir::Type, 4> ResultTypes;
46
47 switch (mode) {
49 suffix = "call";
50 OperandTypes.append(originalInputs.begin(), originalInputs.end());
51 ResultTypes.append(originalResults.begin(), originalResults.end());
52 break;
54 suffix = "generate";
55 if (positionSize <= 0 || constraintSize < 0) {
56 toeval.emitError("Impulse: Unexpected size parameters");
57 return nullptr;
58 }
59 OperandTypes.push_back(
60 RankedTensorType::get({1, constraintSize}, builder.getF64Type()));
61 OperandTypes.append(originalInputs.begin(), originalInputs.end());
62 ResultTypes.push_back(
63 RankedTensorType::get({1, positionSize}, builder.getF64Type()));
64 ResultTypes.push_back(RankedTensorType::get({}, builder.getF64Type()));
65 ResultTypes.append(originalResults.begin(), originalResults.end());
66 break;
68 suffix = "regenerate";
69 if (positionSize < 0) {
70 toeval.emitError("Impulse: Unexpected size parameters");
71 return nullptr;
72 }
73 OperandTypes.push_back(
74 RankedTensorType::get({1, positionSize}, builder.getF64Type()));
75 OperandTypes.append(originalInputs.begin(), originalInputs.end());
76 ResultTypes.push_back(
77 RankedTensorType::get({1, positionSize}, builder.getF64Type()));
78 ResultTypes.push_back(RankedTensorType::get({}, builder.getF64Type()));
79 ResultTypes.append(originalResults.begin(), originalResults.end());
80 break;
82 suffix = "simulate";
83 if (positionSize < 0) {
84 toeval.emitError("Impulse: Unexpected size parameters");
85 return nullptr;
86 }
87 OperandTypes.append(originalInputs.begin(), originalInputs.end());
88 ResultTypes.push_back(
89 RankedTensorType::get({1, positionSize}, builder.getF64Type()));
90 ResultTypes.push_back(RankedTensorType::get({}, builder.getF64Type()));
91 ResultTypes.append(originalResults.begin(), originalResults.end());
92 break;
93 default:
94 llvm_unreachable("Invalid ImpulseMode\n");
95 }
96
97 auto FTy = builder.getFunctionType(OperandTypes, ResultTypes);
98 auto NewF = cast<FunctionOpInterface>(toeval->cloneWithoutRegions());
99 SymbolTable::setSymbolName(NewF, toeval.getName().str() + "." + suffix);
100 NewF.setType(FTy);
101
102 Operation *parent = toeval->getParentWithTrait<OpTrait::SymbolTable>();
103 SymbolTable table(parent);
104 table.insert(NewF);
105
106 IRMapping originalToNew;
107 std::map<Operation *, Operation *> originalToNewOps;
108 cloneInto(&toeval.getFunctionBody(), &NewF.getFunctionBody(), originalToNew,
109 originalToNewOps);
110
112 Block &entry = NewF.getFunctionBody().front();
113 entry.insertArgument(
114 0u, RankedTensorType::get({1, constraintSize}, builder.getF64Type()),
115 toeval.getLoc());
116 }
117
119 Block &entry = NewF.getFunctionBody().front();
120 entry.insertArgument(
121 0u, RankedTensorType::get({1, positionSize}, builder.getF64Type()),
122 toeval.getLoc());
123 }
124
125 return new ImpulseUtils(NewF, toeval, originalToNew, originalToNewOps, mode);
126}
void cloneInto(Region *src, Region *dest, IRMapping &mapper, std::map< Operation *, Operation * > &opMap)
static ImpulseUtils * CreateFromClone(FunctionOpInterface toeval, ImpulseMode mode, int64_t positionSize=-1, int64_t constraintSize=-1)
ImpulseUtils(FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, IRMapping &originalToNewFn_, std::map< Operation *, Operation * > &originalToNewFnOps_, ImpulseMode mode_)