Skip to content

BUG: Static decoding is from left to right with temp=0 #58

@JacobHelwig

Description

@JacobHelwig

There is a bug in the code that makes static decoding with temp=0 or top-k=1 always decode the left-most token. This bug also affects dynamic decoding.

First we process logits. With temp=0 or k=1, this turns them into 1-hots. This makes the probability of sampled tokens at all masked positions 1.

Raw logits per masked position (illustrative):
  pos 0:   [..., -2.1,  5.3,  1.2,  4.8,  3.1, ...]
  pos 1:   [..., -1.0,  0.2, -3.0,  5.5,  2.1, ...]
  pos 2:   [...,  3.1,  1.1,  4.0,  2.8,  4.5, ...]

After top_k=1 (non-argmax logits → -inf), then softmax:
  pos 0:   [...,  0.0,  1.0,  0.0,  0.0,  0.0, ...]   ← one-hot
  pos 1:   [...,  0.0,  0.0,  0.0,  1.0,  0.0, ...]   ← one-hot
  pos 2:   [...,  0.0,  0.0,  0.0,  0.0,  1.0, ...]   ← one-hot

  → seq_x0_p = gather(probs, sampled_token) = 1.0 at every masked position.

When static sampling, this means we call top-k on a tensor of all ones. To break the tie, we select the first (left-most) position.

Decoding a 4-token block over 4 steps (num_to_transfer=1 each step):

Step 0: block state = [MASK,  MASK,  MASK,  MASK ]
        confidence  = [ 1.0,   1.0,   1.0,   1.0 ]    ← all tied
        topk(k=1) on tied 1.0s → picks index 0 (PyTorch tiebreak)
        commit at pos 0
                ↓
Step 1: block state = [tok_0, MASK,  MASK,  MASK ]
        confidence  = [-inf,   1.0,   1.0,   1.0 ]
        topk(k=1) → pos 1
                ↓
Step 2: block state = [tok_0, tok_1, MASK,  MASK ]
        confidence  = [-inf,  -inf,   1.0,   1.0 ]
        topk(k=1) → pos 2
                ↓
Step 3: block state = [tok_0, tok_1, tok_2, MASK ]
        commit at pos 3.

  → every block is decoded strictly left → right, regardless of which
    positions the model is actually most confident about.

When dynamic decoding with temp=0, since the threshold condition is applied to the processed probabilities (all 1 still), it will result in all positions satisfying the condition if threshold < 1, and thus, all positions are decoded.

Decoding a 4-token block, dynamic_threshold = 0.9:

Step 0: block state = [MASK, MASK, MASK, MASK]
        confidence  = [1.0,  1.0,  1.0,  1.0 ]
        confidence > 0.9 = [True, True, True, True]
        → COMMIT ALL 4 POSITIONS IN ONE STEP

  → every block collapses to a single forward pass; each position is
    sampled in isolation, block-diffusion's iterative refinement is
    skipped, and the output is typically garbled, repetitive text.

These bugs are in the original SDAR code, but they have since started using a corrected sampling pipeline.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions