67 SmallVector<Activity> inActivity;
68 SmallVector<Activity> retActivity;
69 SmallVector<Value> in_args;
73 for (
auto [idx, act] : llvm::enumerate(uop.getActivity())) {
74 auto iattr = cast<ActivityAttr>(act);
75 auto val = iattr.getValue();
76 inActivity.push_back(val);
78 in_args.push_back(uop.getInputs()[in_idx]);
81 if (val == Activity::enzyme_dup || val == Activity::enzyme_dupnoneed) {
86 for (
auto [idx, ract] : llvm::enumerate(uop.getRetActivity())) {
87 auto iattr = cast<ActivityAttr>(ract);
88 auto val = iattr.getValue();
89 retActivity.push_back(val);
102 SourceOp callerOp, FunctionOpInterface innerFnOp,
103 const SmallVector<MemoryEffects::EffectInstance> &innerEffects) {
104 SmallVector<MemoryEffects::EffectInstance> outerEffects;
105 for (
auto &eff : innerEffects) {
107 Value effVal = eff.getValue();
110 outerEffects.push_back(eff);
115 size_t primalArgPos = 0;
116 bool foundPrimal =
false;
117 if (
auto effBA = dyn_cast<BlockArgument>(effVal)) {
118 if (llvm::is_contained(innerFnOp.getArguments(), effBA)) {
120 primalArgPos = effBA.getArgNumber();
133 Value primalVal = callerOp.getPrimalInputs()[primalArgPos];
134 outerEffects.push_back(
143 (cast<ActivityAttr>(callerOp.getActivity()[primalArgPos]).getValue() ==
144 Activity::enzyme_dup) ||
145 (cast<ActivityAttr>(callerOp.getActivity()[primalArgPos]).getValue() ==
146 Activity::enzyme_dupnoneed);
150 for (
auto [idx, act] : llvm::enumerate(callerOp.getActivity())) {
151 auto iattr = cast<ActivityAttr>(act);
152 auto act_val = iattr.getValue();
155 if (idx == primalArgPos)
158 if (act_val == Activity::enzyme_dup ||
159 act_val == Activity::enzyme_dupnoneed) {
164 Value dVal = callerOp.getInputs()[gradArgPos];
166 if constexpr (std::is_same_v<SourceOp, ForwardDiffOp>) {
167 outerEffects.push_back(
171 dVal, MemoryEffects::Write::get(), eff.getResource()));
173 dVal, MemoryEffects::Read::get(), eff.getResource()));
185 SmallVector<SourceOp> &allDiffs) {
186 SmallVector<SourceOp, 2> prunedSources;
191 auto firstDiffOp = allDiffs[0];
192 for (
auto uop : allDiffs) {
194 auto diffArgs = uop.getShadows();
195 if constexpr (std::is_same_v<SourceOp, AutoDiffOp>) {
196 auto diffeRet = uop.getDifferentialReturns();
197 diffArgs.append(diffeRet.begin(), diffeRet.end());
200 bool definedBeforeFirst =
true;
202 for (
auto diffVal : diffArgs) {
203 if (
auto diffValOR = dyn_cast<OpResult>(diffVal)) {
205 auto parentOp = diffValOR.getOwner();
206 if (!parentOp->isBeforeInBlock(firstDiffOp.getOperation())) {
207 definedBeforeFirst =
false;
213 if (definedBeforeFirst) {
214 prunedSources.push_back(uop);
218 return prunedSources;
227 SmallVector<SourceOp> &prunedSources,
228 DenseMap<SourceOp, SmallVector<MemoryEffects::EffectInstance>>
230 llvm::DenseMap<FunctionOpInterface,
231 SmallVector<MemoryEffects::EffectInstance>>
240 if (callerEffectMap.empty()) {
242 return prunedSources;
245 SmallVector<SourceOp> legalMerge;
246 auto lastOp = prunedSources[0];
248 SmallVector<MemoryEffects::EffectInstance, 4> betweenEffects;
249 for (
auto candidateOp : prunedSources) {
252 for (Operation *curr = lastOp.getOperation();
253 curr != candidateOp.getOperation(); curr = curr->getNextNode()) {
254 auto currSourceOp = dyn_cast<SourceOp>(curr);
255 if (currSourceOp && callerEffectMap.contains(currSourceOp)) {
258 betweenEffects.append(callerEffectMap[currSourceOp]);
259 }
else if (
auto currFwdOp = dyn_cast<ForwardDiffOp>(curr)) {
261 auto fnOp = dyn_cast_or_null<FunctionOpInterface>(
262 symbolTable.lookupNearestSymbolFrom(currFwdOp,
263 currFwdOp.getFnAttr()));
267 if (!innerEffectCache.contains(fnOp)) {
273 currFwdOp, fnOp, innerEffectCache[fnOp]));
274 }
else if (
auto currBwdOp = dyn_cast<AutoDiffOp>(curr)) {
276 auto fnOp = dyn_cast_or_null<FunctionOpInterface>(
277 symbolTable.lookupNearestSymbolFrom(currBwdOp,
278 currBwdOp.getFnAttr()));
282 if (!innerEffectCache.contains(fnOp)) {
288 currBwdOp, fnOp, innerEffectCache[fnOp]));
293 SmallVector<MemoryEffects::EffectInstance> currOpEffects;
295 betweenEffects.append(currOpEffects);
303 bool foundConflict =
false;
304 for (
auto candidateEffect : callerEffectMap[candidateOp]) {
305 for (
auto prevEffect : betweenEffects) {
310 if ((isa<MemoryEffects::Write>(prevEffect.getEffect()) &&
311 isa<MemoryEffects::Read>(candidateEffect.getEffect())) ||
312 (isa<MemoryEffects::Read>(prevEffect.getEffect()) &&
313 isa<MemoryEffects::Write>(candidateEffect.getEffect())) ||
314 (isa<MemoryEffects::Write>(prevEffect.getEffect()) &&
315 isa<MemoryEffects::Write>(candidateEffect.getEffect()))) {
320 foundConflict =
true;
328 legalMerge.push_back(candidateOp);
331 lastOp = candidateOp;