Skip to content

Commit 4055dd7

Browse files
committed
chore: populate connectionlog count using a separate query
1 parent c013e9f commit 4055dd7

File tree

15 files changed

+552
-11
lines changed

15 files changed

+552
-11
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,15 +1353,26 @@ func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLog
13531353
if err == nil {
13541354
return q.db.CountAuditLogs(ctx, arg)
13551355
}
1356-
13571356
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAuditLog.Type)
13581357
if err != nil {
13591358
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
13601359
}
1361-
13621360
return q.db.CountAuthorizedAuditLogs(ctx, arg, prep)
13631361
}
13641362

1363+
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1364+
// Just like the actual query, shortcut if the user is an owner.
1365+
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
1366+
if err == nil {
1367+
return q.db.CountConnectionLogs(ctx, arg)
1368+
}
1369+
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type)
1370+
if err != nil {
1371+
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
1372+
}
1373+
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
1374+
}
1375+
13651376
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
13661377
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
13671378
return nil, err
@@ -5392,3 +5403,7 @@ func (q *querier) CountAuthorizedAuditLogs(ctx context.Context, arg database.Cou
53925403
func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
53935404
return q.GetConnectionLogsOffset(ctx, arg)
53945405
}
5406+
5407+
func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
5408+
return q.CountConnectionLogs(ctx, arg)
5409+
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,42 @@ func (s *MethodTestSuite) TestConnectionLogs() {
406406
LimitOpt: 10,
407407
}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead)
408408
}))
409+
s.Run("CountConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
410+
ws := createWorkspace(s.T(), db)
411+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
412+
Type: database.ConnectionTypeSsh,
413+
WorkspaceID: ws.ID,
414+
OrganizationID: ws.OrganizationID,
415+
WorkspaceOwnerID: ws.OwnerID,
416+
})
417+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
418+
Type: database.ConnectionTypeSsh,
419+
WorkspaceID: ws.ID,
420+
OrganizationID: ws.OrganizationID,
421+
WorkspaceOwnerID: ws.OwnerID,
422+
})
423+
check.Args(database.CountConnectionLogsParams{}).Asserts(
424+
rbac.ResourceConnectionLog, policy.ActionRead,
425+
).WithNotAuthorized("nil")
426+
}))
427+
s.Run("CountAuthorizedConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
428+
ws := createWorkspace(s.T(), db)
429+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
430+
Type: database.ConnectionTypeSsh,
431+
WorkspaceID: ws.ID,
432+
OrganizationID: ws.OrganizationID,
433+
WorkspaceOwnerID: ws.OwnerID,
434+
})
435+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
436+
Type: database.ConnectionTypeSsh,
437+
WorkspaceID: ws.ID,
438+
OrganizationID: ws.OrganizationID,
439+
WorkspaceOwnerID: ws.OwnerID,
440+
})
441+
check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts(
442+
rbac.ResourceConnectionLog, policy.ActionRead,
443+
)
444+
}))
409445
}
410446

411447
func (s *MethodTestSuite) TestFile() {

coderd/database/dbauthz/setup_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func hasEmptyResponse(values []reflect.Value) bool {
318318
}
319319
}
320320

321-
// Special case for int64, as it's the return type for count query.
321+
// Special case for int64, as it's the return type for count queries.
322322
if r.Kind() == reflect.Int64 {
323323
if r.Int() == 0 {
324324
return true

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
614614

615615
type connectionLogQuerier interface {
616616
GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error)
617+
CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error)
617618
}
618619

619620
func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) {
@@ -700,6 +701,53 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
700701
return items, nil
701702
}
702703

704+
func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
705+
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
706+
VariableConverter: regosql.ConnectionLogConverter(),
707+
})
708+
if err != nil {
709+
return 0, xerrors.Errorf("compile authorized filter: %w", err)
710+
}
711+
filtered, err := insertAuthorizedFilter(countConnectionLogs, fmt.Sprintf(" AND %s", authorizedFilter))
712+
if err != nil {
713+
return 0, xerrors.Errorf("insert authorized filter: %w", err)
714+
}
715+
716+
query := fmt.Sprintf("-- name: CountAuthorizedConnectionLogs :one\n%s", filtered)
717+
rows, err := q.db.QueryContext(ctx, query,
718+
arg.OrganizationID,
719+
arg.WorkspaceOwner,
720+
arg.WorkspaceOwnerID,
721+
arg.WorkspaceOwnerEmail,
722+
arg.Type,
723+
arg.UserID,
724+
arg.Username,
725+
arg.UserEmail,
726+
arg.ConnectedAfter,
727+
arg.ConnectedBefore,
728+
arg.WorkspaceID,
729+
arg.ConnectionID,
730+
arg.Status,
731+
)
732+
if err != nil {
733+
return 0, err
734+
}
735+
defer rows.Close()
736+
var count int64
737+
for rows.Next() {
738+
if err := rows.Scan(&count); err != nil {
739+
return 0, err
740+
}
741+
}
742+
if err := rows.Close(); err != nil {
743+
return 0, err
744+
}
745+
if err := rows.Err(); err != nil {
746+
return 0, err
747+
}
748+
return count, nil
749+
}
750+
703751
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
704752
if !strings.Contains(query, authorizedQueryPlaceholder) {
705753
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")

coderd/database/modelqueries_internal_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ func TestAuditLogsQueryConsistency(t *testing.T) {
7676
}
7777
}
7878

