CAST 0.7B — Speech-to-Speech

arXiv Demo Codec Dependency

Final checkpoint files. Depends on KrauthammerLab/cast-wavtokenizer-24k-40tps for encode/decode.

CAST 0.7B — Speech-to-Speech Model

CAST 0.7B is a speech-to-speech language model built on a 0.7B parameter Gemma3-style LM.
It generates natural continuations of spoken audio.

It requires the companion CAST WavTokenizer for encode/decode.


Demo

Interactive samples and usage examples: https://mortezaro.github.io/speech-cast/


Paper

Optimizing Speech Language Models for Acoustic Consistency
arXiv: 2509.26276 — https://arxiv.org/abs/2509.26276

We study speech language models that incorporate semantic initialization and planning losses to achieve robust and consistent generation. Our approach initializes speech tokens with self-supervised features, applies a light alignment loss, and trains with thinning and auxiliary objectives that target robustness and content planning. We train three models: a 0.7B speech-only model, a 1.0B speech-only model, and a 1.0B interleaved model with both text and speech. Acoustic studies show that the speech-only models achieve the highest consistency across speaker, gender, sentiment, room, and background factors, surpassing larger systems. Interleaving improves lexical and syntactic probes and semantic–acoustic alignment but reduces consistency. Linear probes show that our initialization biases the model toward content structure while trading off prosody detail. These results show that LM-side design and training mix control the balance between acoustic stability and semantic grounding without changes to the tokenizer or runtime architecture. A demo and model weights are available for exploration.

Installation

pip install torch torchaudio transformers accelerate soundfile
pip install git+https://github.com/jishengpeng/WavTokenizer.git

1- Resynthesis


WT_REPO = "KrauthammerLab/cast-wavtokenizer-24k-40tps"

# Download tokenizer ckpt + config from HF
wt_ckpt = hf_hub_download(WT_REPO, filename="wavtokenizer_large_unify_600_24k.ckpt")
try:
    wt_cfg  = hf_hub_download(WT_REPO, filename="config.yaml")
except Exception:
    wt_cfg = None  # cfg optional in your setup

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load WavTokenizer (codec)
if wt_cfg is not None:
    wt = WavTokenizer.from_pretrained0802(wt_cfg, wt_ckpt).to(device)
else:
    wt = WavTokenizer.from_pretrained0802(None, wt_ckpt).to(device)

# Load a 16 kHz prompt
wav16, sr = torchaudio.load("prompt_16k.wav")  # mono recommended
assert sr == 16000, f"Expected 16k input, got {sr}"

# (Optional) ensure mono
if wav16.size(0) > 1:
    wav16 = wav16.mean(dim=0, keepdim=True)

# Resample 16k -> 24k before encode (your pipeline runs at 24k)
wav24 = torchaudio.functional.resample(wav16, orig_freq=16000, new_freq=24000).to(device)

# Encode → features, codes
bandwidth_id = torch.tensor([0], device=device)
feats, codes = wt.encode_infer(wav24, bandwidth_id=bandwidth_id)  # feats: [1, ?, T], codes: [1, streams?, T] or [1, T]

# Decode back to waveform (24 kHz)
recon24 = wt.decode(feats, bandwidth_id=bandwidth_id)  # [1, T] or [1,1,T]
if recon24.dim() == 3:
    recon24 = recon24.squeeze(0)

# Save 24k round-trip audio
sf.write("recon_24k.wav", recon24.squeeze(0).detach().cpu().numpy(), 24000)
print("Wrote recon_24k.wav")
  1. Speech generation


LM_REPO = "KrauthammerLab/cast-0.7b-s2s"
WT_REPO = "KrauthammerLab/cast-wavtokenizer-24k-40tps"

device = "cuda" if torch.cuda.is_available() else "cpu"
codes_per_second = 40          # your setup: ~40 tokens/s
codebook_size    = 4096        # [Sp1]..[Sp4096]
speech_prefix     = "[Speech]"

# ---------- helpers ----------
def equal_power_crossfade(prev_24k: torch.Tensor, cont_24k: torch.Tensor, fade_ms: int = 40, sr: int = 24000) -> torch.Tensor:
    """Equal-power crossfade between prev and cont (both [1,T] @ 24k)."""
    fade = max(1, int(sr * fade_ms / 1000))
    prev_24k = prev_24k.to(device)
    cont_24k = cont_24k.to(device)
    if prev_24k.size(1) < fade or cont_24k.size(1) < fade:
        return torch.cat([prev_24k, cont_24k], dim=1)
    a = prev_24k[:, -fade:]
    b = cont_24k[:, :fade]
    t = torch.linspace(0, 1, fade, device=device).view(1, -1)
    mix = torch.cos(t * 0.5 * math.pi) * a + torch.sin(t * 0.5 * math.pi) * b
    return torch.cat([prev_24k[:, :-fade], mix, cont_24k[:, fade:]], dim=1)

class SpeechOnlyLogitsProcessor(LogitsProcessor):
    """Mask logits so only [Sp#] tokens (and EOS) can be sampled."""
    def __init__(self, allowed: Set[int]):
        super().__init__()
        self.allowed = list(allowed)

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
        mask = torch.full_like(scores, float("-inf"))
        mask[..., self.allowed] = 0.0
        return scores + mask

