Skip to content

Commit 6fa4d22

Browse files
committed
chore: populate connectionlog count using a separate query
1 parent 02a93e0 commit 6fa4d22

File tree

15 files changed

+556
-11
lines changed

15 files changed

+556
-11
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,15 +1353,26 @@ func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLog
13531353
if err == nil {
13541354
return q.db.CountAuditLogs(ctx, arg)
13551355
}
1356-
13571356
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAuditLog.Type)
13581357
if err != nil {
13591358
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
13601359
}
1361-
13621360
return q.db.CountAuthorizedAuditLogs(ctx, arg, prep)
13631361
}
13641362

1363+
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
1364+
// Just like the actual query, shortcut if the user is an owner.
1365+
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
1366+
if err == nil {
1367+
return q.db.CountConnectionLogs(ctx, arg)
1368+
}
1369+
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type)
1370+
if err != nil {
1371+
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
1372+
}
1373+
return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep)
1374+
}
1375+
13651376
func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) {
13661377
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil {
13671378
return nil, err
@@ -5389,3 +5400,7 @@ func (q *querier) CountAuthorizedAuditLogs(ctx context.Context, arg database.Cou
53895400
func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) {
53905401
return q.GetConnectionLogsOffset(ctx, arg)
53915402
}
5403+
5404+
func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) {
5405+
return q.CountConnectionLogs(ctx, arg)
5406+
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,42 @@ func (s *MethodTestSuite) TestConnectionLogs() {
401401
LimitOpt: 10,
402402
}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead)
403403
}))
404+
s.Run("CountConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
405+
ws := createWorkspace(s.T(), db)
406+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
407+
Type: database.ConnectionTypeSsh,
408+
WorkspaceID: ws.ID,
409+
OrganizationID: ws.OrganizationID,
410+
WorkspaceOwnerID: ws.OwnerID,
411+
})
412+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
413+
Type: database.ConnectionTypeSsh,
414+
WorkspaceID: ws.ID,
415+
OrganizationID: ws.OrganizationID,
416+
WorkspaceOwnerID: ws.OwnerID,
417+
})
418+
check.Args(database.CountConnectionLogsParams{}).Asserts(
419+
rbac.ResourceConnectionLog, policy.ActionRead,
420+
).WithNotAuthorized("nil")
421+
}))
422+
s.Run("CountAuthorizedConnectionLogs", s.Subtest(func(db database.Store, check *expects) {
423+
ws := createWorkspace(s.T(), db)
424+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
425+
Type: database.ConnectionTypeSsh,
426+
WorkspaceID: ws.ID,
427+
OrganizationID: ws.OrganizationID,
428+
WorkspaceOwnerID: ws.OwnerID,
429+
})
430+
_ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{
431+
Type: database.ConnectionTypeSsh,
432+
WorkspaceID: ws.ID,
433+
OrganizationID: ws.OrganizationID,
434+
WorkspaceOwnerID: ws.OwnerID,
435+
})
436+
check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts(
437+
rbac.ResourceConnectionLog, policy.ActionRead,
438+
)
439+
}))
404440
}
405441

406442
func (s *MethodTestSuite) TestFile() {

coderd/database/dbauthz/setup_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func hasEmptyResponse(values []reflect.Value) bool {
318318
}
319319
}
320320

321-
// Special case for int64, as it's the return type for count query.
321+
// Special case for int64, as it's the return type for count queries.
322322
if r.Kind() == reflect.Int64 {
323323
if r.Int() == 0 {
324324
return true

coderd/database/dbmetrics/querymetrics.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 30 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/modelqueries.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi
614614

615615
type connectionLogQuerier interface {
616616
GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error)
617+
CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error)
617618
}
618619

619620
func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) {
@@ -701,6 +702,53 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg
701702
return items, nil
702703
}
703704

