|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +import ast |
| 6 | +import pathlib |
| 7 | +import re |
| 8 | + |
| 9 | +from executorch.exir.dialects.edge.spec.utils import SAMPLE_INPUT as _SAMPLE_INPUT |
| 10 | + |
| 11 | +# Add all targets and TOSA profiles we support here. |
| 12 | +TARGETS = [ |
| 13 | + "tosa_FP", |
| 14 | + "tosa_INT", |
| 15 | + "tosa_INT+FP", |
| 16 | + "u55_INT", |
| 17 | + "u85_INT", |
| 18 | + "vgf_INT", |
| 19 | + "vgf_FP", |
| 20 | + "vgf_quant", |
| 21 | + "vgf_no_quant", |
| 22 | + "no_target", |
| 23 | +] |
| 24 | + |
| 25 | +# Add edge ops which we lower but which are not included in exir/dialects/edge/edge.yaml here. |
| 26 | +_CUSTOM_EDGE_OPS = [ |
| 27 | + "linspace.default", |
| 28 | + "cond.default", |
| 29 | + "eye.default", |
| 30 | + "expm1.default", |
| 31 | + "vector_norm.default", |
| 32 | + "hardsigmoid.default", |
| 33 | + "hardswish.default", |
| 34 | + "linear.default", |
| 35 | + "maximum.default", |
| 36 | + "mean.default", |
| 37 | + "multihead_attention.default", |
| 38 | + "adaptive_avg_pool2d.default", |
| 39 | + "bitwise_right_shift.Tensor", |
| 40 | + "bitwise_right_shift.Scalar", |
| 41 | + "bitwise_left_shift.Tensor", |
| 42 | + "bitwise_left_shift.Scalar", |
| 43 | + "native_group_norm.default", |
| 44 | + "silu.default", |
| 45 | + "sdpa.default", |
| 46 | + "sum.default", |
| 47 | + "unbind.int", |
| 48 | + "unflatten.int", |
| 49 | + "_native_batch_norm_legit_no_training.default", |
| 50 | + "_native_batch_norm_legit.no_stats", |
| 51 | + "alias_copy.default", |
| 52 | + "pixel_shuffle.default", |
| 53 | + "pixel_unshuffle.default", |
| 54 | + "while_loop.default", |
| 55 | + "matmul.default", |
| 56 | + "upsample_bilinear2d.vec", |
| 57 | + "upsample_nearest2d.vec", |
| 58 | +] |
| 59 | +_ALL_EDGE_OPS = _SAMPLE_INPUT.keys() | _CUSTOM_EDGE_OPS |
| 60 | + |
| 61 | +_NON_ARM_PASSES = ["quantize_io_pass"] |
| 62 | + |
| 63 | +_MODEL_ENTRY_PATTERN = re.compile(r"^\s*(?:[-*]|\d+\.)\s+(?P<entry>.+?)\s*$") |
| 64 | +_NUMERIC_SERIES_PATTERN = re.compile(r"(\d+)(?=[a-z])") |
| 65 | +_CAMEL_BOUNDARY = re.compile( |
| 66 | + r"(?<!^)(?:(?<=[a-z0-9])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z]))" |
| 67 | +) |
| 68 | + |
| 69 | + |
| 70 | +def _collect_arm_passes(init_path: pathlib.Path) -> set[str]: |
| 71 | + names: set[str] = set() |
| 72 | + names.update(_extract_pass_names_from_init(init_path)) |
| 73 | + names.update(_NON_ARM_PASSES) |
| 74 | + return {_separate_numeric_series(_strip_pass_suffix(name)) for name in names} |
| 75 | + |
| 76 | + |
| 77 | +def _extract_pass_names_from_init(init_path: pathlib.Path) -> set[str]: |
| 78 | + source = init_path.read_text(encoding="utf-8") |
| 79 | + module = ast.parse(source, filename=str(init_path)) |
| 80 | + names: set[str] = set() |
| 81 | + |
| 82 | + for node in module.body: |
| 83 | + if not isinstance(node, ast.ImportFrom): |
| 84 | + continue |
| 85 | + for alias in node.names: |
| 86 | + candidate = alias.asname or alias.name |
| 87 | + if not candidate or not candidate.endswith("Pass"): |
| 88 | + continue |
| 89 | + if candidate == "ArmPass": |
| 90 | + continue |
| 91 | + names.add(_camel_to_snake(candidate)) |
| 92 | + return names |
| 93 | + |
| 94 | + |
| 95 | +def _strip_pass_suffix(name: str) -> str: |
| 96 | + return name[:-5] if name.endswith("_pass") else name |
| 97 | + |
| 98 | + |
| 99 | +def _separate_numeric_series(name: str) -> str: |
| 100 | + def repl(match: re.Match[str]) -> str: |
| 101 | + next_index = match.end() |
| 102 | + next_char = match.string[next_index] if next_index < len(match.string) else "" |
| 103 | + if next_char == "d": # Avoid creating patterns like 3_d |
| 104 | + return match.group(1) |
| 105 | + return f"{match.group(1)}_" |
| 106 | + |
| 107 | + return _NUMERIC_SERIES_PATTERN.sub(repl, name) |
| 108 | + |
| 109 | + |
| 110 | +def _collect_arm_models(models_md: pathlib.Path) -> set[str]: |
| 111 | + models: set[str] = set() |
| 112 | + for line in models_md.read_text(encoding="utf-8").splitlines(): |
| 113 | + stripped = line.strip() |
| 114 | + if not stripped or stripped.startswith("#"): |
| 115 | + continue |
| 116 | + match = _MODEL_ENTRY_PATTERN.match(line) |
| 117 | + if not match: |
| 118 | + continue |
| 119 | + base, alias, is_parent = _split_model_entry(match.group("entry")) |
| 120 | + if is_parent: |
| 121 | + continue |
| 122 | + if alias: |
| 123 | + models.add(_normalize_model_entry(alias)) |
| 124 | + else: |
| 125 | + models.add(_normalize_model_entry(base)) |
| 126 | + |
| 127 | + if not models: |
| 128 | + raise RuntimeError(f"No supported models found in {models_md}") |
| 129 | + return models |
| 130 | + |
| 131 | + |
| 132 | +def _collect_arm_ops() -> set[str]: |
| 133 | + """ |
| 134 | + Returns a mapping from names on the form to be used in unittests to edge op: |
| 135 | + 1. Names are in lowercase. |
| 136 | + 2. Overload is ignored if 'default', otherwise it's appended with an underscore. |
| 137 | + 3. Overly verbose name are shortened by removing certain prefixes/suffixes. |
| 138 | +
|
| 139 | + Examples: |
| 140 | + abs.default -> abs |
| 141 | + split_copy.Tensor -> split_tensor |
| 142 | + """ |
| 143 | + ops: set[str] = set() |
| 144 | + for edge_name in _ALL_EDGE_OPS: |
| 145 | + op, overload = edge_name.split(".") |
| 146 | + |
| 147 | + # Normalize names |
| 148 | + op = op.lower() |
| 149 | + op = op.removeprefix("_") |
| 150 | + op = op.removesuffix("_copy") |
| 151 | + op = op.removesuffix("_with_indices") |
| 152 | + overload = overload.lower() |
| 153 | + |
| 154 | + if overload == "default": |
| 155 | + ops.add(op) |
| 156 | + else: |
| 157 | + ops.add(f"{op}_{overload}") |
| 158 | + |
| 159 | + return ops |
| 160 | + |
| 161 | + |
| 162 | +def _split_model_entry(entry: str) -> tuple[str, str | None, bool]: |
| 163 | + entry = entry.strip() |
| 164 | + if not entry: |
| 165 | + return "", None, False |
| 166 | + is_parent = entry.endswith(":") |
| 167 | + if is_parent: |
| 168 | + entry = entry[:-1].rstrip() |
| 169 | + if "(" in entry and entry.endswith(")"): |
| 170 | + base, _, rest = entry.partition("(") |
| 171 | + alias = rest[:-1].strip() |
| 172 | + return base.strip(), alias or None, is_parent |
| 173 | + return entry, None, is_parent |
| 174 | + |
| 175 | + |
| 176 | +def _normalize_model_entry(name: str) -> str: |
| 177 | + cleaned = name.lower() |
| 178 | + cleaned = re.sub(r"[^a-z0-9\s]", "", cleaned) |
| 179 | + cleaned = re.sub(r"\s+", " ", cleaned).strip() |
| 180 | + return cleaned.replace(" ", "_") |
| 181 | + |
| 182 | + |
| 183 | +def _camel_to_snake(name: str) -> str: |
| 184 | + if not name: |
| 185 | + return "" |
| 186 | + name = name.replace("-", "_").replace(" ", "_") |
| 187 | + return _CAMEL_BOUNDARY.sub("_", name).lower() |
| 188 | + |
| 189 | + |
| 190 | +OP_LIST = sorted(_collect_arm_ops()) |
| 191 | +PASS_LIST = sorted( |
| 192 | + _collect_arm_passes(pathlib.Path("backends/arm/_passes/__init__.py")) |
| 193 | +) |
| 194 | +MODEL_LIST = sorted(_collect_arm_models(pathlib.Path("backends/arm/MODELS.md"))) |
0 commit comments