Skip to content

Commit 33801ba

Browse files
wxgeorgeSBrandeishanouticelina
authored
✨ Support for Featherless.ai as inference provider. (#1310)
Implements support for Featherless.ai as inference provider fully for chat and partially for completions (streaming completions to be covered in a future PR) --------- Co-authored-by: SBrandeis <simon@huggingface.co> Co-authored-by: Celina Hanouti <hanouticelina@gmail.com>
1 parent d119c63 commit 33801ba

File tree

6 files changed

+134
-1
lines changed

6 files changed

+134
-1
lines changed

packages/inference/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ You can send inference requests to third-party providers with the inference clie
4848

4949
Currently, we support the following providers:
5050
- [Fal.ai](https://fal.ai)
51+
- [Featherless AI](https://featherless.ai)
5152
- [Fireworks AI](https://fireworks.ai)
5253
- [Hyperbolic](https://hyperbolic.xyz)
5354
- [Nebius](https://studio.nebius.ai)
@@ -78,6 +79,7 @@ When authenticated with a third-party provider key, the request is made directly
7879

7980
Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
8081
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
82+
- [Featherless AI supported models](https://huggingface.co/api/partners/featherless-ai/models)
8183
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
8284
- [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models)
8385
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)

packages/inference/src/lib/getProviderHelper.ts

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ import * as BlackForestLabs from "../providers/black-forest-labs";
22
import * as Cerebras from "../providers/cerebras";
33
import * as Cohere from "../providers/cohere";
44
import * as FalAI from "../providers/fal-ai";
5+
import * as FeatherlessAI from "../providers/featherless-ai";
56
import * as Fireworks from "../providers/fireworks-ai";
67
import * as Groq from "../providers/groq";
78
import * as HFInference from "../providers/hf-inference";
8-
99
import * as Hyperbolic from "../providers/hyperbolic";
1010
import * as Nebius from "../providers/nebius";
1111
import * as Novita from "../providers/novita";
@@ -64,6 +64,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
6464
"text-to-video": new FalAI.FalAITextToVideoTask(),
6565
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
6666
},
67+
"featherless-ai": {
68+
conversational: new FeatherlessAI.FeatherlessAIConversationalTask(),
69+
"text-generation": new FeatherlessAI.FeatherlessAITextGenerationTask(),
70+
},
6771
"hf-inference": {
6872
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
6973
conversational: new HFInference.HFInferenceConversationalTask(),

packages/inference/src/providers/consts.ts

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
2323
cerebras: {},
2424
cohere: {},
2525
"fal-ai": {},
26+
"featherless-ai": {},
2627
"fireworks-ai": {},
2728
groq: {},
2829
"hf-inference": {},
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import type { ChatCompletionOutput, TextGenerationInput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
2+
import { InferenceOutputError } from "../lib/InferenceOutputError";
3+
import type { BodyParams } from "../types";
4+
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";
5+
6+
interface FeatherlessAITextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
7+
choices: Array<{
8+
text: string;
9+
finish_reason: TextGenerationOutputFinishReason;
10+
seed: number;
11+
logprobs: unknown;
12+
index: number;
13+
}>;
14+
}
15+
16+
const FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";
17+
18+
export class FeatherlessAIConversationalTask extends BaseConversationalTask {
19+
constructor() {
20+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
21+
}
22+
}
23+
24+
export class FeatherlessAITextGenerationTask extends BaseTextGenerationTask {
25+
constructor() {
26+
super("featherless-ai", FEATHERLESS_API_BASE_URL);
27+
}
28+
29+
override preparePayload(params: BodyParams<TextGenerationInput>): Record<string, unknown> {
30+
return {
31+
...params.args,
32+
...params.args.parameters,
33+
model: params.model,
34+
prompt: params.args.inputs,
35+
};
36+
}
37+
38+
override async getResponse(response: FeatherlessAITextCompletionOutput): Promise<TextGenerationOutput> {
39+
if (
40+
typeof response === "object" &&
41+
"choices" in response &&
42+
Array.isArray(response?.choices) &&
43+
typeof response?.model === "string"
44+
) {
45+
const completion = response.choices[0];
46+
return {
47+
generated_text: completion.text,
48+
};
49+
}
50+
throw new InferenceOutputError("Expected Featherless AI text generation response format");
51+
}
52+
}

packages/inference/src/types.ts

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ export const INFERENCE_PROVIDERS = [
4242
"cerebras",
4343
"cohere",
4444
"fal-ai",
45+
"featherless-ai",
4546
"fireworks-ai",
4647
"groq",
4748
"hf-inference",

packages/inference/test/InferenceClient.spec.ts

+73
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,79 @@ describe.skip("InferenceClient", () => {
10451045
TIMEOUT
10461046
);
10471047

1048+
describe.concurrent(
1049+
"Featherless",
1050+
() => {
1051+
HARDCODED_MODEL_INFERENCE_MAPPING["featherless-ai"] = {
1052+
"meta-llama/Llama-3.1-8B": {
1053+
providerId: "meta-llama/Meta-Llama-3.1-8B",
1054+
hfModelId: "meta-llama/Llama-3.1-8B",
1055+
task: "text-generation",
1056+
status: "live",
1057+
},
1058+
"meta-llama/Llama-3.1-8B-Instruct": {
1059+
providerId: "meta-llama/Meta-Llama-3.1-8B-Instruct",
1060+
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
1061+
task: "text-generation",
1062+
status: "live",
1063+
},
1064+
};
1065+
1066+
it("chatCompletion", async () => {
1067+
const res = await chatCompletion({
1068+
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
1069+
model: "meta-llama/Llama-3.1-8B-Instruct",
1070+
provider: "featherless-ai",
1071+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1072+
temperature: 0.1,
1073+
});
1074+
1075+
expect(res).toBeDefined();
1076+
expect(res.choices).toBeDefined();
1077+
expect(res.choices?.length).toBeGreaterThan(0);
1078+
1079+
if (res.choices && res.choices.length > 0) {
1080+
const completion = res.choices[0].message?.content;
1081+
expect(completion).toBeDefined();
1082+
expect(typeof completion).toBe("string");
1083+
expect(completion).toContain("two");
1084+
}
1085+
});
1086+
1087+
it("chatCompletion stream", async () => {
1088+
const stream = chatCompletionStream({
1089+
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
1090+
model: "meta-llama/Llama-3.1-8B-Instruct",
1091+
provider: "featherless-ai",
1092+
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
1093+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1094+
let out = "";
1095+
for await (const chunk of stream) {
1096+
if (chunk.choices && chunk.choices.length > 0) {
1097+
out += chunk.choices[0].delta.content;
1098+
}
1099+
}
1100+
expect(out).toContain("2");
1101+
});
1102+
1103+
it("textGeneration", async () => {
1104+
const res = await textGeneration({
1105+
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
1106+
model: "meta-llama/Llama-3.1-8B",
1107+
provider: "featherless-ai",
1108+
inputs: "Paris is a city of ",
1109+
parameters: {
1110+
temperature: 0,
1111+
top_p: 0.01,
1112+
max_tokens: 10,
1113+
},
1114+
});
1115+
expect(res).toMatchObject({ generated_text: "2.2 million people, and it is the" });
1116+
});
1117+
},
1118+
TIMEOUT
1119+
);
1120+
10481121
describe.concurrent(
10491122
"Replicate",
10501123
() => {

0 commit comments

Comments
 (0)