Skip to content

Commit

Permalink
Reduction cache (pytorch#487)
Browse files Browse the repository at this point in the history
* enable cse for reductions

* self.cse
  • Loading branch information
ngimel committed Jun 30, 2022
1 parent b3c3b9e commit b22b6a5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
2 changes: 2 additions & 0 deletions torchinductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,12 +429,14 @@ def __init__(
name_prefix="tmp",
iter_buffers=None,
store_cache=None,
reduction_cache=None,
):
self.prefix = prefix
self.suffix = suffix
self.cache = {}
self.name_prefix = name_prefix
self.store_cache = store_cache or {}
self.reduction_cache = reduction_cache or {}
self.iter_buffer_ids = iter_buffers or itertools.count()

def invalidate(self, keep_vars: typing.Set[str]):
Expand Down
45 changes: 25 additions & 20 deletions torchinductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,29 +592,34 @@ def reduction(self, name, dtype, reduction_type, index, value):

dim = len(self.range_trees) - 1
result_var = self.cse.newvar()
accumulator = f"_{result_var}"
self.body.writeline(
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(dtype)}) + {default}"
)
if (dtype, reduction_type, value) not in self.cse.reduction_cache:
self.cse.reduction_cache[(dtype, reduction_type, value)] = result_var
accumulator = f"_{result_var}"
self.body.writeline(
f"{accumulator} = tl.zeros({self.dense_size_str()}, {triton_compute_type(dtype)}) + {default}"
)

updated = value
if reduction_type == "min":
masks.append(f"({accumulator} > {value})")
elif reduction_type == "max":
masks.append(f"({accumulator} < {value})")
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
else:
raise NotImplementedError(f"reduction_type {reduction_type}")
updated = value
if reduction_type == "min":
masks.append(f"({accumulator} > {value})")
elif reduction_type == "max":
masks.append(f"({accumulator} < {value})")
elif reduction_type == "sum":
updated = f"{accumulator} + {value}"
else:
raise NotImplementedError(f"reduction_type {reduction_type}")

cond = " & ".join(masks)
self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)
cond = " & ".join(masks)
self.compute.writeline(
f"{accumulator} = tl.where({cond}, {updated}, {accumulator})"
)

self.suffix.writeline(
f"{result_var} = tl.reshape(tl.{reduction_type}({accumulator}, {dim}), [{', '.join(sizes)}])"
)
self.suffix.writeline(
f"{result_var} = tl.reshape(tl.{reduction_type}({accumulator}, {dim}), [{', '.join(sizes)}])"
)
else:
var_name = self.cse.reduction_cache[(dtype, reduction_type, value)]
self.suffix.writeline(f"{result_var} = {var_name}")
self.inside_reduction = False
index, mask = self.indexing(index, result_var)
assert "rmask" not in index
Expand Down

0 comments on commit b22b6a5

Please sign in to comment.