79+
// Same as TestAuditLogsQueryConsistency, but for connection logs.
80+
func TestConnectionLogsQueryConsistency(t *testing.T) {
81+
t.Parallel()
82+
83+
getWhereClause := extractWhereClause(getConnectionLogsOffset)
84+
require.NotEmpty(t, getWhereClause, "getConnectionLogsOffset query should have a WHERE clause")
85+
86+
countWhereClause := extractWhereClause(countConnectionLogs)
87+
require.NotEmpty(t, countWhereClause, "countConnectionLogs query should have a WHERE clause")
88+
89+
require.Equal(t, getWhereClause, countWhereClause, "getConnectionLogsOffset and countConnectionLogs queries should have the same WHERE clause")
90+
}
91+
7992
// extractWhereClause extracts the WHERE clause from a SQL query string
8093
func extractWhereClause(query string) string {
8194
// Find WHERE and get everything after it

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,6 +2168,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
21682168
require.NoError(t, err)
21692169
// Then: No logs returned
21702170
require.Len(t, logs, 0, "no logs should be returned")
2171+
// And: The count matches the number of logs returned
2172+
count, err := authDb.CountConnectionLogs(memberCtx, database.CountConnectionLogsParams{})
2173+
require.NoError(t, err)
2174+
require.EqualValues(t, len(logs), count)
21712175
})
21722176

21732177
t.Run("SiteWideAuditor", func(t *testing.T) {
@@ -2186,6 +2190,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
21862190
require.NoError(t, err)
21872191
// Then: All logs are returned
21882192
require.ElementsMatch(t, connectionOnlyIDs(allLogs), connectionOnlyIDs(logs))
2193+
// And: The count matches the number of logs returned
2194+
count, err := authDb.CountConnectionLogs(siteAuditorCtx, database.CountConnectionLogsParams{})
2195+
require.NoError(t, err)
2196+
require.EqualValues(t, len(logs), count)
21892197
})
21902198

21912199
t.Run("SingleOrgAuditor", func(t *testing.T) {
@@ -2205,6 +2213,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
22052213
require.NoError(t, err)
22062214
// Then: Only the logs for the organization are returned
22072215
require.ElementsMatch(t, orgConnectionLogs[orgID], connectionOnlyIDs(logs))
2216+
// And: The count matches the number of logs returned
2217+
count, err := authDb.CountConnectionLogs(orgAuditCtx, database.CountConnectionLogsParams{})
2218+
require.NoError(t, err)
2219+
require.EqualValues(t, len(logs), count)
22082220
})
22092221

22102222
t.Run("TwoOrgAuditors", func(t *testing.T) {
@@ -2225,6 +2237,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
22252237
require.NoError(t, err)
22262238
// Then: All logs for both organizations are returned
22272239
require.ElementsMatch(t, append(orgConnectionLogs[first], orgConnectionLogs[second]...), connectionOnlyIDs(logs))
2240+
// And: The count matches the number of logs returned
2241+
count, err := authDb.CountConnectionLogs(multiOrgAuditCtx, database.CountConnectionLogsParams{})
2242+
require.NoError(t, err)
2243+
require.EqualValues(t, len(logs), count)
22282244
})
22292245

22302246
t.Run("ErroneousOrg", func(t *testing.T) {
@@ -2243,9 +2259,71 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
22432259
require.NoError(t, err)
22442260
// Then: No logs are returned
22452261
require.Len(t, logs, 0, "no logs should be returned")
2262+
// And: The count matches the number of logs returned
2263+
count, err := authDb.CountConnectionLogs(userCtx, database.CountConnectionLogsParams{})
2264+
require.NoError(t, err)
2265+
require.EqualValues(t, len(logs), count)
22462266
})
22472267
}
22482268