# ---------- load codec ----------
wt_ckpt = hf_hub_download(WT_REPO, filename="wavtokenizer_large_unify_600_24k.ckpt")
try:
    wt_cfg  = hf_hub_download(WT_REPO, filename="config.yaml")
except Exception:
    wt_cfg = None

if wt_cfg is not None:
    wt = WavTokenizer.from_pretrained0802(wt_cfg, wt_ckpt).to(device)
else:
    wt = WavTokenizer.from_pretrained0802(None, wt_ckpt).to(device)

# ---------- load LM + tokenizer ----------
tok = AutoTokenizer.from_pretrained(LM_REPO)
lm  = AutoModelForCausalLM.from_pretrained(LM_REPO, torch_dtype=torch.bfloat16).to(device).eval()

# Build speech token id table: "[Sp1]".."[Sp4096]" must be single tokens
speech_token_ids: List[int] = []
for i in range(1, codebook_size + 1):
    ids = tok(f"[Sp{i}]", add_special_tokens=False)["input_ids"]
    if len(ids) != 1:
        raise RuntimeError(f"[Sp{i}] is not a single token; tokenizer mismatch.")
    speech_token_ids.append(ids[0])

# For mapping back: token_id -> code_index (0-based)
id2code = {tid: i for i, tid in enumerate(speech_token_ids)}
eos_id = tok.eos_token_id
allowed_ids = set(speech_token_ids + ([eos_id] if eos_id is not None else []))

# ---------- load prompt audio (16k), encode to codes ----------
wav16, sr = torchaudio.load("prompt_16k.wav")     # mono
assert sr == 16000
if wav16.size(0) > 1:
    wav16 = wav16.mean(dim=0, keepdim=True)

# Resample to 24k before codec
wav24 = torchaudio.functional.resample(wav16, orig_freq=16000, new_freq=24000).to(device)

bw = torch.tensor([0], device=device)
feats, codes = wt.encode_infer(wav24, bandwidth_id=bw)
# Normalize shapes to [T] list[int]
if codes.dim() == 3:
    codes = codes.squeeze(0)
    codes = codes[0] if codes.size(0) > 1 else codes.squeeze(0)
elif codes.dim() == 2:
    codes = codes.squeeze(0)
codes_list = codes.long().tolist()  # each in [0..4095]

# ---------- optional: decode round-trip for the stitched prefix ----------
recon24 = wt.decode(feats, bandwidth_id=bw)
if recon24.dim() == 3:
    recon24 = recon24.squeeze(0)

# ---------- build LM prefix string ----------
prefix_text = speech_prefix + "".join(f"[Sp{c+1}]" for c in codes_list)
enc = tok(prefix_text, return_tensors="pt")
input_ids = enc["input_ids"].to(device)
attn_mask = enc.get("attention_mask", None)
if attn_mask is not None:
    attn_mask = attn_mask.to(device)

# ---------- generate continuation (about N seconds) ----------
seconds = 3.0
max_new_tokens = max(1, int(round(seconds * codes_per_second)))

lp = LogitsProcessorList([SpeechOnlyLogitsProcessor(allowed_ids)])

gen = lm.generate(
    input_ids=input_ids,
    attention_mask=attn_mask,
    max_new_tokens=max_new_tokens,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    repetition_penalty=1.1,
    eos_token_id=eos_id,
    pad_token_id=(tok.pad_token_id if tok.pad_token_id is not None else eos_id),
    logits_processor=lp,
)

# Strip prefix and EOS
gen_tail = gen[0][input_ids.size(1):].tolist()
if eos_id is not None and eos_id in gen_tail:
    gen_tail = gen_tail[:gen_tail.index(eos_id)]

# Map token ids -> code indices (0-based)
new_codes = [id2code[t] for t in gen_tail if t in id2code]

# (Nice) keep last ~1s of prompt codes to avoid a hard seam
keep_sec = 1.0
keep = max(0, int(round(keep_sec * codes_per_second)))
tail_codes = codes_list[-keep:] if keep > 0 else []
decode_codes = tail_codes + new_codes

# Decode to audio (24 kHz)
tok_tensor = torch.tensor(decode_codes, dtype=torch.long, device=device).view(1,1,-1)
cont24 = wt.decode(wt.codes_to_features(tok_tensor), bandwidth_id=bw)
if cont24.dim() == 3:
    cont24 = cont24.squeeze(0)

# Stitch with crossfade
stitched = equal_power_crossfade(recon24, cont24, fade_ms=60, sr=24000)

# Save files
sf.write("recon_24k.wav",     recon24.squeeze(0).detach().cpu().numpy(), 24000)
sf.write("continuation.wav",  cont24.squeeze(0).detach().cpu().numpy(),  24000)
sf.write("stitched_24k.wav",  stitched.squeeze(0).detach().cpu().numpy(),24000)
print("Wrote recon_24k.wav, continuation.wav, stitched_24k.wav")
Downloads last month
358
Safetensors
Model size
0.8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support