This tutorial will guide you through the steps to fine-tune LLaMA3 models on Bagel. We will start by creating a raw asset, downloading a dataset, purchasing a LLaMA3 model from the Bakery marketplace, and finally, fine-tuning the model.
Prerequisites
Create Bakery Account from
Get user_id & api_key from bakery.
Google Colab (Chrome Browser).
Installation
To install the Bagel Python client, run the following commands in your terminal:
This initializes the Bagel client, which will be used to interact with the Bagel server throughout the tutorial.
import os
from getpass import getpass
# Copy & Paste the API Key from https://bakery.bagel.net/api-key
DEMO_KEY_IN_USE = getpass("Enter your API key: ")
# Set environment variable
os.environ['BAGEL_API_KEY'] = DEMO_KEY_IN_USE
Here, we'll prompt the user to enter their API key securely (without displaying it on the screen) and then sets it as an environment variable named BAGEL_API_KEY.
Step 2: Create Asset on Bagel
This code defines a function create_asset that creates a new dataset /asset using the bagel client and returns the resulting asset ID after creation.
To fine-tune a Llama3 model, you first need to purchase it from the Bakery marketplace:
Log in to your Bagel account and navigate to the Model tab on your Bakery homepage.
Browse the available models and purchase the Llama3-8b model for this example.
It is a large language model designed for advanced text generation and dialogue applications. It is part of the Meta Llama 3 family, which includes both an 8 billion parameter and a 70 billion parameter version. This model utilizes an optimized transformer architecture and is fine-tuned with supervised learning and reinforcement learning with human feedback to enhance its helpfulness and safety.
After purchasing, the model will be available under the My Models section.
You can also use your CLI to purchase assets by running this code:
Buy LLama3 Model From Bakery Marketplace
llama3_model_id = "3323b6c4-06ef-4949-b239-1a2b220e211d" #This is LLama3-8b model
response = client.buy_asset(
asset_id=llama3_model_id,
user_id=user_id
)
print(response)
Step 6: Fine-tune the LLama3 Model on Bakery
Now that you have your dataset and model, you can fine-tune the LLaMA3 model using the following code:
# Getting Column names from selected Raw dataset
client.get_dataset_column_names(asset_id=asset_id, file_name=filename)
!mkdir adapter_model
# Please provide the model zip file path
!unzip /content/{model_id}.zip -d /content/adapter_model
Step 9: Load Model using transformers
This code loads the pre-trained language model with an adapter using a configuration that minimizes GPU memory usage by quantizing the model to 8-bit precision. It first loads the adapter configuration to determine the base model, then initializes the model with the specified quantization settings. The tokenizer is loaded separately, and the model is only loaded when needed, ensuring efficient resource management.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
# Adapter model path
adapter_path = "adapter_model"
# Load only the adapter config initially
peft_config = PeftConfig.from_pretrained(adapter_path)
base_model_name = "bagelnet/Llama-3-8B"
# Function to load the model with minimal GPU memory usage
def load_model():
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
quantization_config=quantization_config,
device_map="auto",
trust_remote_code=True
)
model = PeftModel.from_pretrained(base_model, adapter_path)
return model
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
# This function will load the model only when needed
model = None
if model is None:
print("Loading model... This may take a moment.")
model = load_model()
print("Model loaded successfully.")
Step 10: ChatBot
Now lets use a generate_response function to generate a relevant and concise reply based on the finetuned model. We'll be using specific generation parameters to ensure a relevant and concise reply.
def generate_response(conversation_history, max_length=200):
# Prepare the prompt
prompt = f"You are a helpful AI assistant. Provide a concise and relevant answer to the user's question. Only reply to the question\n\n{conversation_history}"
# prompt += conversation_history
# prompt += "Assistant: "
# Encode the input prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
# Generate a response
with torch.no_grad():
output = model.generate(
input_ids,
max_length=input_ids.shape[1] + max_length,
num_return_sequences=1,
# no_repeat_ngram_size=2,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id,
)
# Decode and return the response
response = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
return response.strip()
Lets use a chatbot function to initiate a simple text-based chatbot that interacts with the user in a loop, continuously prompting for user input and generating responses based on the input. It prints a welcome message, processes user input to generate a response using the generate_response function, and displays the chatbot's reply. The loop terminates when the user types "quit," ending the conversation and printing a farewell message.
def chatbot():
print("Chatbot: Hello! I'm your AI assistant. How can I help you today? (Type 'quit' to exit)")
while True:
conversation = ''
user_input = input("You: ").strip()
if user_input.lower() == 'quit':
print("Chatbot: Goodbye! Have a great day!")
break
conversation = f"{user_input}\n"
response = generate_response(conversation)
print("Chatbot:", response)
if __name__ == "__main__":
chatbot()
You can upload any text, CSV, JSON or Parquet format file. Generally, parquet format is widely used for LLama3 model fine-tuning. You can generate your own parquet format data from
- Please refer to the google colab tutorial for interactive development
Need extra help or just want to connect with other developers? Join our community on 👾