Samhita Alla

Fine-Tuning Insights: Using LLMs as Preprocessors to Improve Dataset Quality

Made with DALL·E 2

LLMs for data cleaning: Yay or nay?

In Part 1 of this series, I fine-tuned a RedPajama Large Language Model (LLM) using Flyte Slack data. During this process, I discovered that the fine-tuned model's predictions were suboptimal, primarily due to issues related to dataset quality. Dataset quality depends on factors such as the quality of input-output pairs, the provided prompts for prediction generation, prompt format, the choice of the model itself and more.

This part of our fine-tuning series takes a closer look at when to choose fine-tuning and how I tackled the task of creating a high-quality dataset. Please note that there's no readily available benchmark for Q&A on Slack-based datasets as we don’t have data labelers to create a gold standard clean dataset.

When to fine-tune?

Niels Bantilan provides clear guidance on when to choose prompt engineering or fine-tuning. In the context of semantic embeddings versus fine-tuning, the semantic embeddings strategy is suitable when the objective is to impart knowledge to the model. On the other hand, fine-tuning proves advantageous when the model needs to acquire specialized skills or exhibit specific behaviors. 

Fine-tuning vs prompt engineering vs embeddings (Source)

High-quality Slack dataset

I attribute the failure of my experiment to the quality of my dataset. If the dataset had contained high-quality responses, the performance would have been improved. However, as I mentioned earlier, my ultimate goal is not just to mimic the style of the responses, but also to impart knowledge to the model. This raises a conflict between the two strategies for LLM that I proposed earlier: semantic embeddings and fine-tuning.

After careful consideration, I have devised a pipeline that combines semantic embeddings and fine-tuning into a single process.

Fine-tuning + semantic embeddings

My objectives for this pipeline are twofold:

  1. Generate responses that are clear and unambiguous.
  2. Create a dataset containing a limited number of question-response pairs.

I opted to create a single question-response pair from each Slack thread to prevent generating a large dataset. Following Andrej Karpathy's recommendation for effective fine-tuning (10k-100k prompts), quality takes precedence over quantity.

While assembling the dataset, I could have chosen to generate a smaller number of question-response pairs by selecting just one from each thread. However, this approach could have led to incoherent responses due to multiple contexts within a thread. Consequently, the responses might lack coherence and relevance within the intended context.

Hence, I excluded parsing the Slack threads that generate multiple question-response pairs, and made the minor modifications to the thread-generation code displayed below to generate a single question-response pair from each Slack thread.

Copied to clipboard!
for data_file in sorted_list_of_files:
    with open(data_file) as f:
        list_of_messages = json.load(f)

    for message in list_of_messages:
        if "reply_count" in message and message["reply_count"] > 0:
            threads.append(
                {
                    "input": message.get("user", "bot")
                    + ": ```"
                    + replace_user_id_with_name(message["text"], user_mapping)
                    + "```",
                    "output": "",
                }
            )
            ...
        else:
            if (
                "thread_ts" in message
                and message["thread_ts"] in thread_ts_list_index_pairs
            ):
                threads[thread_ts_list_index_pairs[message["thread_ts"]]][
                    "output"
                ] += (
                    message.get("user", "bot")
                    + ": ```"
                    + replace_user_id_with_name(message["text"], user_mapping)
                    + "```\n"
                )

Both the `input` and `output` are in string format, as no post-processing is necessary.

After successfully generating inputs and outputs from the Slack threads, I proceeded by leveraging an LLM to enhance the quality of Slack responses. This decision was based on the assumption that using LLMs would significantly improve the overall quality of the dataset. While this could be accomplished manually, I wanted to explore whether an LLM could achieve this task.

Minimum response length

Before diving into the LLM approach, I decided to address the problem of the model generating short responses by setting a minimum response length.

Copied to clipboard!
output_messages = list(output_messages.values())[0]
if len(output_messages) > 180:
    pairs.append(
        {
            "input": list(input_messages.values())[0],
            "output": output_messages,
        }
    )

Enhancing Slack responses with RedPajama and Mosaic

I employed the RedPajama-INCITE-7B-Chat LLM to enhance the quality of Slack responses and generate high-quality outputs. The main objective was to generate clear and refined responses/outputs.

I used the following prompt to generate the responses:

Copied to clipboard!
prompt = f"""
<human>:
Instruction:
You are a helpful slack bot.
You provide answers to user questions on Slack.
Given a user question, your job is to provide an answer to it.
Take help from the context and ensure that your answer appears as if it were provided by a bot, without expressing any personal opinions.
Avoid referencing the context and focus on addressing the user's question directly.
The original user answer consists of responses from multiple users, but your answer has to have a bot-like tone.

User question:
{input}
Context:
{output}\n
<bot>:"""

Prompt provided to the RedPajama model

You can access the complete script here. The `input` represents the question, while the `output` corresponds to the response. Unfortunately, the model hallucinated, which could be attributed to the quality of my prompt. (I’m unsure of other ways to improve it. If you have any ideas, I would love to hear them!) 

To explore alternatives, I decided to try the Mosaic MPT-7B-Chat model. Interestingly, the responses obtained using Mosaic were much better than those from RedPajama. 

Copied to clipboard!
prompt = f"""
Instruction:
<|im_start|>
You are a helpful slack bot.
You provide answers to user questions on Slack.
Given a user question, your job is to provide an answer to it.
Take help from the context and ensure that your answer appears as if it were provided by a bot, without expressing any personal opinions.
Avoid referencing the context and focus on addressing the user's question directly.
The original user answer consists of responses from multiple users, but your answer has to have a bot-like tone.
<|im_end|>
User question:
<|im_start|>
{input}
<|im_end|>
Context:
<|im_start|>
{output}
<|im_end|>
Your answer:<|im_start|>
"""

{{fine-tuning-01="/blog-component-assets"}}

The responses generated by RedPajama were inadequate. However, Mosaic demonstrated potential and showed promise.

The complete script to improve the Slack responses using the Mosaic model is available here. You can access the actual responses in the dataset by searching for the user question, where you'll find the complete Slack conversation.

Llama 2

I also decided to test the Llama 2 7B model to evaluate response quality. I used the following prompt to generate the responses:

Copied to clipboard!
prompt = f"""<s>[INST] <<SYS>>
You are a helpful slack bot.
You provide answers to user questions on Slack.
Given a user question, your job is to provide an answer to it.
Take help from the context and ensure that your answer appears as if it were provided by a bot, without expressing any personal opinions.
Avoid referencing the context and focus on addressing the user's question directly.
The original user answer consists of responses from multiple users, but your answer has to have a bot-like tone.
<</SYS>>

Context: {output}

{input} [/INST]

Prompt provided to the Llama 2 7B model

I loaded the Llama-2-7B-chat-hf model in 4-bit to reduce its memory usage and employed the `pipeline()` function for prediction generation. I set `return_full_text` to `False` to retrieve the added text while excluding the entire prompt.

Copied to clipboard!
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    quantization_config=bnb_config,
    ...
)

pipeline = transformers.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
)

sequences = pipeline(
    prompt,
    do_sample=True,
    top_k=10,
    ...,
)

print(f"Result: {sequences[0]['generated_text']}")

Access the complete script here

The outputs generated by the model are as follows:

{{fine-tuning-02="/blog-component-assets"}}

From the outputs, it's clear that Llama uses Unicode characters to include emojis in the text, which is desirable for Slack responses. Furthermore, the outputs are comprehensive. However, I observed that the 7B model often hallucinates. Therefore, I experimented with the Llama-13B-Chat-model using the following configuration:

Copied to clipboard!
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    llm_int8_enable_fp32_cpu_offload=True,
)

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-13b-chat-hf",
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map={
        "model.layers": 0,
        "lm_head": "cpu",
        "model.norm": 0,
        "model.embed_tokens": 0,
    },
    trust_remote_code=True,
    token=token,
)

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-13b-chat-hf",
    token=token,
)

I enabled `llm_int8_enable_fp32_cpu_offload` to facilitate offloading between CPU and GPU since the 13B model doesn't fit on a g4dn.metal instance, and I initialized a custom `device_map` to offload `lm_head` to a CPU.

The following are some responses that the model generated:

{{fine-tuning-03="/blog-component-assets"}}

I have observed that the 13B model hallucinates much less compared to the 7B model. However, even when I explicitly instruct the model not to reference the given context in the prompt, it occasionally still does. (Interestingly, the 13B chat model doesn't include emojis like the 7B model, even when provided with the same prompt.)

GPT-3.5 Turbo

I also evaluated the GPT-3.5 Turbo model, just to see how the outputs compare with the previous models’ predictions.

Copied to clipboard!
def improve_slack_response(input, output):
    prompt = f"""
### Instruction:

You are a helpful slack bot. You provide answers to user questions on Slack. Given a user question, your job is to provide an answer to it. Take help from the context and ensure that your answer appears as if it were provided by a bot, without expressing any personal opinions. Avoid referencing the context and focus on addressing the user's question directly. The original user answer consists of responses from multiple users, but your answer has to have a bot-like tone.

### User question:
{input}

### Context:
{output}
    """
    OPENAI_API_KEY = flytekit.current_context().secrets.get(
        SECRET_GROUP_OPENAI, SECRET_NAME_OPENAI
    )
    openai.api_key = OPENAI_API_KEY
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3,
    )
    return response.choices[0].message["content"]

Access the complete script here

{{fine-tuning-04="/blog-component-assets"}}

The responses generated by GPT-3.5 Turbo are remarkably close to the actual intended responses.

The outputs generated by GPT-3.5 are comparable to those generated by the Llama 2 13B model, despite the substantial size difference between GPT-3.5 (175 billion parameters) and the Llama 2 model (13 billion parameters). Overall, however, GPT's outputs outperform Llama 2's from a qualitative perspective.

The response in the third row seems to have inferred from the actual Slack response that the original question asker made an attempt to get xxx to work, even though that aspect wasn't explicitly mentioned in the question. I wouldn't necessarily attribute this foresight to the model's performance, but rather to the nature of our data and how it is presented.

Context length is an important factor to consider. GPT-3.5 and Llama have a context length of 4,096 tokens (with one GPT variant offering an extended context of 16,385 tokens and Llama 2's variants going up to 32,000 tokens), while Mosaic and RedPajama have a context length of 2,048 tokens (configurable for Mosaic). Context length refers to the total number of tokens allowed by the model, so it's crucial to ensure that the prompt stays within the specified limit.

Infrastructure

I executed the mentioned pipelines on Union Cloud using a g4dn.metal instance for inference. Running the inference was as simple as allocating a GPU to the relevant task. You have the option to run the pipeline in parallel to process multiple messages concurrently to create a dataset using batch inference.

Conclusion

Overall, GPT-3.5 and Llama 2 13B models emerge as the top performers. While GPT-3.5's outputs are slightly superior to Llama 2 13B, I lean towards using Llama for data-cleaning purposes, primarily because it is freely available for research use. There's potential to produce much cleaner responses with the Llama 2 70B model, which is yet to be explored.

Here is my take on the quality of the responses produced by different LLMs:

{{fine-tuning-05="/blog-component-assets"}}

Next steps

I believe there is room for improvement in prompt engineering to enhance the responses generated by LLMs. One approach could involve considering subsequent questions and responses within each Slack thread. This would help avoid mixing contexts between different dialogues within the same thread. However, this approach would result in a substantial dataset due to the potentially large number of question-response pairs within each Slack thread (the issue I mentioned earlier).

There are three viable options to consider:

  1. Continue improving Slack responses (and questions?) using the LLM by tweaking prompts (perhaps with few-shot prompting?).
  2. Consider Slack messages in which the original poster acknowledges that the question has been answered. (However, this approach may sometimes result in huge prompts, and the prompt may need to be shortened in accordance with the context length.).
  3. Introduce manual intervention to select or write suitable responses. This can be built on top of the existing model as a feedback mechanism to improve the model by creating a gold-standard dataset. This approach allows more precise control over the responses generated, ensuring they align closely with the desired outcome.

I am also thinking about using a different model for fine-tuning because the current prompt format doesn't seem to align well with the RedPajama model's expectations. 

While generating responses, another aspect that I considered is the issue of data staleness. It raises questions about how much older data should be included for fine-tuning the model. If the data used for fine-tuning is too outdated, it may contain stale or irrelevant information.

My team and I have been actively brainstorming to determine the best course of action for the next steps. I would greatly appreciate your feedback and any ideas you may have regarding improvements or further steps that I could consider. 

To summarize my experimentation journey and key takeaways: I experimented with LLMs to create a high-quality dataset. During this process, I evaluated various models, and Llama 2 and GPT outperformed both RedPajama and Mosaic. The next steps involve dedicating more time to dataset cleanup and preparation with batch inference before proceeding with fine-tuning and semantic embeddings.

The code for this project is available in our GitHub repository: https://github.com/unionai-oss/llm-fine-tuning/tree/main/redpajama-lora.

Article