-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathprod-generation.ts
110 lines (98 loc) · 2.55 KB
/
prod-generation.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import { cookies } from 'next/headers'
import { createMistral, MistralProvider, mistral as originalMistral } from '@ai-sdk/mistral'
import { createOpenAI, OpenAIProvider } from '@ai-sdk/openai'
import { customProvider, generateObject, generateText, LanguageModelV1 } from 'ai'
import { AIProvider } from '@/types/ai'
import { COOKIE_NAME } from '@/constants'
export function getAIModel({ provider }: { provider: AIProvider | undefined }) {
let defaultProvider: MistralProvider | OpenAIProvider = originalMistral
let defaultModel = 'mistral-small-latest'
if (provider) {
const userEnteredApiKey = cookies().get(COOKIE_NAME)?.value
if (!userEnteredApiKey || userEnteredApiKey.trim().length === 0) return
if (provider === 'Mistral') {
defaultProvider = createMistral({
apiKey: userEnteredApiKey
})
} else {
defaultProvider = createOpenAI({
apiKey: userEnteredApiKey
})
defaultModel = 'openai-gpt4'
}
}
const customModel = customProvider({
languageModels: {
'mistral-small-latest': defaultProvider('mistral-small-latest'),
'openai-gpt4': defaultProvider('gpt-4o')
},
fallbackProvider: originalMistral
})
return customModel.languageModel(defaultModel)
}
export const generateCompletion = async ({
model,
prompt
}: {
model: LanguageModelV1
prompt: string
}) => {
const { text } = await generateText({
model,
messages: [
{
role: 'system',
content: 'You are a tech lead.'
},
{
role: 'user',
content: prompt
}
],
temperature: 0.7,
maxTokens: 1024
})
return text
}
export const generateCompletionWithSchema = async ({
model,
prompt,
schema
}: {
model: LanguageModelV1
prompt: string
schema: any
}) => {
const { object } = await generateObject({
model,
schema,
messages: [
{
role: 'system',
content: 'You are a tech lead.'
},
{
role: 'user',
content: prompt
}
],
temperature: 0.7,
maxTokens: 1024
})
// @ts-expect-error
const data = object.data
return data
}
export const handleGenerationErrors = ({ error }: { error: unknown }) => {
let errorMessage = 'An error has ocurred with API Completions. Please try again.'
// @ts-ignore
if (error.status === 401) {
errorMessage = 'Incorrect API Key provided. Please enter a new one.'
}
// @ts-ignore
const { name, status, headers } = error
return {
errorData: { name, status, headers, error: errorMessage, data: undefined },
status: { status }
}
}