You've already forked grok
312 lines
6.5 KiB
Go
312 lines
6.5 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"log/slog"
|
||
"os"
|
||
"os/signal"
|
||
"strconv"
|
||
"strings"
|
||
|
||
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
|
||
"github.com/joho/godotenv"
|
||
"github.com/openai/openai-go/v3"
|
||
"github.com/openai/openai-go/v3/option"
|
||
"github.com/openai/openai-go/v3/responses"
|
||
"github.com/openai/openai-go/v3/shared"
|
||
)
|
||
|
||
type OpenAIPrompter struct {
|
||
cli openai.Client
|
||
cfg *Config
|
||
log *slog.Logger
|
||
}
|
||
|
||
func composeSysPromptWithContext(systemPrompt string, req PromptRequest) string {
|
||
var b strings.Builder
|
||
|
||
b.WriteString(systemPrompt)
|
||
b.WriteString("\nИмя пользователя: ")
|
||
b.WriteString(req.Username)
|
||
|
||
if req.ReplyToContent != nil {
|
||
b.WriteString("\nПользователь отсылается на текст сообщения: ")
|
||
b.WriteString(*req.ReplyToContent)
|
||
}
|
||
|
||
return b.String()
|
||
}
|
||
|
||
func (p *OpenAIPrompter) Prompt(ctx context.Context, req PromptRequest) (*HighlyTrustedResponse, error) {
|
||
p.log.Info("new prompt request",
|
||
"req", req)
|
||
|
||
sysPrompt := composeSysPromptWithContext(p.cfg.SystemPrompt, req)
|
||
|
||
input := []responses.ResponseInputItemUnionParam{
|
||
{
|
||
OfMessage: &responses.EasyInputMessageParam{
|
||
Content: responses.EasyInputMessageContentUnionParam{
|
||
OfString: openai.String(sysPrompt),
|
||
},
|
||
Role: responses.EasyInputMessageRoleSystem,
|
||
},
|
||
},
|
||
{
|
||
OfMessage: &responses.EasyInputMessageParam{
|
||
Content: responses.EasyInputMessageContentUnionParam{
|
||
OfString: openai.String(req.Question),
|
||
},
|
||
Role: responses.EasyInputMessageRoleUser,
|
||
},
|
||
},
|
||
}
|
||
|
||
resp, err := p.cli.Responses.New(ctx, responses.ResponseNewParams{
|
||
Input: responses.ResponseNewParamsInputUnion{
|
||
OfInputItemList: input,
|
||
},
|
||
Reasoning: shared.ReasoningParam{
|
||
Effort: shared.ReasoningEffortXhigh,
|
||
},
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &HighlyTrustedResponse{
|
||
Text: resp.OutputText(),
|
||
}, nil
|
||
}
|
||
|
||
func NewOpenAIProoooompter(cfg *Config, log *slog.Logger) *OpenAIPrompter {
|
||
return &OpenAIPrompter{
|
||
cli: openai.NewClient(
|
||
option.WithBaseURL(cfg.OpenAIBaseURL),
|
||
),
|
||
cfg: cfg,
|
||
log: log,
|
||
}
|
||
}
|
||
|
||
type HighlyTrustedResponse struct {
|
||
Text string
|
||
}
|
||
|
||
type PromptRequest struct {
|
||
Username string
|
||
Question string
|
||
ReplyToContent *string
|
||
}
|
||
|
||
type Proompter interface {
|
||
Prompt(ctx context.Context, req PromptRequest) (*HighlyTrustedResponse, error)
|
||
}
|
||
|
||
type App struct {
|
||
log *slog.Logger
|
||
bot *tgbotapi.BotAPI
|
||
proompter Proompter
|
||
sema chan struct{}
|
||
config *Config
|
||
}
|
||
|
||
func (a *App) handleMessage(ctx context.Context, msg *tgbotapi.Message) error {
|
||
if msg.Chat == nil {
|
||
return nil
|
||
}
|
||
|
||
chatID := msg.Chat.ID
|
||
if chatID != a.config.ChatID || !strings.HasPrefix(msg.Text, "@grok ") {
|
||
return nil
|
||
}
|
||
|
||
question := msg.Text[len("@grok "):]
|
||
|
||
select {
|
||
case a.sema <- struct{}{}:
|
||
default:
|
||
a.log.Info("concurrency limit hit", "msg", msg)
|
||
return nil
|
||
}
|
||
defer func() { <-a.sema }()
|
||
|
||
a.log.Info(
|
||
"message passed all guards",
|
||
"og_text", msg.Text,
|
||
"transformed_text", question,
|
||
)
|
||
|
||
var repliedToContent *string
|
||
if msg.ReplyToMessage != nil {
|
||
a.log.Info("message was a reply")
|
||
repliedToContent = &msg.ReplyToMessage.Text
|
||
}
|
||
|
||
var username string
|
||
if msg.From != nil {
|
||
username = msg.From.UserName
|
||
}
|
||
|
||
response, err := a.proompter.Prompt(ctx, PromptRequest{
|
||
Question: question,
|
||
ReplyToContent: repliedToContent,
|
||
Username: username,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("prompting: %w", err)
|
||
}
|
||
|
||
_, err = a.bot.Send(tgbotapi.MessageConfig{
|
||
BaseChat: tgbotapi.BaseChat{
|
||
ChatID: chatID,
|
||
ReplyToMessageID: msg.MessageID,
|
||
},
|
||
Text: response.Text,
|
||
})
|
||
if err != nil {
|
||
return fmt.Errorf("responding: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (a *App) HandleUpdates(ctx context.Context) error {
|
||
upds := a.bot.GetUpdatesChan(tgbotapi.UpdateConfig{})
|
||
for {
|
||
select {
|
||
case upd, ok := <-upds:
|
||
if !ok {
|
||
return errors.New("channel closed")
|
||
}
|
||
|
||
if upd.Message != nil {
|
||
a.log.Info("new message", "update", upd)
|
||
go func() {
|
||
if err := a.handleMessage(ctx, upd.Message); err != nil {
|
||
a.log.Error("handling message", "msg", upd.Message, "err", err)
|
||
}
|
||
}()
|
||
}
|
||
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
}
|
||
}
|
||
|
||
}
|
||
|
||
func NewApp(cfg *Config, prompter Proompter) (*App, error) {
|
||
var app App
|
||
var err error
|
||
|
||
app.proompter = prompter
|
||
|
||
app.log = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
|
||
|
||
app.bot, err = tgbotapi.NewBotAPI(cfg.BotToken)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if cfg.MaxConcurrentRequests == 0 {
|
||
return nil, errors.New("concurrency limit not set")
|
||
}
|
||
app.sema = make(chan struct{}, cfg.MaxConcurrentRequests)
|
||
|
||
app.config = cfg
|
||
|
||
return &app, nil
|
||
}
|
||
|
||
type Config struct {
|
||
SystemPrompt string
|
||
OpenAIBaseURL string
|
||
BotToken string
|
||
MaxConcurrentRequests uint
|
||
ChatID int64
|
||
ChannelID int64
|
||
}
|
||
|
||
func LoadConfig(cfg *Config) error {
|
||
if err := godotenv.Load(".env"); err != nil {
|
||
slog.Warn("no env file loaded", "err", err)
|
||
}
|
||
|
||
cfg.OpenAIBaseURL = os.Getenv("OPENAI_BASE_URL")
|
||
if cfg.OpenAIBaseURL == "" {
|
||
return errors.New("openai base url not set")
|
||
}
|
||
|
||
cfg.BotToken = os.Getenv("BOT_TOKEN")
|
||
mcg, err := strconv.ParseUint(
|
||
os.Getenv("MAX_CONCURRENT_REQUESTS"),
|
||
10, 64,
|
||
)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
cfg.MaxConcurrentRequests = uint(mcg)
|
||
chatID, err := strconv.ParseInt(os.Getenv("CHAT_ID"), 10, 64)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if chatID == 0 {
|
||
return errors.New("non-zero chat id is unsafe")
|
||
}
|
||
|
||
cfg.ChatID = chatID
|
||
|
||
channelID, err := strconv.ParseInt(os.Getenv("CHANNEL_ID"), 10, 64)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if chatID == 0 {
|
||
slog.Warn("channel id is not set")
|
||
}
|
||
cfg.ChannelID = channelID
|
||
|
||
sysPromptPath := os.Getenv("SYSTEM_PROMPT_PATH")
|
||
promptBytes, err := os.ReadFile(sysPromptPath)
|
||
if err != nil {
|
||
slog.Warn("could not load system prompt", "path", sysPromptPath)
|
||
} else {
|
||
cfg.SystemPrompt = string(promptBytes)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func main() {
|
||
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt)
|
||
defer cancel()
|
||
|
||
log := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{}))
|
||
log.Info("Starting GROK")
|
||
|
||
var cfg Config
|
||
if err := LoadConfig(&cfg); err != nil {
|
||
log.Error("loading config", "err", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
prompter := NewOpenAIProoooompter(&cfg, log)
|
||
|
||
app, err := NewApp(&cfg, prompter)
|
||
if err != nil {
|
||
log.Error("initializing app", "err", err)
|
||
os.Exit(1)
|
||
}
|
||
|
||
go func() {
|
||
if err := app.HandleUpdates(ctx); err != nil {
|
||
log.Error("handleUpdates exited", "err", err)
|
||
}
|
||
cancel()
|
||
}()
|
||
|
||
<-ctx.Done()
|
||
}
|