2269+
func TestCountConnectionLogs(t *testing.T) {
2270+
t.Parallel()
2271+
ctx := testutil.Context(t, testutil.WaitLong)
2272+
2273+
db, _ := dbtestutil.NewDB(t)
2274+
2275+
orgA := dbfake.Organization(t, db).Do()
2276+
userA := dbgen.User(t, db, database.User{})
2277+
tplA := dbgen.Template(t, db, database.Template{OrganizationID: orgA.Org.ID, CreatedBy: userA.ID})
2278+
wsA := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userA.ID, OrganizationID: orgA.Org.ID, TemplateID: tplA.ID})
2279+
2280+
orgB := dbfake.Organization(t, db).Do()
2281+
userB := dbgen.User(t, db, database.User{})
2282+
tplB := dbgen.Template(t, db, database.Template{OrganizationID: orgB.Org.ID, CreatedBy: userB.ID})
2283+
wsB := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userB.ID, OrganizationID: orgB.Org.ID, TemplateID: tplB.ID})
2284+
2285+
// Create logs for two different orgs.
2286+
for i := 0; i < 20; i++ {
2287+
dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
2288+
OrganizationID: wsA.OrganizationID,
2289+
WorkspaceOwnerID: wsA.OwnerID,
2290+
WorkspaceID: wsA.ID,
2291+
Type: database.ConnectionTypeSsh,
2292+
})
2293+
}
2294+
for i := 0; i < 10; i++ {
2295+
dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
2296+
OrganizationID: wsB.OrganizationID,
2297+
WorkspaceOwnerID: wsB.OwnerID,
2298+
WorkspaceID: wsB.ID,
2299+
Type: database.ConnectionTypeSsh,
2300+
})
2301+
}
2302+
2303+
// Count with a filter for orgA.
2304+
countParams := database.CountConnectionLogsParams{
2305+
OrganizationID: orgA.Org.ID,
2306+
}
2307+
totalCount, err := db.CountConnectionLogs(ctx, countParams)
2308+
require.NoError(t, err)
2309+
require.Equal(t, int64(20), totalCount)
2310+
2311+
// Get a paginated result for the same filter.
2312+
getParams := database.GetConnectionLogsOffsetParams{
2313+
OrganizationID: orgA.Org.ID,
2314+
LimitOpt: 5,
2315+
OffsetOpt: 10,
2316+
}
2317+
logs, err := db.GetConnectionLogsOffset(ctx, getParams)
2318+
require.NoError(t, err)
2319+
require.Len(t, logs, 5)
2320+
2321+
// The count with the filter should remain the same, independent of pagination.
2322+
countAfterGet, err := db.CountConnectionLogs(ctx, countParams)
2323+
require.NoError(t, err)
2324+
require.Equal(t, int64(20), countAfterGet)
2325+
}
2326+
22492327
func TestConnectionLogsOffsetFilters(t *testing.T) {
22502328
t.Parallel()
22512329
ctx := testutil.Context(t, testutil.WaitLong)
@@ -2484,7 +2562,24 @@ func TestConnectionLogsOffsetFilters(t *testing.T) {
24842562
t.Parallel()
24852563
logs, err := db.GetConnectionLogsOffset(ctx, tc.params)
24862564
require.NoError(t, err)
2565+
count, err := db.CountConnectionLogs(ctx, database.CountConnectionLogsParams{
2566+
OrganizationID: tc.params.OrganizationID,
2567+
WorkspaceOwner: tc.params.WorkspaceOwner,
2568+
Type: tc.params.Type,
2569+
UserID: tc.params.UserID,
2570+
Username: tc.params.Username,
2571+
UserEmail: tc.params.UserEmail,
2572+
ConnectedAfter: tc.params.ConnectedAfter,
2573+
ConnectedBefore: tc.params.ConnectedBefore,
2574+
WorkspaceID: tc.params.WorkspaceID,
2575+
ConnectionID: tc.params.ConnectionID,
2576+
Status: tc.params.Status,
2577+
WorkspaceOwnerID: tc.params.WorkspaceOwnerID,
2578+
WorkspaceOwnerEmail: tc.params.WorkspaceOwnerEmail,
2579+
})
2580+
require.NoError(t, err)
24872581
require.ElementsMatch(t, tc.expectedLogIDs, connectionOnlyIDs(logs))
2582+
require.Equal(t, len(tc.expectedLogIDs), int(count), "CountConnectionLogs should match the number of returned logs (no offset or limit)")
24882583
})
24892584
}
24902585
}

0 commit comments

Comments
 (0)