@@ -19,10 +19,11 @@ package com.github.mauricio.async.db.postgresql
19
19
import java .util .concurrent .{TimeUnit , TimeoutException }
20
20
21
21
import com .github .mauricio .async .db .util .Log
22
- import com .github .mauricio .async .db .{Configuration , Connection }
22
+ import com .github .mauricio .async .db .{RowData , Configuration , Connection }
23
+ import org .reactivestreams .{Subscription , Subscriber }
23
24
24
25
import scala .concurrent .duration ._
25
- import scala .concurrent .{Await , Future }
26
+ import scala .concurrent .{Promise , Await , Future }
26
27
27
28
object DatabaseTestHelper {
28
29
val log = Log .get[DatabaseTestHelper ]
@@ -105,9 +106,44 @@ trait DatabaseTestHelper {
105
106
} )
106
107
}
107
108
109
+ def executeStream (handler : PostgreSQLConnection , statement : String , windowSize : Int = 1000 , values : Array [Any ] = Array .empty[Any ]) : IndexedSeq [RowData ] = {
110
+ handleTimeout(handler, {
111
+ val subscriber : TestSubscriber = new TestSubscriber
112
+ handler.streamQuery(statement, windowSize).subscribe(subscriber)
113
+ Await .result(subscriber.promise.future, Duration (5 , SECONDS ))
114
+ })
115
+ }
116
+
108
117
def await [T ](future : Future [T ]): T = {
109
118
Await .result(future, Duration (10 , TimeUnit .SECONDS ))
110
119
}
111
120
121
+ class TestSubscriber extends Subscriber [RowData ] {
122
+ val promise = Promise [IndexedSeq [RowData ]]()
123
+ override def onError (t : Throwable ): Unit = {
124
+ promise.failure(t)
125
+ }
126
+
127
+ var subscription : Option [Subscription ] = None
128
+ var requested : Long = 0
129
+ override def onSubscribe (subscription : Subscription ): Unit = {
130
+ this .subscription = Some (subscription)
131
+ requested = 10
132
+ subscription.request(10 )
133
+ }
134
+
135
+ override def onComplete (): Unit = {
136
+ promise.success(rows)
137
+ requested -= 1
138
+ if (requested <= 2 ) {
139
+ subscription.get.request(8 )
140
+ requested += 8
141
+ }
142
+ }
112
143
144
+ var rows = IndexedSeq [RowData ]()
145
+ override def onNext (t : RowData ): Unit = {
146
+ rows = rows :+ t
147
+ }
148
+ }
113
149
}
0 commit comments