705+
func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
706+
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
707+
VariableConverter: regosql.ConnectionLogConverter(),
708+
})
709+
if err != nil {
710+
return 0, xerrors.Errorf("compile authorized filter: %w", err)
711+
}
712+
filtered, err := insertAuthorizedFilter(countConnectionLogs, fmt.Sprintf(" AND %s", authorizedFilter))
713+
if err != nil {
714+
return 0, xerrors.Errorf("insert authorized filter: %w", err)
715+
}
716+
717+
query := fmt.Sprintf("-- name: CountAuthorizedConnectionLogs :one\n%s", filtered)
718+
rows, err := q.db.QueryContext(ctx, query,
719+
arg.OrganizationID,
720+
arg.WorkspaceOwner,
721+
arg.WorkspaceOwnerID,
722+
arg.WorkspaceOwnerEmail,
723+
arg.Type,
724+
arg.UserID,
725+
arg.Username,
726+
arg.UserEmail,
727+
arg.StartedAfter,
728+
arg.StartedBefore,
729+
arg.WorkspaceID,
730+
arg.ConnectionID,
731+
arg.Status,
732+
)
733+
if err != nil {
734+
return 0, err
735+
}
736+
defer rows.Close()
737+
var count int64
738+
for rows.Next() {
739+
if err := rows.Scan(&count); err != nil {
740+
return 0, err
741+
}
742+
}
743+
if err := rows.Close(); err != nil {
744+
return 0, err
745+
}
746+
if err := rows.Err(); err != nil {
747+
return 0, err
748+
}
749+
return count, nil
750+
}
751+
704752
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
705753
if !strings.Contains(query, authorizedQueryPlaceholder) {
706754
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")

coderd/database/modelqueries_internal_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ func TestAuditLogsQueryConsistency(t *testing.T) {
7676
}
7777
}
7878

79+
// Same as TestAuditLogsQueryConsistency, but for connection logs.
80+
func TestConnectionLogsQueryConsistency(t *testing.T) {
81+
t.Parallel()
82+
83+
getWhereClause := extractWhereClause(getConnectionLogsOffset)
84+
require.NotEmpty(t, getWhereClause, "getConnectionLogsOffset query should have a WHERE clause")
85+
86+
countWhereClause := extractWhereClause(countConnectionLogs)
87+
require.NotEmpty(t, countWhereClause, "countConnectionLogs query should have a WHERE clause")
88+
89+
require.Equal(t, getWhereClause, countWhereClause, "getConnectionLogsOffset and countConnectionLogs queries should have the same WHERE clause")
90+
}
91+
7992
// extractWhereClause extracts the WHERE clause from a SQL query string
8093
func extractWhereClause(query string) string {
8194
// Find WHERE and get everything after it

coderd/database/querier.go

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier_test.go

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,6 +2166,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
21662166
require.NoError(t, err)
21672167
// Then: No logs returned
21682168
require.Len(t, logs, 0, "no logs should be returned")
2169+
// And: The count matches the number of logs returned
2170+
count, err := authDb.CountConnectionLogs(memberCtx, database.CountConnectionLogsParams{})
2171+
require.NoError(t, err)
2172+
require.EqualValues(t, len(logs), count)
21692173
})
21702174

21712175
t.Run("SiteWideAuditor", func(t *testing.T) {
@@ -2184,6 +2188,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
21842188
require.NoError(t, err)
21852189
// Then: All logs are returned
21862190
require.ElementsMatch(t, connectionOnlyIDs(allLogs), connectionOnlyIDs(logs))
2191+
// And: The count matches the number of logs returned
2192+
count, err := authDb.CountConnectionLogs(siteAuditorCtx, database.CountConnectionLogsParams{})
2193+
require.NoError(t, err)
2194+
require.EqualValues(t, len(logs), count)
21872195
})
21882196

21892197
t.Run("SingleOrgAuditor", func(t *testing.T) {
@@ -2203,6 +2211,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
22032211
require.NoError(t, err)
22042212
// Then: Only the logs for the organization are returned
22052213
require.ElementsMatch(t, orgConnectionLogs[orgID], connectionOnlyIDs(logs))
2214+
// And: The count matches the number of logs returned
2215+
count, err := authDb.CountConnectionLogs(orgAuditCtx, database.CountConnectionLogsParams{})
2216+
require.NoError(t, err)
2217+
require.EqualValues(t, len(logs), count)
22062218
})
22072219

22082220
t.Run("TwoOrgAuditors", func(t *testing.T) {
@@ -2223,6 +2235,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
22232235
require.NoError(t, err)
22242236
// Then: All logs for both organizations are returned
22252237
require.ElementsMatch(t, append(orgConnectionLogs[first], orgConnectionLogs[second]...), connectionOnlyIDs(logs))
2238+
// And: The count matches the number of logs returned
2239+
count, err := authDb.CountConnectionLogs(multiOrgAuditCtx, database.CountConnectionLogsParams{})
2240+
require.NoError(t, err)
2241+
require.EqualValues(t, len(logs), count)
22262242
})
22272243

22282244
t.Run("ErroneousOrg", func(t *testing.T) {
@@ -2241,9 +2257,71 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) {
22412257
require.NoError(t, err)
22422258
// Then: No logs are returned
22432259
require.Len(t, logs, 0, "no logs should be returned")
2260+
// And: The count matches the number of logs returned
2261+
count, err := authDb.CountConnectionLogs(userCtx, database.CountConnectionLogsParams{})
2262+
require.NoError(t, err)
2263+
require.EqualValues(t, len(logs), count)
22442264
})
22452265
}
22462266

2267+
func TestCountConnectionLogs(t *testing.T) {
2268+
t.Parallel()
2269+
ctx := testutil.Context(t, testutil.WaitLong)
2270+
2271+
db, _ := dbtestutil.NewDB(t)
2272+
2273+
orgA := dbfake.Organization(t, db).Do()
2274+
userA := dbgen.User(t, db, database.User{})
2275+
tplA := dbgen.Template(t, db, database.Template{OrganizationID: orgA.Org.ID, CreatedBy: userA.ID})
2276+
wsA := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userA.ID, OrganizationID: orgA.Org.ID, TemplateID: tplA.ID})
2277+
2278+
orgB := dbfake.Organization(t, db).Do()
2279+
userB := dbgen.User(t, db, database.User{})
2280+
tplB := dbgen.Template(t, db, database.Template{OrganizationID: orgB.Org.ID, CreatedBy: userB.ID})
2281+
wsB := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userB.ID, OrganizationID: orgB.Org.ID, TemplateID: tplB.ID})
2282+
2283+
// Create logs for two different orgs.
2284+
for i := 0; i < 20; i++ {
2285+
dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
2286+
OrganizationID: wsA.OrganizationID,
2287+
WorkspaceOwnerID: wsA.OwnerID,
2288+
WorkspaceID: wsA.ID,
2289+
Type: database.ConnectionTypeSsh,
2290+
})
2291+
}
2292+
for i := 0; i < 10; i++ {
2293+
dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{
2294+
OrganizationID: wsB.OrganizationID,
2295+
WorkspaceOwnerID: wsB.OwnerID,
2296+
WorkspaceID: wsB.ID,
2297+
Type: database.ConnectionTypeSsh,
2298+
})
2299+
}
2300+
2301+
// Count with a filter for orgA.
2302+
countParams := database.CountConnectionLogsParams{
2303+
OrganizationID: orgA.Org.ID,
2304+
}
2305+
totalCount, err := db.CountConnectionLogs(ctx, countParams)
2306+
require.NoError(t, err)
2307+
require.Equal(t, int64(20), totalCount)
2308+
2309+
// Get a paginated result for the same filter.
2310+
getParams := database.GetConnectionLogsOffsetParams{
2311+
OrganizationID: orgA.Org.ID,
2312+
LimitOpt: 5,
2313+
OffsetOpt: 10,
2314+
}
2315+
logs, err := db.GetConnectionLogsOffset(ctx, getParams)
2316+
require.NoError(t, err)
2317+
require.Len(t, logs, 5)
2318+
2319+
// The count with the filter should remain the same, independent of pagination.
2320+
countAfterGet, err := db.CountConnectionLogs(ctx, countParams)
2321+
require.NoError(t, err)
2322+
require.Equal(t, int64(20), countAfterGet)
2323+
}
2324+
22472325
func TestConnectionLogsOffsetFilters(t *testing.T) {
22482326
t.Parallel()
22492327
ctx := testutil.Context(t, testutil.WaitLong)
@@ -2482,7 +2560,24 @@ func TestConnectionLogsOffsetFilters(t *testing.T) {
24822560
t.Parallel()
24832561
logs, err := db.GetConnectionLogsOffset(ctx, tc.params)
24842562
require.NoError(t, err)
2563+
count, err := db.CountConnectionLogs(ctx, database.CountConnectionLogsParams{
2564+
OrganizationID: tc.params.OrganizationID,
2565+
WorkspaceOwner: tc.params.WorkspaceOwner,
2566+
Type: tc.params.Type,
2567+
UserID: tc.params.UserID,
2568+
Username: tc.params.Username,
2569+
UserEmail: tc.params.UserEmail,
2570+
StartedAfter: tc.params.StartedAfter,
2571+
StartedBefore: tc.params.StartedBefore,
2572+
WorkspaceID: tc.params.WorkspaceID,
2573+
ConnectionID: tc.params.ConnectionID,
2574+
Status: tc.params.Status,
2575+
WorkspaceOwnerID: tc.params.WorkspaceOwnerID,
2576+
WorkspaceOwnerEmail: tc.params.WorkspaceOwnerEmail,
2577+
})
2578+
require.NoError(t, err)
24852579
require.ElementsMatch(t, tc.expectedLogIDs, connectionOnlyIDs(logs))
2580+
require.Equal(t, len(tc.expectedLogIDs), int(count), "CountConnectionLogs should match the number of returned logs (no offset or limit)")
24862581
})
24872582
}
24882583
}

0 commit comments

Comments
 (0)