gittech. site

for different kinds of informations and explorations.

RLlama: Teaching LLMs to Learn and Remember with Memory-Augmented RL

Published at
3 days ago

RLlama

Empowering LLMs with Memory-Augmented Reinforcement Learning

Python Version

๐Ÿ”— GitHub Repository   โ€ข   ๐Ÿ“ฆ PyPI Package   โ€ข   RLlama - Memory-Augmented Reinforcement Learning for LLMs   โ€ข   Hacker News

RLlama

RLlama is an enhanced fork of LlamaGym, supercharging it with memory-augmented learning capabilities and additional RL algorithms. While LlamaGym pioneered the integration of LLMs with reinforcement learning, RLlama takes it further by introducing episodic memory, working memory, and a broader suite of RL algorithms.

Features

  • ๐Ÿง  Memory-Augmented Learning with Episodic and Working Memory
  • ๐ŸŽฎ Multiple RL Algorithms (PPO, DQN, A2C, SAC, REINFORCE, GRPO)
  • ๐Ÿ”„ Online Learning Support
  • ๐ŸŽฏ Seamless Integration with Gymnasium
  • ๐Ÿš€ Multi-Modal Support (Coming Soon)

Quick Start

Get started with RLlama in seconds:

pip install rllama

Usage

Blackjack Agent Example

from rllama import RLlamaAgent

class BlackjackAgent(RLlamaAgent):
    def get_system_prompt(self) -> str:
        return """You are an expert blackjack player. Follow these rules:
        1. ALWAYS hit if your total is 11 or below
        2. With 12-16: hit if dealer shows 7+, stay if 6 or lower
        3. ALWAYS stay if your total is 17+ without an ace
        4. With a usable ace: hit if total is 17 or below"""

    def format_observation(self, observation) -> str:
        return f"Current hand total: {observation[0]}\nDealer's card: {observation[1]}\nUsable ace: {'yes' if observation[2] else 'no'}"

    def extract_action(self, response: str):
        return 0 if "stay" in response.lower() else 1

Text World Agent Example

from rllama import RLlamaAgent
import re

class TextWorldAgent(RLlamaAgent):
    def get_system_prompt(self) -> str:
        return """You will be playing a text-based game. Here are some example commands: 
        'go west', 'inventory', 'drop teacup', 'examine broom', 'open door', 'look'."""

    def format_observation(self, observation) -> str:
        return observation.split("$$$$$$$ \n\n")[-1].strip()

    def extract_action(self, response: str) -> str:
        command_match = re.search(r"command: (.+?)(?=\n|$)", response, re.IGNORECASE)
        return command_match.group(1) if command_match else "look"

Training Examples

Basic Training Loop

import gymnasium as gym
from transformers import AutoTokenizer, AutoModelForCausalLMWithValueHead

# Initialize model and agent
model = AutoModelForCausalLMWithValueHead.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
agent = BlackjackAgent(model, tokenizer, "cuda", algorithm="ppo")

# Training loop
env = gym.make("Blackjack-v1")
for episode in range(1000):
    observation, info = env.reset()
    done = False
    
    while not done:
        action = agent.act(observation)
        observation, reward, terminated, truncated, info = env.step(action)
        agent.assign_reward(reward)
        done = terminated or truncated
    
    agent.terminate_episode()

Example Implementations

Check out our complete examples:

Memory-Augmented Learning

RLlama implements two types of memory systems:

  1. Episodic Memory: Stores and retrieves past experiences
  2. Working Memory: Maintains context for current decision-making

These systems allow agents to:

  • Learn from past experiences
  • Maintain context across multiple steps
  • Make more informed decisions
  • Handle complex, long-term dependencies

Contributing

We welcome contributions! Here's how:

  1. Fork the repository
  2. Create your feature branch (git checkout -b feature/AmazingFeature)
  3. Commit your changes (git commit -m 'Add some AmazingFeature')
  4. Push to the branch (git push origin feature/AmazingFeature)
  5. Open a Pull Request

Relevant Work

Citation

@misc{ch33nchan2024rllama,
    title = {RLlama: Memory-Augmented Reinforcement Learning Framework for LLMs},
    author = {Ch33nchan},
    year = {2024},
    publisher = {GitHub},
    url = {https://github.com/ch33nchan/RLlama}
}

License

This project is licensed under the MIT License - see the LICENSE file for details.