Skip to content

Commit 883c780

Browse files
committed
refactor execution
1 parent 05cd33b commit 883c780

15 files changed

+170
-68
lines changed

db-async-common/src/main/scala/com/github/mauricio/async/db/Configuration.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ object Configuration {
4343
* OOM or eternal loop attacks the client could have, defaults to 16 MB. You can set this
4444
* to any value you would like but again, make sure you know what you are doing if you do
4545
* change it.
46-
* @param defaultWindowSize how many rows to retrieve in one windows. This count will be retrieved from database
46+
* @param fetchSize how many rows to retrieve in one windows. This count will be retrieved from database
4747
* in one request. So in case of back pressure this amount of information will be preserved.
4848
* It can be specified for a particular query.
4949
*
@@ -59,4 +59,4 @@ case class Configuration(username: String,
5959
allocator: AbstractByteBufAllocator = PooledByteBufAllocator.DEFAULT,
6060
connectTimeout: Duration = 5.seconds,
6161
testTimeout: Duration = 5.seconds,
62-
defaultWindowSize : Int = 1000)
62+
fetchSize : Int = 1000)

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/PostgreSQLConnection.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,12 @@ class PostgreSQLConnection
107107
promise.future
108108
}
109109

110-
def streamQuery(query: String, windowSize : Int = configuration.defaultWindowSize): Publisher[RowData] = {
110+
def streamQuery(query: String, values: Seq[Any] = List(), fetchSize : Int = configuration.fetchSize): Publisher[RowData] = {
111111
validateQuery(query)
112112

113113
new Publisher[RowData] {
114114
override def subscribe(s: Subscriber[_ >: RowData]): Unit = {
115-
new RowDataSubscription(s, new SubscriptionDelegate(query), bufferSize = windowSize/2)
115+
new RowDataSubscription(s, new SubscriptionDelegate(query), bufferSize = fetchSize/2)
116116
}
117117
}
118118
}
@@ -136,7 +136,7 @@ class PostgreSQLConnection
136136
processor.columnTypes = holder.columnDatas
137137
write(
138138
if (holder.prepared)
139-
new PreparedStatementExecuteMessage(holder.statementId, holder.realQuery, values, this.encoderRegistry)
139+
new PreparedStatementWholeExecuteMessage(holder.statementId, holder.realQuery, values, this.encoderRegistry)
140140
else {
141141
holder.prepared = true
142142
new PreparedStatementOpeningMessage(holder.statementId, holder.realQuery, values, this.encoderRegistry)

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageEncoder.scala

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,15 @@
1616

1717
package com.github.mauricio.async.db.postgresql.codec
1818

19+
import java.nio.charset.Charset
20+
1921
import com.github.mauricio.async.db.column.ColumnEncoderRegistry
2022
import com.github.mauricio.async.db.exceptions.EncoderNotAvailableException
2123
import com.github.mauricio.async.db.postgresql.encoders._
22-
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
2324
import com.github.mauricio.async.db.postgresql.messages.frontend._
2425
import com.github.mauricio.async.db.util.{BufferDumper, Log}
25-
import java.nio.charset.Charset
26-
import scala.annotation.switch
27-
import io.netty.handler.codec.MessageToMessageEncoder
2826
import io.netty.channel.ChannelHandlerContext
27+
import io.netty.handler.codec.MessageToMessageEncoder
2928

3029
object MessageEncoder {
3130
val log = Log.get[MessageEncoder]
@@ -35,31 +34,34 @@ class MessageEncoder(charset: Charset, encoderRegistry: ColumnEncoderRegistry) e
3534

3635
import MessageEncoder.log
3736

38-
private val executeEncoder = new ExecutePreparedStatementEncoder(charset, encoderRegistry)
37+
private val executeWholeEncoder = new PreparedStatementWholeExecuteEncoder(charset, encoderRegistry)
3938
private val openEncoder = new PreparedStatementOpeningEncoder(charset, encoderRegistry)
4039
private val startupEncoder = new StartupMessageEncoder(charset)
4140
private val queryEncoder = new QueryMessageEncoder(charset)
4241
private val credentialEncoder = new CredentialEncoder(charset)
42+
private val closeEncoder = new PreparedStatementCloseEncoder(charset)
43+
private val bindEncoder = new PreparedStatementBindEncoder(charset, encoderRegistry)
44+
private val executeEncoder = new PreparedStatementExecuteEncoder(charset)
4345

4446
override def encode(ctx: ChannelHandlerContext, msg: AnyRef, out: java.util.List[Object]) = {
4547

4648
val buffer = msg match {
47-
case message: ClientMessage => {
48-
val encoder = (message.kind: @switch) match {
49-
case ServerMessage.Close => CloseMessageEncoder
50-
case ServerMessage.Execute => this.executeEncoder
51-
case ServerMessage.Parse => this.openEncoder
52-
case ServerMessage.Startup => this.startupEncoder
53-
case ServerMessage.Query => this.queryEncoder
54-
case ServerMessage.PasswordMessage => this.credentialEncoder
49+
case message: ClientMessage =>
50+
val encoder = message match {
51+
case CloseMessage => CloseMessageEncoder
52+
case _ : PreparedStatementWholeExecuteMessage => this.executeWholeEncoder
53+
case _ : PreparedStatementOpeningMessage => this.openEncoder
54+
case _ : StartupMessage => this.startupEncoder
55+
case _ : QueryMessage => this.queryEncoder
56+
case _ : CredentialMessage => this.credentialEncoder
57+
case _ : PreparedStatementCloseMessage => this.closeEncoder
58+
case _ : PreparedStatementBindMessage => this.bindEncoder
59+
case _ : PreparedStatementExecuteMessage => this.executeEncoder
5560
case _ => throw new EncoderNotAvailableException(message)
5661
}
57-
5862
encoder.encode(message)
59-
}
60-
case _ => {
63+
case _ =>
6164
throw new IllegalArgumentException("Can not encode message %s".format(msg))
62-
}
6365
}
6466

6567
if (log.isTraceEnabled) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.github.mauricio.async.db.postgresql.encoders
2+
3+
import java.nio.charset.Charset
4+
5+
import com.github.mauricio.async.db.column.ColumnEncoderRegistry
6+
import com.github.mauricio.async.db.postgresql.messages.frontend.{ClientMessage, PreparedStatementBindMessage}
7+
import io.netty.buffer.ByteBuf
8+
9+
class PreparedStatementBindEncoder(charset: Charset, encoder : ColumnEncoderRegistry) extends Encoder with PreparedStatementEncoderHelper {
10+
override def encode(message: ClientMessage): ByteBuf = {
11+
val m = message.asInstanceOf[PreparedStatementBindMessage]
12+
val statementIdBytes = m.statementId.toString.getBytes(charset)
13+
14+
bind(statementIdBytes, m.query, m.values, encoder, charset)
15+
}
16+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.github.mauricio.async.db.postgresql.encoders
2+
3+
import java.nio.charset.Charset
4+
5+
import com.github.mauricio.async.db.postgresql.messages.frontend.{ClientMessage, PreparedStatementOpeningMessage}
6+
import io.netty.buffer.{ByteBuf, Unpooled}
7+
8+
class PreparedStatementCloseEncoder(charset: Charset) extends Encoder with PreparedStatementEncoderHelper {
9+
override def encode(message: ClientMessage): ByteBuf = {
10+
val m = message.asInstanceOf[PreparedStatementOpeningMessage]
11+
12+
val statementIdBytes = m.statementId.toString.getBytes(charset)
13+
val closeBuffer: ByteBuf = closePortal(statementIdBytes)
14+
val syncBuffer: ByteBuf = sync
15+
Unpooled.wrappedBuffer(syncBuffer, closeBuffer)
16+
}
17+
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/PreparedStatementEncoderHelper.scala

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,28 @@ trait PreparedStatementEncoderHelper {
4040
writeDescribe: Boolean = false
4141
): ByteBuf = {
4242

43+
val bindBuffer: ByteBuf = bind(statementIdBytes, query, values, encoder, charset)
44+
45+
if (writeDescribe) {
46+
val describeLength = 1 + 4 + 1 + statementIdBytes.length + 1
47+
val describeBuffer = bindBuffer
48+
describeBuffer.writeByte(ServerMessage.Describe)
49+
describeBuffer.writeInt(describeLength - 1)
50+
describeBuffer.writeByte('P')
51+
describeBuffer.writeBytes(statementIdBytes)
52+
describeBuffer.writeByte(0)
53+
}
54+
55+
val executeBuffer: ByteBuf = execute(statementIdBytes, 0)
56+
57+
val closeBuffer: ByteBuf = closePortal(statementIdBytes)
58+
59+
val syncBuffer: ByteBuf = sync
60+
61+
Unpooled.wrappedBuffer(bindBuffer, executeBuffer, syncBuffer, closeBuffer)
62+
}
63+
64+
def bind(statementIdBytes: Array[Byte], query: String, values: Seq[Any], encoder: ColumnEncoderRegistry, charset: Charset): ByteBuf = {
4365
if (log.isDebugEnabled) {
4466
log.debug(s"Preparing execute portal to statement ($query) - values (${values.mkString(", ")}) - ${charset}")
4567
}
@@ -96,39 +118,36 @@ trait PreparedStatementEncoderHelper {
96118
bindBuffer.writeShort(0)
97119

98120
ByteBufferUtils.writeLength(bindBuffer)
121+
bindBuffer
122+
}
99123

100-
if (writeDescribe) {
101-
val describeLength = 1 + 4 + 1 + statementIdBytes.length + 1
102-
val describeBuffer = bindBuffer
103-
describeBuffer.writeByte(ServerMessage.Describe)
104-
describeBuffer.writeInt(describeLength - 1)
105-
describeBuffer.writeByte('P')
106-
describeBuffer.writeBytes(statementIdBytes)
107-
describeBuffer.writeByte(0)
108-
}
109-
124+
def execute(statementIdBytes: Array[Byte], fetchSize: Int): ByteBuf = {
110125
val executeLength = 1 + 4 + statementIdBytes.length + 1 + 4
111126
val executeBuffer = Unpooled.buffer(executeLength)
112127
executeBuffer.writeByte(ServerMessage.Execute)
113128
executeBuffer.writeInt(executeLength - 1)
114129
executeBuffer.writeBytes(statementIdBytes)
115130
executeBuffer.writeByte(0)
116-
executeBuffer.writeInt(0)
131+
executeBuffer.writeInt(fetchSize)
132+
executeBuffer
133+
}
134+
135+
def sync: ByteBuf = {
136+
val syncBuffer = Unpooled.buffer(5)
137+
syncBuffer.writeByte(ServerMessage.Sync)
138+
syncBuffer.writeInt(4)
139+
syncBuffer
140+
}
117141

142+
def closePortal(statementIdBytes: Array[Byte]): ByteBuf = {
118143
val closeLength = 1 + 4 + 1 + statementIdBytes.length + 1
119144
val closeBuffer = Unpooled.buffer(closeLength)
120145
closeBuffer.writeByte(ServerMessage.CloseStatementOrPortal)
121146
closeBuffer.writeInt(closeLength - 1)
122147
closeBuffer.writeByte('P')
123148
closeBuffer.writeBytes(statementIdBytes)
124149
closeBuffer.writeByte(0)
125-
126-
val syncBuffer = Unpooled.buffer(5)
127-
syncBuffer.writeByte(ServerMessage.Sync)
128-
syncBuffer.writeInt(4)
129-
130-
Unpooled.wrappedBuffer(bindBuffer, executeBuffer, syncBuffer, closeBuffer)
131-
150+
closeBuffer
132151
}
133152

134153
def isNull(value: Any): Boolean = value == null || value == None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.github.mauricio.async.db.postgresql.encoders
2+
3+
import java.nio.charset.Charset
4+
5+
import com.github.mauricio.async.db.postgresql.messages.frontend.{ClientMessage, PreparedStatementExecuteMessage}
6+
import io.netty.buffer.ByteBuf
7+
8+
class PreparedStatementExecuteEncoder(charset: Charset) extends Encoder with PreparedStatementEncoderHelper {
9+
override def encode(message: ClientMessage): ByteBuf = {
10+
val m = message.asInstanceOf[PreparedStatementExecuteMessage]
11+
12+
val statementIdBytes = m.statementId.toString.getBytes(charset)
13+
execute(statementIdBytes, m.fetchSize)
14+
}
15+
}
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ package com.github.mauricio.async.db.postgresql.encoders
1919
import java.nio.charset.Charset
2020

2121
import com.github.mauricio.async.db.column.ColumnEncoderRegistry
22-
import com.github.mauricio.async.db.postgresql.messages.frontend.{ClientMessage, PreparedStatementExecuteMessage}
22+
import com.github.mauricio.async.db.postgresql.messages.frontend.{ClientMessage, PreparedStatementWholeExecuteMessage}
2323
import io.netty.buffer.ByteBuf
2424

25-
class ExecutePreparedStatementEncoder(
25+
class PreparedStatementWholeExecuteEncoder(
2626
charset: Charset,
2727
encoder : ColumnEncoderRegistry)
2828
extends Encoder
@@ -31,7 +31,7 @@ class ExecutePreparedStatementEncoder(
3131

3232
def encode(message: ClientMessage): ByteBuf = {
3333

34-
val m = message.asInstanceOf[PreparedStatementExecuteMessage]
34+
val m = message.asInstanceOf[PreparedStatementWholeExecuteMessage]
3535
val statementIdBytes = m.statementId.toString.getBytes(charset)
3636

3737
writeExecutePortal( statementIdBytes, m.query, m.values, encoder, charset )
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.github.mauricio.async.db.postgresql.messages.frontend
2+
3+
import com.github.mauricio.async.db.column.ColumnEncoderRegistry
4+
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
5+
6+
class PreparedStatementBindMessage(statementId: Int, query: String, values: Seq[Any], encoderRegistry : ColumnEncoderRegistry)
7+
extends PreparedStatementMessage(statementId, ServerMessage.Bind, query, values, encoderRegistry)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package com.github.mauricio.async.db.postgresql.messages.frontend
2+
3+
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
4+
5+
class PreparedStatementCloseMessage(val statementId: Int) extends ClientMessage(ServerMessage.CloseStatementOrPortal)

0 commit comments

Comments
 (0)