Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/cube/vectorization #1781

Merged
merged 24 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Ben WIP
  • Loading branch information
nathanielsimard committed May 16, 2024
commit 91624ff3e4d5f3261efda546e0856b29d414f4a8
14 changes: 6 additions & 8 deletions crates/burn-cube/src/codegen/dialect/scope.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use crate::PrimitiveVariable;

use super::{
cpa, processing::ScopeProcessing, Elem, IndexOffsetGlobalWithLayout, Item, Operation, Operator,
Procedure, ReadGlobal, ReadGlobalWithLayout, UnaryOperator, Variable, Vectorization,
Expand Down Expand Up @@ -69,13 +67,13 @@ impl Scope {
}

/// Create a variable initialized at some value.
pub fn create_with_value<E: PrimitiveVariable, I: Into<Item> + Copy>(
&mut self,
value: E,
item: I,
) -> Variable {
pub fn create_with_value<E, I>(&mut self, value: E, item: I) -> Variable
where
E: Into<f64>,
I: Into<Item> + Copy,
{
let local = self.create_local(item);
let value = Variable::ConstantScalar(value.to_f64(), item.into().elem());
let value = Variable::ConstantScalar(value.into(), item.into().elem());
cpa!(self, local = value);
local
}
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cube/src/codegen/dialect/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub enum Variable {
}

impl Variable {
pub(crate) fn index(&self) -> Option<u16> {
pub fn index(&self) -> Option<u16> {
match self {
Variable::GlobalInputArray(idx, _) => Some(*idx),
Variable::GlobalScalar(idx, _) => Some(*idx),
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-cube/src/codegen/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeE
Elem::Bool => panic!("Bool scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
element_priority(E1::elem()),
element_priority(E2::elem()),
element_priority(E3::elem()),
element_priority(E1::cube_elem()),
element_priority(E2::cube_elem()),
element_priority(E3::cube_elem()),
];

let mut handles_scalars = Vec::new();
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-cube/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub trait GpuComputeShaderPhase: Send + Sync + 'static {
}
}

pub(crate) fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> WorkGroup {
pub fn elemwise_workgroup(num_elems: usize, workgroup_size: usize) -> WorkGroup {
let num_elem_per_invocation = workgroup_size * workgroup_size;
let workgroups = f32::ceil(num_elems as f32 / num_elem_per_invocation as f32);
let workgroup_x = f32::ceil(f32::sqrt(workgroups));
Expand Down
12 changes: 6 additions & 6 deletions crates/burn-cube/src/pod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemu
/// Convert a slice of bytes to a slice of elements.
fn from_bytes(bytes: &[u8]) -> &[Self];
/// Element representation for `cubecl`.
fn elem() -> Elem;
fn cube_elem() -> Elem;
/// Highest possible value
fn maximum_value() -> Self;
/// Lowest possible value
Expand All @@ -29,7 +29,7 @@ impl CubeElement for u32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn elem() -> Elem {
fn cube_elem() -> Elem {
Elem::UInt
}
fn maximum_value() -> Self {
Expand All @@ -50,7 +50,7 @@ impl CubeElement for i32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn elem() -> Elem {
fn cube_elem() -> Elem {
Elem::Int(IntKind::I32)
}
fn maximum_value() -> Self {
Expand All @@ -73,7 +73,7 @@ impl CubeElement for f32 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn elem() -> Elem {
fn cube_elem() -> Elem {
Elem::Float(FloatKind::F32)
}
fn maximum_value() -> Self {
Expand All @@ -94,7 +94,7 @@ impl CubeElement for half::f16 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn elem() -> Elem {
fn cube_elem() -> Elem {
Elem::Float(FloatKind::F16)
}
fn maximum_value() -> Self {
Expand All @@ -115,7 +115,7 @@ impl CubeElement for half::bf16 {
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn elem() -> Elem {
fn cube_elem() -> Elem {
Elem::Float(FloatKind::BF16)
}
fn maximum_value() -> Self {
Expand Down
1 change: 1 addition & 0 deletions crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export_tests = [
[dependencies]
burn-common = { path = "../burn-common", version = "0.15.0" }
burn-tensor = { path = "../burn-tensor", version = "0.15.0" }
burn-cube = { path = "../burn-cube", version = "0.15.0" }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }

bytemuck = { workspace = true }
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{tensor::JitTensor, FloatElement, IntElement, PrecisionBridge, Runtime};
use burn_tensor::backend::Backend;
use burn_tensor::backend::{Backend, DeviceOps};
use rand::{rngs::StdRng, SeedableRng};
use std::{marker::PhantomData, sync::Mutex};

Expand All @@ -16,6 +16,7 @@ pub struct JitBackend<R: Runtime, F: FloatElement, I: IntElement> {
impl<R, F, I> Backend for JitBackend<R, F, I>
where
R: Runtime,
R::Device: DeviceOps,
F: FloatElement,
I: IntElement,
{
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
kernel, ops::to_device, tensor::JitTensor, FloatElement, IntElement, JitBackend, Runtime,
};
use burn_tensor::{
backend::BackendBridge,
backend::{BackendBridge, DeviceOps},
ops::{FloatElem, FloatTensor},
};
use core::marker::PhantomData;
Expand All @@ -19,6 +19,7 @@ impl<R, FOrigin, IOrigin, FTarget, ITarget> BackendBridge<JitBackend<R, FOrigin,
for PrecisionBridge<R, FTarget, ITarget>
where
R: Runtime,
R::Device: DeviceOps,
FOrigin: FloatElement,
IOrigin: IntElement,
FTarget: FloatElement,
Expand Down
131 changes: 9 additions & 122 deletions crates/burn-jit/src/element.rs
Original file line number Diff line number Diff line change
@@ -1,136 +1,23 @@
use crate::codegen::dialect::gpu;
use burn_tensor::Element;
use burn_cube::CubeElement;

/// The base element trait for the jit backend.
pub trait JitElement:
burn_tensor::Element + core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod
{
/// TODO: Remove when all wgsl static kernels are migrated.
fn type_name() -> &'static str;
/// Convert a slice of elements to a slice of bytes.
fn as_bytes(slice: &[Self]) -> &[u8];
/// Convert a slice of bytes to a slice of elements.
fn from_bytes(bytes: &[u8]) -> &[Self];
/// Element representation for `gpu`.
fn gpu_elem() -> gpu::Elem;
/// Highest possible value
fn maximum_value() -> Self;
/// Lowest possible value
fn minimum_value() -> Self;
}
pub trait JitElement: burn_tensor::Element + CubeElement {}

/// The float element type for the jit backend.
pub trait FloatElement: JitElement + Element {}
pub trait FloatElement: JitElement {}

/// The int element type for the jit backend.
pub trait IntElement: JitElement + Element {}
pub trait IntElement: JitElement {}

impl JitElement for u32 {
fn type_name() -> &'static str {
"u32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::UInt
}
fn maximum_value() -> Self {
u32::MAX
}
fn minimum_value() -> Self {
u32::MIN
}
}
impl JitElement for u32 {}

impl JitElement for i32 {
fn type_name() -> &'static str {
"i32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Int(gpu::IntKind::I32)
}
fn maximum_value() -> Self {
// Seems to cause problem for some GPU
i32::MAX - 1
}
fn minimum_value() -> Self {
// Seems to cause problem for some GPU
i32::MIN + 1
}
}
impl JitElement for i32 {}

impl JitElement for f32 {
fn type_name() -> &'static str {
"f32"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Float(gpu::FloatKind::F32)
}
fn maximum_value() -> Self {
f32::MAX
}
fn minimum_value() -> Self {
f32::MIN
}
}
impl JitElement for f32 {}

impl JitElement for half::f16 {
fn type_name() -> &'static str {
"f16"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Float(gpu::FloatKind::F16)
}
fn maximum_value() -> Self {
half::f16::MAX
}
fn minimum_value() -> Self {
half::f16::MIN
}
}
impl JitElement for half::f16 {}

impl JitElement for half::bf16 {
fn type_name() -> &'static str {
"bf16"
}
fn as_bytes(slice: &[Self]) -> &[u8] {
bytemuck::cast_slice(slice)
}
fn from_bytes(bytes: &[u8]) -> &[Self] {
bytemuck::cast_slice(bytes)
}
fn gpu_elem() -> gpu::Elem {
gpu::Elem::Float(gpu::FloatKind::BF16)
}
fn maximum_value() -> Self {
half::bf16::MAX
}
fn minimum_value() -> Self {
half::bf16::MIN
}
}
impl JitElement for half::bf16 {}
impl FloatElement for f32 {}
impl FloatElement for half::bf16 {}
impl FloatElement for half::f16 {}
Expand Down
25 changes: 0 additions & 25 deletions crates/burn-jit/src/kernel/base.rs

This file was deleted.

8 changes: 4 additions & 4 deletions crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,21 @@ macro_rules! binary {
let mut scope = $crate::codegen::dialect::gpu::Scope::root();
let position = $crate::codegen::dialect::gpu::Variable::Id;

let op = $ops(&mut scope, I::gpu_elem(), position);
let op = $ops(&mut scope, I::cube_elem(), position);
scope.register(op);

let local = scope.last_local_index().unwrap().index().unwrap();

let lhs = $crate::codegen::InputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
item: $crate::codegen::dialect::gpu::Item::Scalar(I::cube_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
};
let rhs = $crate::codegen::InputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
item: $crate::codegen::dialect::gpu::Item::Scalar(I::cube_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
};
let out = $crate::codegen::OutputInfo::ArrayWrite {
item: $crate::codegen::dialect::gpu::Item::Scalar(O::gpu_elem()),
item: $crate::codegen::dialect::gpu::Item::Scalar(O::cube_elem()),
local,
position,
};
Expand Down
11 changes: 6 additions & 5 deletions crates/burn-jit/src/kernel/cast/base.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use burn_cube::cpa;
use std::{any::TypeId, marker::PhantomData};

use crate::{
codegen::{
Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo,
OutputInfo, WorkgroupLaunch,
},
gpu::{gpu, ComputeShader, Scope, Variable, Visibility},
gpu::{ComputeShader, Scope, Variable, Visibility},
kernel::GpuComputeShaderPhase,
tensor::JitTensor,
JitElement, Runtime,
Expand Down Expand Up @@ -64,8 +65,8 @@ impl<R: Runtime, EI: JitElement, EO: JitElement> GpuComputeShaderPhase
{
fn compile(&self) -> ComputeShader {
let mut scope = Scope::root();
let item_input = EI::gpu_elem().into();
let item_output = EO::gpu_elem().into();
let item_input = EI::cube_elem().into();
let item_output = EO::cube_elem().into();

let tensor = Variable::GlobalInputArray(0, item_input);
let output = Variable::GlobalOutputArray(0, item_output);
Expand Down Expand Up @@ -103,7 +104,7 @@ impl CastShader {
let output = self.output;

let value = scope.create_local(output.item());
gpu!(scope, value = tensor[id]);
gpu!(scope, output[id] = value);
cpa!(scope, value = tensor[id]);
cpa!(scope, output[id] = value);
}
}
Loading