You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
gpt/main.go

170 lines
4.0 KiB

package main
import (
"context"
"flag"
"fmt"
"html/template"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"strings"
"github.com/PullRequestInc/go-gpt3"
)
const indexTemplate = `
<!DOCTYPE html>
<html lang="en">
<head>
<link rel="stylesheet" href="https://cdn.simplecss.org/simple.min.css">
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>GPT3 Demo</title>
</head>
<body>
<header>
<nav>
<a href="/">Home</a>
<a href="https://git.mills.io/prologic/gpt">Source</a>
</nav>
<h1>GPT3 Demo</h1>
<p>A GPT3 Demo Web Application and API using the <a href="https://beta.openai.com/">Open AI</a> service.</p>
</header>
<main>
<h2>Ask me anything!</h2>
<form action="/" method="POST">
<input type="text" name="prompt" placeholder="Enter your prompt here..." required>
<button type="submit">Go!</button>
</form>
<p>{{ .Response }}</p>
</main>
<footer>
<p>Licensed under the <a href="https://git.mills.io/prologic/gpt/blob/master/LICENSE">WTFPL License</a></p>
</footer>
</body>
</html>
`
func render(name, tmpl string, ctx interface{}, w io.Writer) error {
t, err := template.New(name).Parse(tmpl)
if err != nil {
return err
}
return t.Execute(w, ctx)
}
type gptHandler struct {
cli gpt3.Client
engine string
}
func (h gptHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodHead || r.Method == http.MethodGet && r.URL.Path == "/health" {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "OK\n")
return
}
if r.Method == http.MethodHead || r.Method == http.MethodGet || r.Method == http.MethodPost {
switch r.URL.Path {
case "/":
ctx := struct {
Response string
}{}
if r.Method == http.MethodHead {
return
}
if r.Method == http.MethodPost {
prompt := r.FormValue("prompt")
if prompt == "" {
ctx.Response = "Error: no prompt entered"
} else {
response, err := h.cli.CompletionWithEngine(r.Context(), h.engine, gpt3.CompletionRequest{
Prompt: []string{prompt},
MaxTokens: gpt3.IntPtr(150),
Echo: false,
Temperature: gpt3.Float32Ptr(0.7),
TopP: gpt3.Float32Ptr(0),
FrequencyPenalty: 0,
PresencePenalty: 0,
})
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
for _, choice := range response.Choices {
text := strings.TrimSpace(choice.Text)
if len(text) > 0 {
ctx.Response += text
}
}
}
}
if err := render("index", indexTemplate, ctx, w); err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
return
default:
http.Error(w, "Not Found", http.StatusNotFound)
return
}
}
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
}
func main() {
var (
engine = flag.String("engine", "text-davinci-002", "The engine to use")
server = flag.Bool("server", false, "Run in server mode")
bind = flag.String("bind", "0.0.0.0:8000", "[interface]:<port> to bind to in server mode")
)
flag.Parse()
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
log.Fatalln("missing openai key")
}
cli := gpt3.NewClient(apiKey)
if !*server {
stdin, err := ioutil.ReadAll(os.Stdin)
if err != nil {
log.Fatalln(err)
}
input := strings.TrimSpace(string(stdin))
ctx := context.Background()
response, err := cli.CompletionWithEngine(ctx, *engine, gpt3.CompletionRequest{
Prompt: []string{input},
MaxTokens: gpt3.IntPtr(150),
Echo: false,
Temperature: gpt3.Float32Ptr(0.7),
TopP: gpt3.Float32Ptr(0),
FrequencyPenalty: 0,
PresencePenalty: 0,
})
if err != nil {
log.Fatalln(err)
}
for _, choice := range response.Choices {
text := strings.TrimSpace(choice.Text)
if len(text) > 0 {
fmt.Println(text)
}
}
} else {
log.Fatal(http.ListenAndServe(*bind, gptHandler{cli: cli, engine: *engine}))
}
}