
gittech. site
for different kinds of informations and explorations.
RLlama: Teaching LLMs to Learn and Remember with Memory-Augmented RL
Empowering LLMs with Memory-Augmented Reinforcement Learning
๐ GitHub Repository
โข
๐ฆ PyPI Package
โข
โข
RLlama
RLlama is an enhanced fork of LlamaGym, supercharging it with memory-augmented learning capabilities and additional RL algorithms. While LlamaGym pioneered the integration of LLMs with reinforcement learning, RLlama takes it further by introducing episodic memory, working memory, and a broader suite of RL algorithms.
Features
- ๐ง Memory-Augmented Learning with Episodic and Working Memory
- ๐ฎ Multiple RL Algorithms (PPO, DQN, A2C, SAC, REINFORCE, GRPO)
- ๐ Online Learning Support
- ๐ฏ Seamless Integration with Gymnasium
- ๐ Multi-Modal Support (Coming Soon)
Quick Start
Get started with RLlama in seconds:
pip install rllama
Usage
Blackjack Agent Example
from rllama import RLlamaAgent
class BlackjackAgent(RLlamaAgent):
def get_system_prompt(self) -> str:
return """You are an expert blackjack player. Follow these rules:
1. ALWAYS hit if your total is 11 or below
2. With 12-16: hit if dealer shows 7+, stay if 6 or lower
3. ALWAYS stay if your total is 17+ without an ace
4. With a usable ace: hit if total is 17 or below"""
def format_observation(self, observation) -> str:
return f"Current hand total: {observation[0]}\nDealer's card: {observation[1]}\nUsable ace: {'yes' if observation[2] else 'no'}"
def extract_action(self, response: str):
return 0 if "stay" in response.lower() else 1
Text World Agent Example
from rllama import RLlamaAgent
import re
class TextWorldAgent(RLlamaAgent):
def get_system_prompt(self) -> str:
return """You will be playing a text-based game. Here are some example commands:
'go west', 'inventory', 'drop teacup', 'examine broom', 'open door', 'look'."""
def format_observation(self, observation) -> str:
return observation.split("$$$$$$$ \n\n")[-1].strip()
def extract_action(self, response: str) -> str:
command_match = re.search(r"command: (.+?)(?=\n|$)", response, re.IGNORECASE)
return command_match.group(1) if command_match else "look"
Training Examples
Basic Training Loop
import gymnasium as gym
from transformers import AutoTokenizer, AutoModelForCausalLMWithValueHead
# Initialize model and agent
model = AutoModelForCausalLMWithValueHead.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
agent = BlackjackAgent(model, tokenizer, "cuda", algorithm="ppo")
# Training loop
env = gym.make("Blackjack-v1")
for episode in range(1000):
observation, info = env.reset()
done = False
while not done:
action = agent.act(observation)
observation, reward, terminated, truncated, info = env.step(action)
agent.assign_reward(reward)
done = terminated or truncated
agent.terminate_episode()
Example Implementations
Check out our complete examples:
- Blackjack Agent - Classic card game environment
- Text World Agent - Text-based adventure game with memory augmentation
- Multi-Modal Agent (Coming Soon)
Memory-Augmented Learning
RLlama implements two types of memory systems:
- Episodic Memory: Stores and retrieves past experiences
- Working Memory: Maintains context for current decision-making
These systems allow agents to:
- Learn from past experiences
- Maintain context across multiple steps
- Make more informed decisions
- Handle complex, long-term dependencies
Contributing
We welcome contributions! Here's how:
- Fork the repository
- Create your feature branch (
git checkout -b feature/AmazingFeature
) - Commit your changes (
git commit -m 'Add some AmazingFeature'
) - Push to the branch (
git push origin feature/AmazingFeature
) - Open a Pull Request
Relevant Work
- LlamaGym: Fine-tune LLM agents with Online Reinforcement Learning
- Grounding Large Language Models with Online Reinforcement Learning
- Lamorel: Language Models for Reinforcement Learning
Citation
@misc{ch33nchan2024rllama,
title = {RLlama: Memory-Augmented Reinforcement Learning Framework for LLMs},
author = {Ch33nchan},
year = {2024},
publisher = {GitHub},
url = {https://github.com/ch33nchan/RLlama}
}
License
This project is licensed under the MIT License - see the LICENSE file for details.