Skip to content

πŸ€— A specialized library for integrating context-free grammars (CFG) in EBNF with the Hugging Face Transformers

License

Notifications You must be signed in to change notification settings

epfl-dlab/transformers-CFG

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

πŸ€— Transformers CFG

Python 3.9+ License

πŸ’­ Release news

Latest experimental

Features

  • LlamaCPP Python wrapper support (#116)

Bug fixes

  • pip show license (#117)

Latest stable

v0.2.7 (2025-03-02)

Features

  • Types and MLX (#93)
  • Negation (#94)
  • Wildcards (#95)
  • Repetition brackets (#96, #104)
  • Qwen2 and Qwen2.5 (#97)
  • Resuable logits processor (#100)
  • Pytest (#109)
  • GitHub Actions workflow (#110)

Bug fixes

  • Avoid computing full masks and optimized type additions (#101)
  • Refactored grammar encoding to improve structure (#99)
  • EOS token now correctly masks (#108)
  • Multiple bugs removed and aesthetics improved (#107)

Recent releases

  • Gemma-2 β€” @fillassuncao (2024-08-16)
  • DeepSeek (2024-07-24)
  • LLaMA-3 (2024-07-08)
  • JSON Schema (2024-05-13)
  • Mask optimization (2024-04-25)
  • Phi (2024-04-16)
  • Online demo (2024-04-10)
  • Unicode and foreign text (2024-02-29)
  • Text-Generation-WebUI (2023-12-17)
    • We are pleased to announce that transformers-cfg has been integrated into the Text-Generation-WebUI project, allowing users to leverage CFG capabilities within this widely used text-generation interface (PR).

πŸš€ Introduction

Initially developed as a pull request to the Hugging Face Transformers library (PR), transformers-cfg extends the Hugging Face Transformers library to support constrained decoding through context-free grammars (CFG), offering a Transformers parellel for LlamaCPP's GBNF support, but with stricter generation rules.

πŸ’» Installation

Stable

Install the stable version via pip:

pip install transformers-cfg

Development

For the latest updates, install directly from GitHub:

pip install git+https://github.com/epfl-dlab/transformers-CFG.git@main

πŸ’‘ Why use transformers-cfg?

  • EBNF Grammar Support: Uses Extended Backus-Naur Form (EBNF) for grammar description.
  • Seamless Integration: Compatible with the llama-cpp project for easy replacement.
  • Broad Model Compatibility: Works with all models in the πŸ€— Transformers library.
  • Multilingual Grammar Support: Enables grammars in various languages, including Chinese (δΈ­ζ–‡), Japanese (ζ—₯本θͺž), Korean (ν•œκ΅­μ–΄), Hindi (ΰ€Ήΰ€Ώΰ€¨ΰ₯ΰ€¦ΰ₯€), Hebrew (Χ’Χ‘Χ¨Χ™Χͺ), Arabic (Ψ§Ω„ΨΉΨ±Ψ¨ΩŠΨ©), and emoji (πŸ€—).

πŸ€” What is a grammar?

Think of it as an enhanced version of regular expressions.

Valid JSON object

root ::= object
object ::= "{" pair ("," pair)* "}"
pair ::= string ":" value
string ::= '"' [a-zA-Z0-9]* '"'
value ::= string | object | "true" | "false" | "null"

For advanced grammar debugging, see our debugging guide.

πŸ”§ Grammar quickstart

Let's set up a predictable generation method where the model would usually reply with "The animal is a dog." However, we'll force the model to say either "The animal is a cat" or "The animal is a fish," two other common domestic pets that contradict the inital text.

Command-line interface (CLI)

The transformers-cfg-cli tool enables text generation using a model and a specified grammar. Unicode is supported.

transformers-cfg-cli generate \
    -m "facebook/opt-125m" \
    -g "examples/grammars/animal.ebnf" \
    -p 'The text says, "The animal is a dog." The answer is obvious. ' \
    --max_new_tokens 50 \
# The animal is a cat.

Run transformers-cfg-cli generate --help for available options.

Transformers Torch

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

if __name__ == "__main__":
    # Set device: use GPU if available, else CPU.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Model identifier
    model_id = "facebook/opt-125m"

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
    model.generation_config.pad_token_id = model.generation_config.eos_token_id

    # Define grammar string
    grammar_str = """
    root   ::= "The animal is a " animal "."
    animal ::= "cat" | "fish"
    """
    
    # Create grammar constraint and logits processor
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

    # Define prompts
    prompts = [
        'The text says, "The animal is a dog." The answer is obvious. ',
        'I\'m going to say "The animal is a dog." Here I go! '
    ]
    
    # Tokenize prompts
    input_ids = tokenizer(prompts, add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"].to(device)

    # Generate constrained text
    output = model.generate(
        input_ids,
        max_length=50,
        logits_processor=[grammar_processor],
        repetition_penalty=1.1,
        num_return_sequences=1,
    )
    
    # Decode and print generated text
    generations = tokenizer.batch_decode(output, skip_special_tokens=True)
    for generation in generations:
        print(generation)

# The animal is a cat.

Stream

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

if __name__ == "__main__":
    # Set device: use GPU if available, else CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Model identifier
    model_id = "facebook/opt-125m"

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_id).to(device)
    model.generation_config.pad_token_id = model.generation_config.eos_token_id

    # Define grammar string
    grammar_str = """
    root   ::= "The animal is a " animal "."
    animal ::= "cat" | "fish"
    """
    
    # Create grammar constraint and logits processor
    grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
    grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

    # Define prompt
    prompts = [
        'The text says, "The animal is a dog." The answer is obvious. '
    ]
    
    # Tokenize prompt
    input_ids = tokenizer(prompts, add_special_tokens=False, return_tensors="pt", padding=True)["input_ids"].to(device)

    # Set up streaming
    streamer = TextStreamer(tokenizer)

    # Generate constrained text with streaming.
    model.generate(
        input_ids,
        max_length=50,
        logits_processor=[grammar_processor],
        repetition_penalty=1.1,
        num_return_sequences=1,
        streamer=streamer
    )

# The animal is a cat.

Transformers Pipeline

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor

# Model identifier
model_id = "facebook/opt-125m"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForCausalLM.from_pretrained(model_id).to(device)

# Define grammar string
grammar_str = """
root   ::= "The animal is a " animal "."
animal ::= "cat" | "fish"
"""

# Create grammar constraint and logits processor
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)

# Initialize text generation pipeline
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    max_new_tokens=100,
    batch_size=2,
)

# Define prompts
prompts = [
    'The text says, "The animal is a dog." The answer is obvious. ',
    'I\'m going to say "The animal is a dog." Here I go! '
]

# Generate constrained text using the pipeline.
generations = pipe(
    prompts,
    do_sample=False,
    logits_processor=[grammar_processor],
)

# Print generated texts
for generation_group in generations:
    for generation in generation_group:
        print(generation['generated_text'])

# The animal is a cat.

LlamaCPP Python

Use the llama-cpp-python adapter, automatically loadable with the adapter parameter.

import io
import logging
from contextlib import redirect_stderr
from llama_cpp import Llama
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
from transformers import AutoTokenizer

logging.basicConfig(level=logging.INFO)

# Define grammar string.
grammar_str = """
root   ::= "The animal is a " animal "."
animal ::= "cat" | "fish"
"""

# Load the tokenizer matching the model.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5b")

# Redirect stderr and load the model via llama-cpp-python.
with redirect_stderr(io.StringIO()):
    model = Llama(model_path="qwen2.5-1.5b-q8_0.gguf", n_ctx=8000, verbose=False)

# Create grammar constraint and logits processor using the adapter.
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar, adapter="llama-cpp-python")

# Define prompt.
prompt = 'The text says, "The animal is a dog." The answer is obvious. '

# Generate constrained text (non-streaming).
response = model.create_completion(
    prompt=prompt,
    logits_processor=[grammar_processor],
    max_tokens=100,
)

# Print generated text.
print(response["choices"][0]["text"])

# The animal is a cat.

Stream

import io
import logging
from contextlib import redirect_stderr
from llama_cpp import Llama
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
from transformers import AutoTokenizer

logging.basicConfig(level=logging.INFO)

# Define grammar string
grammar_str = """
root   ::= "The animal is a " animal "."
animal ::= "cat" | "fish"
"""

# Load the tokenizer matching the model
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5b")

# Redirect stderr and load the model via llama-cpp-python
with redirect_stderr(io.StringIO()):
    model = Llama(model_path="qwen2.5-1.5b-q8_0.gguf", n_ctx=8000, verbose=False)

# Create grammar constraint and logits processor using the adapter
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar, adapter="llama-cpp-python")

# Define prompt.
prompt = 'The text says, "The animal is a dog." The answer is obvious. '

# Generate constrained text with streaming
response = model.create_completion(
    stream=True,
    prompt=prompt,
    logits_processor=[grammar_processor],
    max_tokens=100,
)

# Stream and print generated text
for token in response:
    print(token["choices"][0]["text"], end="", flush=True)

# The animal is a cat.

πŸ“œ Grammar collection

We maintain a collection of grammars in examples/grammars, aligned with llama-cpp:

πŸ›  JSON schema

Learn to create grammars for complex JSON objects in our documentation.

βœ… Supported tokenizers

πŸ€– Tested models

Qwen (≀ 2.5)
LLaMa (≀ 3.3)
GPT (≀ 2)
Mistral (≀ 0.3)
Falcon (≀ 3.0)
OPT

If you encounter an unsupported model, please open an issue or submit a pull request.

πŸ“– Citation

If you find this work useful, please cite it with the reccomended citation:

@inproceedings{geng-etal-2023-grammar,
  title        = {Grammar-Constrained Decoding for Structured {NLP} Tasks without Finetuning},
  author       = {Geng, Saibo and Josifoski, Martin and Peyrard, Maxime and West, Robert},
  year         = 2023,
  month        = dec,
  booktitle    = {Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing},
  publisher    = {Association for Computational Linguistics},
  address      = {Singapore},
  url          = {https://aclanthology.org/2023.emnlp-main.674},
  editor       = {Bouamor, Houda and Pino, Juan and Bali, Kalika}
}

πŸ“œ License

This project is licensed under the MIT License.

πŸ™Œ Acknowledgements

Derived from torch-grammars, which was based on llama-cpp.

About

πŸ€— A specialized library for integrating context-free grammars (CFG) in EBNF with the Hugging Face Transformers

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages