Tutorials & GuidesExtending RigWrite Your Own Provider

Implementing a Custom LLM Provider

rig-core provides a flexible and modular architecture for building LLM applications. While it includes native support for several major LLM providers (e.g., OpenAI, Anthropic, Cohere), you can easily extend Rig by implementing your own client for a custom or a less common provider.

This guide will walk you through the process of creating a new LLM provider client by implementing the core Rig traits.

Prerequisites

Before you begin, ensure you have a basic understanding of Rust and rig-core’s fundamental concepts, such as CompletionModel and EmbeddingModel.

Defining Your Client

First, create a new struct to represent your provider’s client. This struct will act as a base for whatever models we need to implement. It should store any necessary configuration, like an API key or a custom HTTP client (for example, if a user would like to configure HTTP timeout manually).

const BASE_URL: &str = "https://example.com";
 
#[derive(Clone)]
pub struct MyProviderClient {
    api_key: String,
    http_client: reqwest::Client,
    base_url: String
    // Add any other provider-specific fields here
}
 
impl MyProviderClient {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self::builder().api_key(api_key.into()).build()
    }
 
    pub fn builder() -> MyProviderClientBuilder {
        MyProviderClientBuilder::default()
    }
}
 
#[derive(Default)]
struct MyProviderClientBuilder {
  api_key: Option<String>,
  http_client: Option<reqwest::Client>,
  base_url: Option<String>
}
 
impl MyProviderClientBuilder {
  fn api_key(mut self, api_key: String) -> Self {
    self.api_key = Some(api_key);
    self
  }
 
  fn custom_client(mut self, client: reqwest::Client) -> Self {
      self.http_client = Some(client);
      self
  }
 
  fn build(self) -> MyProviderClient {
    // in practical usage
    // you may wish to use real error handling here
    let api_key = self.api_key.unwrap();
    let http_client = self.http_client.unwrap_or_default();
    let base_url = self.base_url.unwrap_or(BASE_URL);
    MyProviderClient { api_key, http_client, base_url }
  }
}

You may note in the code snippet above that we use the builder pattern to help build our client out. The builder pattern is an extremely common idiom in Rust, and is often quite helpful as many structs may otherwise have instantiation methods that take an unreasonably large amount of arguments. All of the model providers contained in the core Rig libraries use this pattern, as each provider may have different variables that it takes (for example, the Anthropic version may take a beta version as a variable).

Client Safety

While it is useful to implement Debug for our Client struct, you should aim to ensure maximum safety. Users should under no circumstances be able to accidentally log their API key in the terminal or shell, or whatever logging they decide to use. Therefore, rather than simply deriving it we will actually implement this manually.

See below for a Debug implementation which presents the base URL and HTTP client as their default Debug implementation, then redacts the API key with relevant text.

impl std::fmt::Debug for Client {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Client")
            .field("base_url", &self.base_url)
            .field("http_client", &self.http_client)
            .field("api_key", &"<REDACTED>")
            .finish()
    }
}

This can of course be adjusted to your needs, depending on how much detail needs to be removed from the Debug implementation.

Client Traits

There are many different traits that can be implemented. You can find what these are in the rig::client module, but essentially they are as follows:

  • CompletionClient - create a type that can generate Chat Completions.
  • EmbeddingsClient - create a type that can generate embeddings.
  • TranscriptionClient - create a type that can generate transcriptions.
  • AudioGenerationClient - create a type that can generate audio (requires the audio feature flag).
  • ImageGenerationClient - create a type that can generate images (requires the image feature flag).

As you can see, each one represents a different type of possible behaviour that can be carried out by a given client or model in Rig. Generally speaking, these are relatively easy to implement. We will go over the Completion Client API as an example.

Implementing ProviderClient

Before we start, as you many have noticed there is additionally a ProviderClient trait required by every single one of these traits. This particular trait provides a way for you to extend your Client so that it can be generated from environment variables instead of using ::new().

Unfortunately, there’s just one small problem: it requires implementing a trait for every single possible generation type.

Fortunately, there are ways around this. There are a number of traits that we can use:

  • AsCompletion - prevents you from needing to write a completion client.
  • AsEmbeddings - prevents you from needing to write an embeddings client.

And so on and so forth.

You can use rig::impl_conversion_traits to simply implement the types that you want for your Client. In this case, it’s whatever modes we do not have implemented. For example, if the model provider only supports embeddings and chat completions, we would write the following below (taken from the ollama module):

impl_conversion_traits!(
    AsTranscription,
    AsImageGeneration,
    AsAudioGeneration for Client
);

This is required because of the rig::DynClientBuilder which allows for very convenient client generation without egregious boilerplate.

Additionally, you may notice some ClientDyn traits which are required to be implemented for ProviderClient to actually be implemented - each one is individually implemented when its Model equivalent is implemented. For example, CompletionClientDyn is automatically implemented for types that implement CompletionClient, where the resulting model type also implements CompletionModel.

Next, we will look at implementing CompletionClient and CompletionModel. The other traits are implemented in a similar manner (and therefore won’t be covered), although the return type will differ depending on the trait being implemented.

Implementing a Completion Client

The CompletionClient trait looks like this:

pub trait CompletionClient: ProviderClient + Clone {
    /// The type of CompletionModel used by the client.
    type CompletionModel: CompletionModel;
 
    fn completion_model(&self, model: &str) -> Self::CompletionModel;
 
    // there are other methods that exist on this client, but we can ignore those for now
}

So, we can implement the client like so:

impl CompletionClient for MyProviderClient {
    type CompletionModel = CompletionModel;
 
    fn completion_model(&self, model: &str) -> Self::CompletionModel {
        CompletionModel { client: self.clone(), model_name: model.to_string() }
    }
}
 
struct CompletionModel {
    client: Client,
    model_name: String
}

However, our work is not done here yet. Let’s have a look at how to implement the CompletionModel trait.

Implementing a Completion Model

For each client type, there is also a model type. In this case, the completion type is no different: we have a trait for chat completion clients, as well as a trait for chat completion models (CompletionModel).

The trait looks like this:

pub trait CompletionModel:
    Clone
    + Send
    + Sync {
    type Response: Send + Sync + Serialize + DeserializeOwned;
    type StreamingResponse: Clone + Unpin + Send + Sync + Serialize + DeserializeOwned + GetTokenUsage;
 
    // Required methods
    fn completion(
        &self,
        request: CompletionRequest,
    ) -> impl Future<Output = Result<CompletionResponse<Self::Response>, CompletionError>> + Send;
 
    fn stream(
        &self,
        request: CompletionRequest,
    ) -> impl Future<Output = Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>> + Send;
}

For most completion model types that revolve around HTTP API calls, typically the go-to model here would be converting the CompletionRequest type to a JSON body that you then send off to the model provider API. Generally speaking however, the shape of this JSON body depends on the API. The recommended way (if you don’t think the API will change much) is to create a type that represents the completion request then use TryFrom<(String, CompletionRequest)> (where String is the model name as CompletionRequest won’t have it). However, you can also use the serde_json::json!() macro - but you will need to be careful that you don’t accidentally make any mistakes while trying to form the JSON up. You will also of course need to provide a completion response type for the HTTP response to be deserialized to (typically a struct).

Should you decide to use raw JSON instead of using a TryFrom<T> conversion then converting to JSON, you may need some additional helper functions for composing your JSON. While the json_utils module in Rig is internal (since it is not intended to be exposed as a public API), you can simply copy what functions you need from the module into your own crate.

For some inspiration on writing your own streaming response functionality, here are some implementations that already exist in the Rig library:

Converting messages

When writing your own integration, you will more than likely need to write your own message type (unless it has complete parity with the OpenAI message type because of OpenAI compatibility). Each internal message in Rig is either a User or Assistant message composed of one or more user or assistant content types, respectively. For example, a user message may contain text contain as well as an image URL (or base64 encoded string) to be parsed by an LLM.

The most common way to solve this in the official Rig providers module is to use TryFrom<rig::completion::Message> for Vec<T>. However, when doing this in your own local crate(s) you may find that this is impossible to implement because of the orphan rule.

The reason why we use Vec<T> is that if you go to Rig’s core Message abstraction, you can see that each user message actually has multiple content parts. While this satisfies the OpenAI, Anthropic, Gemini, etc… providers there are also many providers that are unable to use multiple content parts in one message. However, if you are in a situation where the provider message can also contain multiple parts, you can simply return a single-element Vec<T>.

There are a couple of ways to get around this:

The easiest way to get around this is using a free-standing function:

fn convert_rig_message_to_my_message(message: rig::completion::Message) -> Result<Vec<Message>, MyError> {
    // .. some code
}
 
// practical usage
let messages: Vec<Message> = messages
    .into_iter()
    .map(convert_rig_message_to_my_message)
    .collect::<Result<Vec<Message>, MyError>>()?
    .into_iter()
    .flatten()
    .collect();