Skip to content

Commit d755c2b

Browse files
committed
improve RowDataSubscription
1 parent 4d7cc14 commit d755c2b

File tree

2 files changed

+168
-36
lines changed

2 files changed

+168
-36
lines changed

db-async-common/src/main/scala/com/github/mauricio/async/db/util/RowDataSubscription.scala

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,39 @@ import org.reactivestreams.{Subscriber, Subscription}
2525
import scala.annotation.tailrec
2626
import scala.concurrent.ExecutionContext
2727

28+
trait RowDataSubscriptionDelegate {
29+
def start(subscription: RowDataSubscription)
30+
def cancel(subscription: RowDataSubscription)
31+
def pause(subscription: RowDataSubscription)
32+
def continue(subscription: RowDataSubscription)
33+
}
2834

29-
class RowDataSubscription(val subscriber: Subscriber[RowData])(implicit executionContext: ExecutionContext) extends Subscription with Runnable {
35+
final class RowDataSubscription(val subscriber: Subscriber[_ >: RowData], val delegate: RowDataSubscriptionDelegate, val bufferSize : Int)
36+
(implicit executionContext: ExecutionContext) extends Subscription with Runnable
37+
{
3038
private val on = new AtomicBoolean(false)
31-
private var stopped = false
39+
private val needToCallCancel = new AtomicBoolean(false)
40+
private var started = false
41+
private var paused = false
42+
@volatile private var stopped = false
3243
private var completed = false
3344
private[util] val rows = new ConcurrentLinkedQueue[RowData]()
3445
private val demand = new AtomicLong(0)
3546
subscriber.onSubscribe(this)
3647

3748
override def cancel() {
49+
needToCallCancel.set(true)
3850
stopped = true
3951
tryScheduleToExecute()
4052
}
4153

42-
override final def request(n: Long) {
43-
demand.addAndGet(n)
44-
tryScheduleToExecute()
54+
override def request(n: Long) {
55+
if (n <= 0) {
56+
terminateDueTo(new IllegalArgumentException(s"Requested number $n <= 0"))
57+
} else {
58+
demand.addAndGet(n)
59+
tryScheduleToExecute()
60+
}
4561
}
4662

4763
// Should be called by one thread only
@@ -50,6 +66,10 @@ class RowDataSubscription(val subscriber: Subscriber[RowData])(implicit executio
5066
if (!stopped) {
5167
if (demand.get() > 0) {
5268
if (rows.isEmpty) {
69+
if (!started) {
70+
delegate.start(this)
71+
started = true
72+
}
5373
subscriber.onNext(rowData)
5474
demand.decrementAndGet()
5575
on.set(false)
@@ -59,9 +79,16 @@ class RowDataSubscription(val subscriber: Subscriber[RowData])(implicit executio
5979
}
6080
} else {
6181
rows.offer(rowData)
82+
if (!paused && rows.size() >= bufferSize) {
83+
paused = true
84+
delegate.pause(this)
85+
}
6286
on.set(false)
6387
}
6488
}
89+
if (stopped && needToCallCancel.compareAndSet(true, false)) {
90+
delegate.cancel(this)
91+
}
6592
} else if (!stopped) {
6693
rows.offer(rowData)
6794
tryScheduleToExecute()
@@ -114,8 +141,16 @@ class RowDataSubscription(val subscriber: Subscriber[RowData])(implicit executio
114141
}
115142
}
116143

117-
override final def run() {
144+
override def run() {
118145
if (demand.get() > 0) {
146+
if (!started) {
147+
delegate.start(this)
148+
started = true
149+
}
150+
if (paused) {
151+
delegate.continue(this)
152+
paused = false
153+
}
119154
sendRows()
120155
}
121156
if (completed && !stopped && rows.isEmpty) {
@@ -124,9 +159,14 @@ class RowDataSubscription(val subscriber: Subscriber[RowData])(implicit executio
124159
}
125160
on.set(false)
126161
if (stopped) {
162+
if (needToCallCancel.compareAndSet(true, false)) {
163+
delegate.cancel(this)
164+
}
127165
rows.clear()
128-
} else if ((demand.get() > 0 && !rows.isEmpty) || (completed && rows.isEmpty)) {
129-
tryScheduleToExecute()
166+
} else {
167+
if ((demand.get() > 0 && !rows.isEmpty) || (completed && rows.isEmpty)) {
168+
tryScheduleToExecute()
169+
}
130170
}
131171
}
132172

db-async-common/src/test/scala/com/github/mauricio/async/db/util/RowDataSubscriptionSpec.scala

Lines changed: 120 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,26 @@ class RowDataSubscriptionSpec extends Specification {
3131
sequential
3232

3333
"Positive flow" >> {
34+
val delegate = new TestRowDataSubscriptionDelegate()
3435
val subscriber = new TestSubscriber()
35-
val subscription = newSubscription(subscriber)
36+
val subscription = newSubscription(subscriber, delegate, bufferSize = 10)
3637
"Subscription should call onSubscribe" in {
3738
subscriber.subscribed must beTrue
3839
}
39-
"nextRow should be preserved" in {
40-
subscription.nextRow(0)
40+
"request should not affect on rows" in {
41+
subscription.request(1)
4142
subscriber.lastRow must_== -1
42-
subscription.rows must haveSize(1)
4343
}
44-
"and send when it is requested by calling onNext" in {
45-
subscription.request(1)
44+
"nextRow should be sent" in {
45+
subscription.nextRow(0)
4646
subscriber.lastRow must_== 0
4747
}
48-
"next rows should be again preserved" in {
48+
"next rows should be preserved" in {
4949
subscription.nextRow(1)
5050
subscription.nextRow(2)
5151
subscription.nextRow(3)
5252
subscriber.lastRow must_== 0
53+
subscription.rows must haveSize(3)
5354
}
5455
"and send when it is requested but not more than the total number requested" in {
5556
subscription.request(2)
@@ -66,16 +67,49 @@ class RowDataSubscriptionSpec extends Specification {
6667
}
6768
}
6869

70+
"Delegate" >> {
71+
val delegate = new TestRowDataSubscriptionDelegate()
72+
val subscriber = new TestSubscriber()
73+
val subscription = newSubscription(subscriber, delegate, bufferSize = 2)
74+
"delegate should not be started or paused" >> {
75+
delegate.started must beFalse
76+
}
77+
"request should not cause pause" >> {
78+
subscription.request(1)
79+
delegate.paused must beFalse
80+
}
81+
"sending 2 rows should not cause pause" >> {
82+
subscription.nextRow(0)
83+
subscription.nextRow(1)
84+
delegate.paused must beFalse
85+
}
86+
"sending 3nd row should not cause pause" >> {
87+
subscription.nextRow(1)
88+
delegate.paused must beTrue
89+
}
90+
"continue should be called after request" >> {
91+
subscription.request(1)
92+
delegate.paused must beFalse
93+
}
94+
"cancel should be called after cancel" >> {
95+
subscription.cancel()
96+
delegate.cancelled must beTrue
97+
}
98+
}
99+
69100
"When it is canceled it should stop sending and preserving rows" >> {
101+
val delegate = new TestRowDataSubscriptionDelegate()
70102
val subscriber = new TestSubscriber()
71-
val subscription = newSubscription(subscriber)
103+
val subscription = newSubscription(subscriber, delegate, bufferSize = 10)
72104
subscription.nextRow(0)
73105
subscription.request(1)
74106
subscription.nextRow(1)
75107
subscription.cancel()
76108
subscription.nextRow(2)
77109
subscription.rows must haveSize(0)
78110
subscriber.lastRow must_== 0
111+
delegate.started must beTrue
112+
delegate.cancelled must beTrue
79113
}
80114

81115
val attemptsCount = 1000
@@ -90,16 +124,21 @@ class RowDataSubscriptionSpec extends Specification {
90124
"Thread safety" ! attempts {_ =>
91125
val subscriber = new TestSubscriber()
92126
implicit val context = ExecutionContext.global
93-
val subscription = new RowDataSubscription(subscriber)
94127
val count = 1000
95-
context.execute(new Runnable {
96-
override def run(): Unit = {
97-
for (row <- 0 until count) {
98-
subscription.nextRow(row)
99-
}
100-
subscription.complete()
128+
val delegate = new TestRowDataSubscriptionDelegate() {
129+
override def start(subscription: RowDataSubscription): Unit = {
130+
super.start(subscription)
131+
context.execute(new Runnable {
132+
override def run(): Unit = {
133+
for (row <- 0 until count) {
134+
subscription.nextRow(row)
135+
}
136+
subscription.complete()
137+
}
138+
})
101139
}
102-
})
140+
}
141+
val subscription = new RowDataSubscription(subscriber, delegate, bufferSize = 10)
103142
context.execute(new Runnable {
104143
override def run(): Unit = {
105144
for (row <- 0 until count) {
@@ -117,39 +156,54 @@ class RowDataSubscriptionSpec extends Specification {
117156
subscriber.lastRowError must beFalse
118157
subscriber.completed must beTrue
119158
subscriber.lastRow mustEqual count - 1
159+
delegate.started must beTrue
160+
delegate.cancelled must beFalse
120161
}
121162

122163
"Thread safety for cancel" ! attempts {_ =>
123164
val subscriber = new TestSubscriber()
124165
implicit val context = ExecutionContext.global
125-
val subscription = new RowDataSubscription(subscriber)
126-
val count = 1000
127-
context.execute(new Runnable {
128-
override def run(): Unit = {
129-
for (row <- 0 until count + 10) {
130-
subscription.nextRow(row)
131-
}
132-
subscription.complete()
133-
}
134-
})
166+
val count = 10
135167
val canceledPromise = Promise[Unit]()
168+
169+
val delegate = new TestRowDataSubscriptionDelegate() {
170+
override def start(subscription: RowDataSubscription): Unit = {
171+
super.start(subscription)
172+
context.execute(new Runnable {
173+
override def run(): Unit = {
174+
for (row <- 0 until count + 10) {
175+
subscription.nextRow(row)
176+
}
177+
subscription.complete()
178+
}
179+
})
180+
}
181+
182+
override def cancel(subscription: RowDataSubscription): Unit = {
183+
super.cancel(subscription)
184+
canceledPromise.success()
185+
}
186+
}
187+
val subscription = new RowDataSubscription(subscriber, delegate, bufferSize = 1)
136188
context.execute(new Runnable {
137189
override def run(): Unit = {
138190
for (row <- 0 until count) {
139191
subscription.request(1)
140192
}
141193
subscription.cancel()
142-
canceledPromise.success()
143194
}
144195
})
145196
Await.ready(canceledPromise.future, Duration(10, TimeUnit.SECONDS))
146197

147198
subscriber.lastRowError must beFalse
148199
subscriber.completed must beFalse
149200
subscriber.lastRow must beLessThanOrEqualTo(count - 1)
201+
delegate.started must beTrue
202+
delegate.cancelled must beTrue
150203
}
151204

152-
def newSubscription(subscriber : TestSubscriber) = new RowDataSubscription(subscriber)(SameThreadExecutionContext)
205+
def newSubscription(subscriber : TestSubscriber, delegate : RowDataSubscriptionDelegate, bufferSize : Int) =
206+
new RowDataSubscription(subscriber, delegate, bufferSize)(SameThreadExecutionContext)
153207

154208
implicit def intToTestRowData(rowNumber : Int) : RowData = TestRowData(rowNumber)
155209
case class TestRowData(rowNumber : Int) extends RowData {
@@ -202,4 +256,42 @@ class RowDataSubscriptionSpec extends Specification {
202256

203257
}
204258
}
259+
260+
class TestRowDataSubscriptionDelegate extends RowDataSubscriptionDelegate {
261+
var cancelled = false
262+
override def cancel(subscription: RowDataSubscription): Unit = {
263+
if (!started) {
264+
throw new IllegalStateException("Not started")
265+
}
266+
if (cancelled) {
267+
throw new IllegalStateException("Already canceled")
268+
}
269+
cancelled = true
270+
}
271+
var paused = false
272+
override def pause(subscription: RowDataSubscription): Unit = {
273+
if (paused) {
274+
throw new IllegalStateException("Already paused")
275+
}
276+
if (!started) {
277+
throw new IllegalStateException("Not started")
278+
}
279+
paused = true
280+
}
281+
282+
override def continue(subscription: RowDataSubscription): Unit = {
283+
if (!paused) {
284+
throw new IllegalStateException("Not paused")
285+
}
286+
if (!started) {
287+
throw new IllegalStateException("Not started")
288+
}
289+
paused = false
290+
}
291+
292+
var started = false
293+
override def start(subscription: RowDataSubscription): Unit = {
294+
started = true
295+
}
296+
}
205297
}

0 commit comments

Comments
 (0)