Skip to content

fp4_quantize is incompatible with torch.compile (lazy JIT + raw data_ptr access) #2999

@BehindTheCartan

Description

@BehindTheCartan

Summary

flashinfer.quantization.fp4_quantize cannot be called from inside a torch.compile'd region. This blocks any inference framework that uses torch.compile / piecewise CUDA graphs with NVFP4-quantized models. Hit on a Qwen3-235B-NVFP4 deployment via sgl-project/sglang#21419, where the only workaround is --disable-piecewise-cuda-graph, costing roughly 2–3× decode throughput.

Happy to submit a PR with the fix once maintainers confirm the swizzled scale-factor shape rules — sketch at the end.

Reproduction

import torch
from flashinfer.quantization import fp4_quantize

@torch.compile(fullgraph=True)
def f(x, scale):
    return fp4_quantize(x, scale)

x = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([1.0], device="cuda")
f(x, scale)  # crashes

Two stacked failures

(1) Lazy JIT inside the trace

First call to fp4_quantize invokes get_fp4_quantization_module(...).build_and_load(), which spawns a subprocess.Popen and constructs a threading.Lock(). Dynamo cannot trace _thread.lock.__new__ (skip-listed C builtin) or subprocess.Popen.__init__:

torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
  Explanation: Dynamo does not know how to trace `lock.__new__` ...

Pre-warming get_fp4_quantization_module outside the trace does not help — dynamo still walks Path.existsos.stat inside is_aot on every traced call:

File ".../flashinfer/jit/core.py", line 261, in is_aot
    return self.aot_path.exists()
File ".../pathlib/_local.py", line 515, in stat
    return os.stat(self, follow_symlinks=follow_symlinks)

(2) FakeTensor incompatibility

Marking fp4_quantize with torch._dynamo.allow_in_graph (treat as opaque) gets past (1), but during fake-tensor propagation dynamo calls the function with FakeTensor inputs, and the underlying TVM kernel calls .data_ptr(), which FakeTensor rejects:

RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor).
If you're using torch.compile/export/fx, it is likely that we are erroneously tracing
into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op.

Why no obvious workaround is enough

Approach Blocker
@torch._dynamo.disable on fp4_quantize sglang's piecewise compile uses fullgraph=True → "Skip calling torch.compiler.disable()d function" hard error
@torch._dynamo.allow_in_graph dynamo executes the function with FakeTensors → data_ptr crash
Pre-warming the JIT module doesn't help — is_aot still does filesystem I/O on every call, which dynamo traces

The fundamental issue is that fp4_quantize is a regular Python function that does (a) lazy JIT side effects and (b) C++ kernel calls that touch raw storage. Neither is compatible with torch.compile's tracing model. The same applies to block_scale_interleave.

Proposed fix

Wrap fp4_quantize (and block_scale_interleave) as torch.library.custom_op with proper register_fake meta kernels. This makes them fully opaque to dynamo: the meta kernel runs during tracing, the real kernel runs at execution time, no JIT/lock/data_ptr access ever happens inside a compile region. Public Python API stays unchanged.

Sketch

@torch.library.custom_op(
    "flashinfer::fp4_quantize_op",
    mutates_args=(),
    device_types="cuda",
)
def _fp4_quantize_op(
    x: torch.Tensor,
    global_scale: torch.Tensor,
    sf_vec_size: int,
    sf_use_ue8m0: bool,
    is_sf_swizzled_layout: bool,
    is_sf_8x4_layout: bool,
    enable_pdl: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
    major, minor = get_compute_capability(x.device)
    return get_fp4_quantization_module(f"{major}{minor}").fp4_quantize_sm100(
        x, global_scale, sf_vec_size, sf_use_ue8m0,
        is_sf_swizzled_layout, is_sf_8x4_layout, enable_pdl,
    )

@_fp4_quantize_op.register_fake
def _(x, global_scale, sf_vec_size, sf_use_ue8m0,
      is_sf_swizzled_layout, is_sf_8x4_layout, enable_pdl):
    K = x.shape[-1]
    M = x.numel() // K
    x_q = torch.empty((*x.shape[:-1], K // 2), dtype=torch.uint8, device=x.device)
    # See open question 1 below — this is the shape I need confirmed.
    if is_sf_swizzled_layout:
        M_padded = ((M + 127) // 128) * 128
        K_sf = ((K // sf_vec_size + 3) // 4) * 4
    else:
        M_padded = M
        K_sf = K // sf_vec_size
    sf = torch.empty((M_padded, K_sf), dtype=torch.uint8, device=x.device)
    return x_q, sf

# Public wrapper — unchanged signature, dispatches to the registered op so
# torch.compile sees an opaque op rather than a Python function.
def fp4_quantize(input, global_scale=None, sf_vec_size=16, sf_use_ue8m0=False,
                 is_sf_swizzled_layout=True, is_sf_8x4_layout=False, enable_pdl=None):
    ...
    x_q, sf = torch.ops.flashinfer.fp4_quantize_op(
        input, global_scale, sf_vec_size, sf_use_ue8m0,
        is_sf_swizzled_layout, is_sf_8x4_layout, enable_pdl,
    )
    sf = sf.reshape((-1, input.shape[-1] // sf_vec_size))
    ...

A test that fails on main and should pass with this fix:

import torch
from flashinfer.quantization import fp4_quantize

x = torch.randn(64, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([1.0], device="cuda")

eager_q, eager_sf = fp4_quantize(x, scale)

@torch.compile(fullgraph=True)
def compiled(x, scale):
    return fp4_quantize(x, scale)

comp_q, comp_sf = compiled(x, scale)
torch.testing.assert_close(eager_q, comp_q)
torch.testing.assert_close(eager_sf, comp_sf)

Open questions for maintainers

  1. Swizzled SF shape. Is M_padded = round_up(M, 128) and K_sf = round_up(K/sf_vec_size, 4) correct for all swizzle modes, or does is_sf_8x4_layout=True use different padding (8×4 blocks)? This is the critical thing to get right — the meta kernel must match the real kernel exactly or fake-tensor propagation explodes downstream.
  2. enable_pdl=None handling. OK to resolve Nonedevice_support_pdl(...) outside the op so the op input is always a concrete bool?
  3. global_scale=None handling. Same — default to torch.tensor([1.0]) outside the op so the op input is always a Tensor?
  4. mutates_args=() — confirming the kernel never writes through input pointers.

Once these are settled I'll send a PR with the meta kernel + a parametrized test covering sf_vec_size ∈ {16, 32} × is_sf_swizzled_layout ∈ {True, False} × is_sf_8x4_layout ∈ {True, False}.

Environment

  • flashinfer 0.6.7
  • torch 2.9.1+cu128
  • CUDA 12.8
  • GPU: 2× RTX PRO 6000 (sm_120)
  • Model: Qwen3-235B-A22B-Instruct-2507-NVFP4 via sglang TP2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions