Skip to content

Commit bdb59e0

Browse files
feat: add GitHub notifications tools for managing user notifications
1 parent e6b73f7 commit bdb59e0

File tree

2 files changed

+258
-0
lines changed

2 files changed

+258
-0
lines changed

pkg/github/notifications.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"time"
10+
11+
"github.com/github/github-mcp-server/pkg/translations"
12+
"github.com/google/go-github/v69/github"
13+
"github.com/mark3labs/mcp-go/mcp"
14+
"github.com/mark3labs/mcp-go/server"
15+
)
16+
17+
// getNotifications creates a tool to list notifications for the current user.
18+
func getNotifications(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
19+
return mcp.NewTool("get_notifications",
20+
mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")),
21+
mcp.WithBoolean("all",
22+
mcp.Description("If true, show notifications marked as read. Default: false"),
23+
),
24+
mcp.WithBoolean("participating",
25+
mcp.Description("If true, only shows notifications in which the user is directly participating or mentioned. Default: false"),
26+
),
27+
mcp.WithString("since",
28+
mcp.Description("Only show notifications updated after the given time (ISO 8601 format)"),
29+
),
30+
mcp.WithString("before",
31+
mcp.Description("Only show notifications updated before the given time (ISO 8601 format)"),
32+
),
33+
mcp.WithNumber("per_page",
34+
mcp.Description("Results per page (max 100). Default: 30"),
35+
),
36+
mcp.WithNumber("page",
37+
mcp.Description("Page number of the results to fetch. Default: 1"),
38+
),
39+
),
40+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
41+
// Extract optional parameters with defaults
42+
all, err := optionalParamWithDefault[bool](request, "all", false)
43+
if err != nil {
44+
return mcp.NewToolResultError(err.Error()), nil
45+
}
46+
47+
participating, err := optionalParamWithDefault[bool](request, "participating", false)
48+
if err != nil {
49+
return mcp.NewToolResultError(err.Error()), nil
50+
}
51+
52+
since, err := optionalParam[string](request, "since")
53+
if err != nil {
54+
return mcp.NewToolResultError(err.Error()), nil
55+
}
56+
57+
before, err := optionalParam[string](request, "before")
58+
if err != nil {
59+
return mcp.NewToolResultError(err.Error()), nil
60+
}
61+
62+
perPage, err := optionalIntParamWithDefault(request, "per_page", 30)
63+
if err != nil {
64+
return mcp.NewToolResultError(err.Error()), nil
65+
}
66+
67+
page, err := optionalIntParamWithDefault(request, "page", 1)
68+
if err != nil {
69+
return mcp.NewToolResultError(err.Error()), nil
70+
}
71+
72+
// Build options
73+
opts := &github.NotificationListOptions{
74+
All: all,
75+
Participating: participating,
76+
ListOptions: github.ListOptions{
77+
Page: page,
78+
PerPage: perPage,
79+
},
80+
}
81+
82+
// Parse time parameters if provided
83+
if since != "" {
84+
sinceTime, err := time.Parse(time.RFC3339, since)
85+
if err != nil {
86+
return mcp.NewToolResultError(fmt.Sprintf("invalid since time format, should be RFC3339/ISO8601: %v", err)), nil
87+
}
88+
opts.Since = sinceTime
89+
}
90+
91+
if before != "" {
92+
beforeTime, err := time.Parse(time.RFC3339, before)
93+
if err != nil {
94+
return mcp.NewToolResultError(fmt.Sprintf("invalid before time format, should be RFC3339/ISO8601: %v", err)), nil
95+
}
96+
opts.Before = beforeTime
97+
}
98+
99+
// Call GitHub API
100+
notifications, resp, err := client.Activity.ListNotifications(ctx, opts)
101+
if err != nil {
102+
return nil, fmt.Errorf("failed to get notifications: %w", err)
103+
}
104+
defer func() { _ = resp.Body.Close() }()
105+
106+
if resp.StatusCode != http.StatusOK {
107+
body, err := io.ReadAll(resp.Body)
108+
if err != nil {
109+
return nil, fmt.Errorf("failed to read response body: %w", err)
110+
}
111+
return mcp.NewToolResultError(fmt.Sprintf("failed to get notifications: %s", string(body))), nil
112+
}
113+
114+
// Marshal response to JSON
115+
r, err := json.Marshal(notifications)
116+
if err != nil {
117+
return nil, fmt.Errorf("failed to marshal response: %w", err)
118+
}
119+
120+
return mcp.NewToolResultText(string(r)), nil
121+
}
122+
}
123+
124+
// markNotificationRead creates a tool to mark a notification as read.
125+
func markNotificationRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
126+
return mcp.NewTool("mark_notification_read",
127+
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")),
128+
mcp.WithString("threadID",
129+
mcp.Required(),
130+
mcp.Description("The ID of the notification thread"),
131+
),
132+
),
133+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
134+
threadID, err := requiredParam[string](request, "threadID")
135+
if err != nil {
136+
return mcp.NewToolResultError(err.Error()), nil
137+
}
138+
139+
resp, err := client.Activity.MarkThreadRead(ctx, threadID)
140+
if err != nil {
141+
return nil, fmt.Errorf("failed to mark notification as read: %w", err)
142+
}
143+
defer func() { _ = resp.Body.Close() }()
144+
145+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
146+
body, err := io.ReadAll(resp.Body)
147+
if err != nil {
148+
return nil, fmt.Errorf("failed to read response body: %w", err)
149+
}
150+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark notification as read: %s", string(body))), nil
151+
}
152+
153+
return mcp.NewToolResultText("Notification marked as read"), nil
154+
}
155+
}
156+
157+
// markAllNotificationsRead creates a tool to mark all notifications as read.
158+
func markAllNotificationsRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
159+
return mcp.NewTool("mark_all_notifications_read",
160+
mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")),
161+
mcp.WithString("lastReadAt",
162+
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
163+
),
164+
),
165+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
166+
lastReadAt, err := optionalParam[string](request, "lastReadAt")
167+
if err != nil {
168+
return mcp.NewToolResultError(err.Error()), nil
169+
}
170+
171+
var markReadOptions github.Timestamp
172+
if lastReadAt != "" {
173+
lastReadTime, err := time.Parse(time.RFC3339, lastReadAt)
174+
if err != nil {
175+
return mcp.NewToolResultError(fmt.Sprintf("invalid lastReadAt time format, should be RFC3339/ISO8601: %v", err)), nil
176+
}
177+
markReadOptions = github.Timestamp{
178+
Time: lastReadTime,
179+
}
180+
}
181+
182+
resp, err := client.Activity.MarkNotificationsRead(ctx, markReadOptions)
183+
if err != nil {
184+
return nil, fmt.Errorf("failed to mark all notifications as read: %w", err)
185+
}
186+
defer func() { _ = resp.Body.Close() }()
187+
188+
if resp.StatusCode != http.StatusResetContent && resp.StatusCode != http.StatusOK {
189+
body, err := io.ReadAll(resp.Body)
190+
if err != nil {
191+
return nil, fmt.Errorf("failed to read response body: %w", err)
192+
}
193+
return mcp.NewToolResultError(fmt.Sprintf("failed to mark all notifications as read: %s", string(body))), nil
194+
}
195+
196+
return mcp.NewToolResultText("All notifications marked as read"), nil
197+
}
198+
}
199+
200+
// getNotificationThread creates a tool to get a specific notification thread.
201+
func getNotificationThread(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
202+
return mcp.NewTool("get_notification_thread",
203+
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
204+
mcp.WithString("threadID",
205+
mcp.Required(),
206+
mcp.Description("The ID of the notification thread"),
207+
),
208+
),
209+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
210+
threadID, err := requiredParam[string](request, "threadID")
211+
if err != nil {
212+
return mcp.NewToolResultError(err.Error()), nil
213+
}
214+
215+
thread, resp, err := client.Activity.GetThread(ctx, threadID)
216+
if err != nil {
217+
return nil, fmt.Errorf("failed to get notification thread: %w", err)
218+
}
219+
defer func() { _ = resp.Body.Close() }()
220+
221+
if resp.StatusCode != http.StatusOK {
222+
body, err := io.ReadAll(resp.Body)
223+
if err != nil {
224+
return nil, fmt.Errorf("failed to read response body: %w", err)
225+
}
226+
return mcp.NewToolResultError(fmt.Sprintf("failed to get notification thread: %s", string(body))), nil
227+
}
228+
229+
r, err := json.Marshal(thread)
230+
if err != nil {
231+
return nil, fmt.Errorf("failed to marshal response: %w", err)
232+
}
233+
234+
return mcp.NewToolResultText(string(r)), nil
235+
}
236+
}

pkg/github/server.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ func NewServer(client *github.Client, readOnly bool, t translations.TranslationH
7777
// Add GitHub tools - Code Scanning
7878
s.AddTool(getCodeScanningAlert(client, t))
7979
s.AddTool(listCodeScanningAlerts(client, t))
80+
81+
// Add GitHub tools - Notifications
82+
s.AddTool(getNotifications(client, t))
83+
s.AddTool(getNotificationThread(client, t))
84+
if !readOnly {
85+
s.AddTool(markNotificationRead(client, t))
86+
s.AddTool(markAllNotificationsRead(client, t))
87+
}
8088
return s
8189
}
8290

@@ -189,6 +197,20 @@ func optionalIntParam(r mcp.CallToolRequest, p string) (int, error) {
189197
return int(v), nil
190198
}
191199

200+
// optionalParamWithDefault is a generic helper function that can be used to fetch a requested parameter from the request
201+
// with a default value if the parameter is not provided or is zero value.
202+
func optionalParamWithDefault[T comparable](r mcp.CallToolRequest, p string, d T) (T, error) {
203+
var zero T
204+
v, err := optionalParam[T](r, p)
205+
if err != nil {
206+
return zero, err
207+
}
208+
if v == zero {
209+
return d, nil
210+
}
211+
return v, nil
212+
}
213+
192214
// optionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
193215
// similar to optionalIntParam, but it also takes a default value.
194216
func optionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {

0 commit comments

Comments
 (0)