diff --git a/docker-compose.yml b/docker-compose.yml index 267f3c1..96658c6 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,8 +4,10 @@ services: environment: BOT_TOKEN: "" CHAT_ID: "" - OPENAI_BASE_URL: "http://localhost:8080/v1" + CHANNEL_ID: -1003290014225 + OPENAI_BASE_URL: http://localhost:8080/v1 SYSTEM_PROMPT_PATH: /etc/sysprompt.txt + MAX_CONCURRENT_REQUESTS: 2 volumes: - ./sysprompt.txt:/etc/sysprompt.txt:ro restart: unless-stopped diff --git a/main.go b/main.go index 7781981..fcb8cc4 100644 --- a/main.go +++ b/main.go @@ -23,11 +23,16 @@ type OpenAIPrompter struct { cfg *Config } -func (p *OpenAIPrompter) Prompt(ctx context.Context, question string) (*HighlyTrustedResponse, error) { +func (p *OpenAIPrompter) Prompt(ctx context.Context, req PromptRequest) (*HighlyTrustedResponse, error) { + instructions := p.cfg.SystemPrompt + if req.OriginalPostContent != nil { + instructions += fmt.Sprintf("\nThis is the channel post that user is mentioning: %v", req.OriginalPostContent) + } + resp, err := p.cli.Responses.New(ctx, responses.ResponseNewParams{ Instructions: openai.String(p.cfg.SystemPrompt), Input: responses.ResponseNewParamsInputUnion{ - OfString: openai.String(question), + OfString: openai.String(req.Question), }, Reasoning: shared.ReasoningParam{ Effort: shared.ReasoningEffortXhigh, @@ -55,8 +60,13 @@ type HighlyTrustedResponse struct { Text string } +type PromptRequest struct { + Question string + OriginalPostContent *string +} + type Proompter interface { - Prompt(ctx context.Context, question string) (*HighlyTrustedResponse, error) + Prompt(ctx context.Context, req PromptRequest) (*HighlyTrustedResponse, error) } type App struct { @@ -93,7 +103,16 @@ func (a *App) handleMessage(ctx context.Context, msg *tgbotapi.Message) error { "transformed_text", question, ) - response, err := a.proompter.Prompt(ctx, question) + var ogPostContent *string + if msg.ReplyToMessage.From.ID == a.config.ChannelID { + a.log.Info("message was a reply to channel post") + ogPostContent = &msg.ReplyToMessage.Text + } + + response, err := a.proompter.Prompt(ctx, PromptRequest{ + Question: question, + OriginalPostContent: ogPostContent, + }) if err != nil { return fmt.Errorf("prompting: %w", err) } @@ -166,6 +185,7 @@ type Config struct { BotToken string MaxConcurrentRequests uint ChatID int64 + ChannelID int64 } func LoadConfig(cfg *Config) error { @@ -197,6 +217,15 @@ func LoadConfig(cfg *Config) error { cfg.ChatID = chatID + channelID, err := strconv.ParseInt(os.Getenv("CHANNEL_ID"), 10, 64) + if err != nil { + return err + } + if chatID == 0 { + slog.Warn("channel id is not set") + } + cfg.ChannelID = channelID + sysPromptPath := os.Getenv("SYSTEM_PROMPT_PATH") promptBytes, err := os.ReadFile(sysPromptPath) if err != nil {