Enzyme main
Loading...
Searching...
No Matches
Ops.h
Go to the documentation of this file.
1//===- EnzymeOps.h - Enzyme dialect ops -------------------------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef ENZYMEOPS_H
10#define ENZYMEOPS_H
11
12#include <type_traits>
13
14#include "mlir/Dialect/Arith/IR/Arith.h"
15#include "mlir/IR/BuiltinTypes.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/OpDefinition.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/Interfaces/ControlFlowInterfaces.h"
20#include "mlir/Interfaces/MemorySlotInterfaces.h"
21#include "mlir/Interfaces/SideEffectInterfaces.h"
22
23#include "mlir/Bytecode/BytecodeOpInterface.h"
24
25#include "Dialect/EnzymeAttributeInterfaces.h.inc"
26#include "Dialect/EnzymeEnums.h.inc"
27
28#define GET_ATTRDEF_CLASSES
29#include "Dialect/EnzymeAttributes.h.inc"
30
31#define GET_TYPEDEF_CLASSES
32#include "Dialect/EnzymeOpsTypes.h.inc"
33
34// forward declare Enzyme op definitions
35#include "Dialect/EnzymeOps.h.inc"
36
37namespace mlir {
38namespace enzyme {
39namespace detail {
40
41// For any differentiation op, we either return input primal values or selective
42// derivative values. When `filterGrad` is true, `includeShadows` controls
43// whether input shadow arguments (activity `enzyme_dup` / `enzyme_dupnoneed`)
44// are collected, while `includeDifferentialReturns` controls whether
45// reverse-mode output shadows (`enzyme_active` / `enzyme_activenoneed`) are
46// collected.
47template <typename SourceOp, bool filterGrad, bool includeShadows = true,
48 bool includeDifferentialReturns = true>
49llvm::SmallVector<mlir::Value, 2> filterGradInputs(SourceOp uop) {
50 llvm::SmallVector<mlir::Value, 2> outs;
51 size_t in_idx = 0;
52
53 for (auto act : uop.getActivity()) {
54 auto iattr = cast<ActivityAttr>(act);
55 auto act_val = iattr.getValue();
56
57 if constexpr (!filterGrad) {
58 outs.push_back(uop.getInputs()[in_idx]);
59 }
60
61 ++in_idx;
62
63 if (act_val == Activity::enzyme_dup ||
64 act_val == Activity::enzyme_dupnoneed) {
65
66 if constexpr (filterGrad && includeShadows) {
67 outs.push_back(uop.getInputs()[in_idx]);
68 }
69
70 ++in_idx;
71 }
72 }
73
74 // For reverse mode AD, add derivative values corresponding to active outputs
75 // clang-format off
76 if constexpr ((std::is_same_v<SourceOp, AutoDiffOp> ||
77 std::is_same_v<SourceOp, AutoDiffRegionOp>) &&
78 filterGrad && includeDifferentialReturns) {
79 // clang-format on
80 if (in_idx != uop.getInputs().size()) {
81 for (auto act : uop.getRetActivity()) {
82 auto iattr = cast<ActivityAttr>(act);
83 auto act_val = iattr.getValue();
84
85 if (act_val == Activity::enzyme_active ||
86 act_val == Activity::enzyme_activenoneed) {
87 outs.push_back(uop.getInputs()[in_idx]);
88 in_idx++;
89 }
90 }
91 }
92 }
93
94 return outs;
95}
96
97} // namespace detail
98} // namespace enzyme
99} // namespace mlir
100
101#define GET_OP_CLASSES
102#include "Dialect/EnzymeOps.h.inc"
103
104// #include "Dialect/EnzymeTypes.h.inc"
105
106#endif // ENZYMEOPS_H
llvm::SmallVector< mlir::Value, 2 > filterGradInputs(SourceOp uop)
Definition Ops.h:49