Summary
flashinfer.quantization.fp4_quantize cannot be called from inside a torch.compile'd region. This blocks any inference framework that uses torch.compile / piecewise CUDA graphs with NVFP4-quantized models. Hit on a Qwen3-235B-NVFP4 deployment via sgl-project/sglang#21419, where the only workaround is --disable-piecewise-cuda-graph, costing roughly 2–3× decode throughput.
Happy to submit a PR with the fix once maintainers confirm the swizzled scale-factor shape rules — sketch at the end.
Reproduction
import torch
from flashinfer.quantization import fp4_quantize
@torch.compile(fullgraph=True)
def f(x, scale):
return fp4_quantize(x, scale)
x = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([1.0], device="cuda")
f(x, scale) # crashes
Two stacked failures
(1) Lazy JIT inside the trace
First call to fp4_quantize invokes get_fp4_quantization_module(...).build_and_load(), which spawns a subprocess.Popen and constructs a threading.Lock(). Dynamo cannot trace _thread.lock.__new__ (skip-listed C builtin) or subprocess.Popen.__init__:
torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
Explanation: Dynamo does not know how to trace `lock.__new__` ...
Pre-warming get_fp4_quantization_module outside the trace does not help — dynamo still walks Path.exists → os.stat inside is_aot on every traced call:
File ".../flashinfer/jit/core.py", line 261, in is_aot
return self.aot_path.exists()
File ".../pathlib/_local.py", line 515, in stat
return os.stat(self, follow_symlinks=follow_symlinks)
(2) FakeTensor incompatibility
Marking fp4_quantize with torch._dynamo.allow_in_graph (treat as opaque) gets past (1), but during fake-tensor propagation dynamo calls the function with FakeTensor inputs, and the underlying TVM kernel calls .data_ptr(), which FakeTensor rejects:
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor).
If you're using torch.compile/export/fx, it is likely that we are erroneously tracing
into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op.
Why no obvious workaround is enough
| Approach |
Blocker |
@torch._dynamo.disable on fp4_quantize |
sglang's piecewise compile uses fullgraph=True → "Skip calling torch.compiler.disable()d function" hard error |
@torch._dynamo.allow_in_graph |
dynamo executes the function with FakeTensors → data_ptr crash |
| Pre-warming the JIT module |
doesn't help — is_aot still does filesystem I/O on every call, which dynamo traces |
The fundamental issue is that fp4_quantize is a regular Python function that does (a) lazy JIT side effects and (b) C++ kernel calls that touch raw storage. Neither is compatible with torch.compile's tracing model. The same applies to block_scale_interleave.
Proposed fix
Wrap fp4_quantize (and block_scale_interleave) as torch.library.custom_op with proper register_fake meta kernels. This makes them fully opaque to dynamo: the meta kernel runs during tracing, the real kernel runs at execution time, no JIT/lock/data_ptr access ever happens inside a compile region. Public Python API stays unchanged.
Sketch
@torch.library.custom_op(
"flashinfer::fp4_quantize_op",
mutates_args=(),
device_types="cuda",
)
def _fp4_quantize_op(
x: torch.Tensor,
global_scale: torch.Tensor,
sf_vec_size: int,
sf_use_ue8m0: bool,
is_sf_swizzled_layout: bool,
is_sf_8x4_layout: bool,
enable_pdl: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
major, minor = get_compute_capability(x.device)
return get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100(
x, global_scale, sf_vec_size, sf_use_ue8m0,
is_sf_swizzled_layout, is_sf_8x4_layout, enable_pdl,
)
@_fp4_quantize_op.register_fake
def _(x, global_scale, sf_vec_size, sf_use_ue8m0,
is_sf_swizzled_layout, is_sf_8x4_layout, enable_pdl):
K = x.shape[-1]
M = x.numel() // K
x_q = torch.empty((*x.shape[:-1], K // 2), dtype=torch.uint8, device=x.device)
# See open question 1 below — this is the shape I need confirmed.
if is_sf_swizzled_layout:
M_padded = ((M + 127) // 128) * 128
K_sf = ((K // sf_vec_size + 3) // 4) * 4
else:
M_padded = M
K_sf = K // sf_vec_size
sf = torch.empty((M_padded, K_sf), dtype=torch.uint8, device=x.device)
return x_q, sf
# Public wrapper — unchanged signature, dispatches to the registered op so
# torch.compile sees an opaque op rather than a Python function.
def fp4_quantize(input, global_scale=None, sf_vec_size=16, sf_use_ue8m0=False,
is_sf_swizzled_layout=True, is_sf_8x4_layout=False, enable_pdl=None):
...
x_q, sf = torch.ops.flashinfer.fp4_quantize_op(
input, global_scale, sf_vec_size, sf_use_ue8m0,
is_sf_swizzled_layout, is_sf_8x4_layout, enable_pdl,
)
sf = sf.reshape((-1, input.shape[-1] // sf_vec_size))
...
A test that fails on main and should pass with this fix:
import torch
from flashinfer.quantization import fp4_quantize
x = torch.randn(64, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([1.0], device="cuda")
eager_q, eager_sf = fp4_quantize(x, scale)
@torch.compile(fullgraph=True)
def compiled(x, scale):
return fp4_quantize(x, scale)
comp_q, comp_sf = compiled(x, scale)
torch.testing.assert_close(eager_q, comp_q)
torch.testing.assert_close(eager_sf, comp_sf)
Open questions for maintainers
- Swizzled SF shape. Is
M_padded = round_up(M, 128) and K_sf = round_up(K/sf_vec_size, 4) correct for all swizzle modes, or does is_sf_8x4_layout=True use different padding (8×4 blocks)? This is the critical thing to get right — the meta kernel must match the real kernel exactly or fake-tensor propagation explodes downstream.
enable_pdl=None handling. OK to resolve None → device_support_pdl(...) outside the op so the op input is always a concrete bool?
global_scale=None handling. Same — default to torch.tensor([1.0]) outside the op so the op input is always a Tensor?
mutates_args=() — confirming the kernel never writes through input pointers.
Once these are settled I'll send a PR with the meta kernel + a parametrized test covering sf_vec_size ∈ {16, 32} × is_sf_swizzled_layout ∈ {True, False} × is_sf_8x4_layout ∈ {True, False}.
Environment
- flashinfer 0.6.7
- torch 2.9.1+cu128
- CUDA 12.8
- GPU: 2× RTX PRO 6000 (sm_120)
- Model: Qwen3-235B-A22B-Instruct-2507-NVFP4 via sglang TP2
Summary
flashinfer.quantization.fp4_quantizecannot be called from inside atorch.compile'd region. This blocks any inference framework that usestorch.compile/ piecewise CUDA graphs with NVFP4-quantized models. Hit on a Qwen3-235B-NVFP4 deployment via sgl-project/sglang#21419, where the only workaround is--disable-piecewise-cuda-graph, costing roughly 2–3× decode throughput.Happy to submit a PR with the fix once maintainers confirm the swizzled scale-factor shape rules — sketch at the end.
Reproduction
Two stacked failures
(1) Lazy JIT inside the trace
First call to
fp4_quantizeinvokesget_fp4_quantization_module(...).build_and_load(), which spawns asubprocess.Popenand constructs athreading.Lock(). Dynamo cannot trace_thread.lock.__new__(skip-listed C builtin) orsubprocess.Popen.__init__:Pre-warming
get_fp4_quantization_moduleoutside the trace does not help — dynamo still walksPath.exists→os.statinsideis_aoton every traced call:(2) FakeTensor incompatibility
Marking
fp4_quantizewithtorch._dynamo.allow_in_graph(treat as opaque) gets past (1), but during fake-tensor propagation dynamo calls the function withFakeTensorinputs, and the underlying TVM kernel calls.data_ptr(), whichFakeTensorrejects:Why no obvious workaround is enough
@torch._dynamo.disableonfp4_quantizefullgraph=True→ "Skip callingtorch.compiler.disable()d function" hard error@torch._dynamo.allow_in_graphdata_ptrcrashis_aotstill does filesystem I/O on every call, which dynamo tracesThe fundamental issue is that
fp4_quantizeis a regular Python function that does (a) lazy JIT side effects and (b) C++ kernel calls that touch raw storage. Neither is compatible withtorch.compile's tracing model. The same applies toblock_scale_interleave.Proposed fix
Wrap
fp4_quantize(andblock_scale_interleave) astorch.library.custom_opwith properregister_fakemeta kernels. This makes them fully opaque to dynamo: the meta kernel runs during tracing, the real kernel runs at execution time, no JIT/lock/data_ptraccess ever happens inside a compile region. Public Python API stays unchanged.Sketch
A test that fails on main and should pass with this fix:
Open questions for maintainers
M_padded = round_up(M, 128)andK_sf = round_up(K/sf_vec_size, 4)correct for all swizzle modes, or doesis_sf_8x4_layout=Trueuse different padding (8×4 blocks)? This is the critical thing to get right — the meta kernel must match the real kernel exactly or fake-tensor propagation explodes downstream.enable_pdl=Nonehandling. OK to resolveNone→device_support_pdl(...)outside the op so the op input is always a concretebool?global_scale=Nonehandling. Same — default totorch.tensor([1.0])outside the op so the op input is always aTensor?mutates_args=()— confirming the kernel never writes through input pointers.Once these are settled I'll send a PR with the meta kernel + a parametrized test covering
sf_vec_size ∈ {16, 32}×is_sf_swizzled_layout ∈ {True, False}×is_sf_8x4_layout ∈ {True, False}.Environment