Skip to content

Commit 5935252

Browse files
committed
add pretty batching test too
1 parent c32caa5 commit 5935252

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

tests/pretty/autodiff_forward.pp

+31
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
// Make sure, that we add the None for the default return.
3030

31+
3132
::core::panicking::panic("not implemented")
3233
}
3334
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
@@ -127,4 +128,34 @@
127128
::core::hint::black_box(f7(x));
128129
::core::hint::black_box(());
129130
}
131+
#[no_mangle]
132+
#[rustc_autodiff]
133+
#[inline(never)]
134+
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
135+
#[rustc_autodiff(Forward, 4, Dual, Dual)]
136+
#[inline(never)]
137+
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
138+
-> [f32; 5usize] {
139+
unsafe { asm!("NOP", options(pure, nomem)); };
140+
::core::hint::black_box(f8(x));
141+
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
142+
::core::hint::black_box(<[f32; 5usize]>::default())
143+
}
144+
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
145+
#[inline(never)]
146+
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
147+
-> [f32; 4usize] {
148+
unsafe { asm!("NOP", options(pure, nomem)); };
149+
::core::hint::black_box(f8(x));
150+
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
151+
::core::hint::black_box(<[f32; 4usize]>::default())
152+
}
153+
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
154+
#[inline(never)]
155+
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
156+
unsafe { asm!("NOP", options(pure, nomem)); };
157+
::core::hint::black_box(f8(x));
158+
::core::hint::black_box((bx_0,));
159+
::core::hint::black_box(<f32>::default())
160+
}
130161
fn main() {}

tests/pretty/autodiff_forward.rs

+8
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,12 @@ pub fn f6() -> DoesNotImplDefault {
4646
#[autodiff(df7, Forward, Const)]
4747
pub fn f7(x: f32) -> () {}
4848

49+
#[autodiff(f8_1, Forward, Dual, DualOnly)]
50+
#[autodiff(f8_2, Forward, 4, Dual, DualOnly)]
51+
#[autodiff(f8_3, Forward, 4, Dual, Dual)]
52+
#[no_mangle]
53+
fn f8(x: &f32) -> f32 {
54+
unimplemented!()
55+
}
56+
4957
fn main() {}

0 commit comments

Comments
 (0)