Metadata-Version: 2.4
Name: autoregressive-language-model-generate
Version: 0.1.0a0
Summary: A generator-based, stateless autoregressive inference loop for language models compatible with HuggingFace's Transformers API.
Author-email: Jifeng Wu <jifengwu2k@gmail.com>
License-Expression: MIT
Project-URL: Homepage, https://github.com/jifengwu2k/autoregressive-language-model-generate
Project-URL: Bug Tracker, https://github.com/jifengwu2k/autoregressive-language-model-generate/issues
Classifier: Programming Language :: Python :: 2
Classifier: Programming Language :: Python :: 3
Classifier: Operating System :: OS Independent
Requires-Python: >=2
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch
Requires-Dist: typing; python_version < "3.5"
Dynamic: license-file

# `autoregressive-language-model-generate`

A generator-based, stateless autoregressive inference loop for language models compatible with HuggingFace's Transformers API. At each step, it yields logits from the model and expects the caller to send back the predicted next tokens. Easily integrates into custom sampling strategies (greedy, beam, top-k/p, etc).

## Usage

Assume you have:

- `model`
- `input_ids` and `attention_mask`, shape `(batch_size, seq_len)`

```python
import torch
from autoregressive_language_model_generate import autoregressive_language_model_generate

model = ...
input_ids = ...
attention_mask = ...

gen = autoregressive_language_model_generate(
    model,
    input_ids,
    attention_mask
)

logits = next(gen)

# Implement your sampling logic here
next_token_logits = logits[:, -1, :]
top_k = 50
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_scores = next_token_logits.masked_fill(indices_to_remove, -float('Inf'))
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)

# `next_tokens` has shape `(batch_size,)`
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

# Send `next_tokens` to generator, receive `logits`
logits = gen.send(next_tokens)
```

## Contributing

Contributions are welcome! Please submit pull requests or open issues on the GitHub repository.

## License

This project is licensed under the [MIT License](LICENSE).
