DEV Community

Cover image for Implement bidirectional streaming gRPC - Go
TECH SCHOOL
TECH SCHOOL

Posted on • Updated on

Implement bidirectional streaming gRPC - Go

Hi everyone!

Today we’re gonna learn how to implement the last type of gRPC: bidirectional streaming, or bidi-streaming. This streaming allows client and server to send multiple requests and multiple responses to each other in parallel.

Here's the link to the full gRPC course playlist on Youtube
Github repository: pcbook-go and pcbook-java
Gitlab repository: pcbook-go and pcbook-java

In this lecture, we will write an API for client to rate a stream of laptops with score from 1 to 10, and the server will respond with a stream of average scores for each of the laptops.

Alright, let’s start!

1. Define bidi-streaming gRPC protobuf

The first thing we need to do is to define a new bidi-streaming RPC in the laptop_service.proto file.

We define the RateLaptopRequest with 2 fields: the laptop ID and the score.

message RateLaptopRequest {
  string laptop_id = 1;
  double score = 2;
}
Enter fullscreen mode Exit fullscreen mode

Then the RateLaptopResponse with 3 fields: the laptop ID, the number of time this laptop was rated, and the average rated score.

message RateLaptopResponse {
  string laptop_id = 1;
  uint32 rated_count = 2;
  double average_score = 3;
}
Enter fullscreen mode Exit fullscreen mode

Now we define the RateLaptop RPC with input is a stream of RateLaptopRequest, and output is a stream of RateLaptopResponse.

service LaptopService {
  ...
  rpc RateLaptop(stream RateLaptopRequest) returns (stream RateLaptopResponse) {};
}
Enter fullscreen mode Exit fullscreen mode

After we run make gen in the terminal to regenerate the codes, we can see an error inside the server/main.go file. This is because the LaptopServiceServer interface now requires 1 more method: RateLaptop, which should be implemented by the LaptopServer struct. We can find the signature of this method inside the pb/laptop_service.pb.go file.

type LaptopServiceServer interface {
    CreateLaptop(context.Context, *CreateLaptopRequest) (*CreateLaptopResponse, error)
    SearchLaptop(*SearchLaptopRequest, LaptopService_SearchLaptopServer) error
    UploadImage(LaptopService_UploadImageServer) error
    RateLaptop(LaptopService_RateLaptopServer) error
}
Enter fullscreen mode Exit fullscreen mode

So let’s copy it, and paste in the laptop_server.go file. Let's just return nil for now. We will come back to implement this method later.

func (server *LaptopServer) RateLaptop(stream pb.LaptopService_RateLaptopServer) error {
    return nil
}
Enter fullscreen mode Exit fullscreen mode

2. Implement the rating store

Now we need to create a new rating store to save the laptop ratings.

I will define a RatingStore interface. It has 1 function Add that takes a laptop ID and a score as input, and returns the updated laptop rating or an error.

type RatingStore interface {
    Add(laptopID string, score float64) (*Rating, error)
}
Enter fullscreen mode Exit fullscreen mode

The rating consists of 2 fields: one is count, which is the number of times the laptop is rated, and the other is the sum of all rated scores.

type Rating struct {
    Count uint32
    Sum   float64
}
Enter fullscreen mode Exit fullscreen mode

Then we will write an in-memory rating store that implements the interface. Similar to the in-memory laptop store, here we will need a mutex to handle concurrent access. And we have a rating map with key is the laptop ID, and value is the rating object.

type InMemoryRatingStore struct {
    mutex  sync.RWMutex
    rating map[string]*Rating
}
Enter fullscreen mode Exit fullscreen mode

Then we define a function to create a new in-memory rating store. In this function, we just need to initialize the rating map.

func NewInMemoryRatingStore() *InMemoryRatingStore {
    return &InMemoryRatingStore{
        rating: make(map[string]*Rating),
    }
}
Enter fullscreen mode Exit fullscreen mode

OK, now let’s implement the Add function!

As we’re going to change the internal data of the store, we have to acquire a lock. Then we get the rating of the laptop ID from the map.

If the rating is not found, we just create a new object with count is 1 and sum is the input score. Else, we increase the rating count by 1 and add the score to the sum.

func (store *InMemoryRatingStore) Add(laptopID string, score float64) (*Rating, error) {
    store.mutex.Lock()
    defer store.mutex.Unlock()

    rating := store.rating[laptopID]
    if rating == nil {
        rating = &Rating{
            Count: 1,
            Sum:   score,
        }
    } else {
        rating.Count++
        rating.Sum += score
    }

    store.rating[laptopID] = rating
    return rating, nil
}
Enter fullscreen mode Exit fullscreen mode

Finally we put the updated rating back to the map and return it to the caller.

Then we’re done with the store. Now let’s go back to implement the server.

3. Implement the bidi-streaming gRPC server

We add a new ratingStore to the LaptopServer struct and the NewLaptopServer() function.

type LaptopServer struct {
    laptopStore LaptopStore
    imageStore  ImageStore
    ratingStore RatingStore
}

func NewLaptopServer(laptopStore LaptopStore, imageStore ImageStore, ratingStore RatingStore) *LaptopServer {
    return &LaptopServer{laptopStore, imageStore, ratingStore}
}
Enter fullscreen mode Exit fullscreen mode

Now let’s implement the RateLaptop function!

Since we will receive multiple requests from the stream, we must use a for loop here. Similar to what we did on the client-streaming RPC, before doing anything, let’s check the context error to see if it’s already canceled or deadline exceeded or not.

func (server *LaptopServer) RateLaptop(stream pb.LaptopService_RateLaptopServer) error {
    for {
        err := contextError(stream.Context())
        if err != nil {
            return err
        }

        req, err := stream.Recv()
        if err == io.EOF {
            log.Print("no more data")
            break
        }
        if err != nil {
            return logError(status.Errorf(codes.Unknown, "cannot receive stream request: %v", err))
        }
        ...
    }

    return nil
}
Enter fullscreen mode Exit fullscreen mode

Then we call stream.Recv() to get a request from the stream. If error is end of file (EOF), then there’s no more data, we simply break the loop. Else if error is not nil, we log it and return the error with status code unknown to the client.

Otherwise, we can get the laptop ID and the score from the request. Let’s write a log here saying that we have received a request with this laptop ID and score.

func (server *LaptopServer) RateLaptop(stream pb.LaptopService_RateLaptopServer) error {
    for {
        ...
        laptopID := req.GetLaptopId()
        score := req.GetScore()

        log.Printf("received a rate-laptop request: id = %s, score = %.2f", laptopID, score)

        found, err := server.laptopStore.Find(laptopID)
        if err != nil {
            return logError(status.Errorf(codes.Internal, "cannot find laptop: %v", err))
        }
        if found == nil {
            return logError(status.Errorf(codes.NotFound, "laptopID %s is not found", laptopID))
        }

        ...
    }

    return nil
}
Enter fullscreen mode Exit fullscreen mode

Then we should check if this laptop ID really exists or not by using the laptopStore.Find() function. If an error occurs, we return it with the status code Internal. If the laptop is not found, we return the status code NotFound to the client.

If everything goes well, we call ratingStore.Add() to add the new laptop score to the store and get back the updated rating object.

func (server *LaptopServer) RateLaptop(stream pb.LaptopService_RateLaptopServer) error {
    for {
        ...

        rating, err := server.ratingStore.Add(laptopID, score)
        if err != nil {
            return logError(status.Errorf(codes.Internal, "cannot add rating to the store: %v", err))
        }

        res := &pb.RateLaptopResponse{
            LaptopId:     laptopID,
            RatedCount:   rating.Count,
            AverageScore: rating.Sum / float64(rating.Count),
        }

        err = stream.Send(res)
        if err != nil {
            return logError(status.Errorf(codes.Unknown, "cannot send stream response: %v", err))
        }
    }

    return nil
}
Enter fullscreen mode Exit fullscreen mode

If there’s an error, we return Internal status code. Else, we create a RateLaptopResponse with laptop ID is the input laptop ID, rated count taken from the rating object, and average score is computed using the sum and count of the rating.

Then we call stream.Send() to send the response to the client. If error is not nil, we log it and return status code Unknown.

And that’s it! We’re done with the server.

3. Implement the bidi-streaming gRPC client

Now before moving to the client, I will add a new function to generate a random laptop score in the sample package. To be simple, let’s say it’s gonna be a random integer between 1 and 10.

func RandomLaptopScore() float64 {
    return float64(randomInt(1, 10))
}
Enter fullscreen mode Exit fullscreen mode

Alright, now let’s implement the client!

First we define a rateLaptop() function with 3 input parameters: a laptop client, a list of laptop IDs and their corresponding scores. In this function, we create a new context with timeout after 5 seconds. Then we call laptopClient.RateLaptop() with the created context.

func rateLaptop(laptopClient pb.LaptopServiceClient, laptopIDs []string, scores []float64) error {
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()

    stream, err := laptopClient.RateLaptop(ctx)
    if err != nil {
        return fmt.Errorf("cannot rate laptop: %v", err)
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

The output is a stream or an error. If error is not nil, we just return it. Else, we will have to make a channel to wait for the responses from the server. The waitResponse channel will receive an error when it occurs, or a nil if all responses are received successfully.

Note that the requests and responses are sent concurrently, so we have to start a new go routine to receive the responses. In the go routine, we use a for loop, and call stream.Recv() to get a response from the server.

func rateLaptop(laptopClient pb.LaptopServiceClient, laptopIDs []string, scores []float64) error {
    ...

    waitResponse := make(chan error)

    // go routine to receive responses
    go func() {
        for {
            res, err := stream.Recv()
            if err == io.EOF {
                log.Print("no more responses")
                waitResponse <- nil
                return
            }
            if err != nil {
                waitResponse <- fmt.Errorf("cannot receive stream response: %v", err)
                return
            }

            log.Print("received response: ", res)
        }
    }()

    ...
}
Enter fullscreen mode Exit fullscreen mode

If error is EOF, it means there’s no more responses, so we send nil to the waitResponse channel, and return. Else, if error is not nil, we send the error to the waitResponse channel, and return as well. If no errors occur, we just write a simple log.

OK, now after this go routine, we can start sending requests to the server. Let’s iterate through the list of the laptops and create a new request for each of them with the input laptop ID and the corresponding input scores.

func rateLaptop(laptopClient pb.LaptopServiceClient, laptopIDs []string, scores []float64) error {
    ...

    // send requests
    for i, laptopID := range laptopIDs {
        req := &pb.RateLaptopRequest{
            LaptopId: laptopID,
            Score:    scores[i],
        }

        err := stream.Send(req)
        if err != nil {
            return fmt.Errorf("cannot send stream request: %v - %v", err, stream.RecvMsg(nil))
        }

        log.Print("sent request: ", req)
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Then we call stream.Send() to send the request to the server. If we get an error, just return it. Note that here we call stream.RecvMsg() to get the real error, just like what we did in the previous lecture with client-streaming RPC. If no error occurs, we write a log saying the request is sent.

Now one important thing that we must do after sending all requests, which is, to call stream.CloseSend() to tell the server that we won’t send any more data. And finally read from the waitResponse channel and return the received error.

func rateLaptop(laptopClient pb.LaptopServiceClient, laptopIDs []string, scores []float64) error {
    ...

    err = stream.CloseSend()
    if err != nil {
        return fmt.Errorf("cannot close send: %v", err)
    }

    err = <-waitResponse
    return err
}
Enter fullscreen mode Exit fullscreen mode

The rateLaptop() function is completed. Now we will write a testRateLaptop() function to call it.

Let’s say we want to rate 3 laptops, so we declare a slice to keep the laptop IDs. We use a for loop to generate a random laptop, save its ID to the slice, and call createLaptop() function to create it on the server.

func testRateLaptop(laptopClient pb.LaptopServiceClient) {
    n := 3
    laptopIDs := make([]string, n)

    for i := 0; i < n; i++ {
        laptop := sample.NewLaptop()
        laptopIDs[i] = laptop.GetId()
        createLaptop(laptopClient, laptop)
    }

    ...
}
Enter fullscreen mode Exit fullscreen mode

Then we also make a slice to keep the scores. I want to rate these 3 laptops in multiple rounds, so I will use a for loop here and ask if we want to do another round of rating or not.

func testRateLaptop(laptopClient pb.LaptopServiceClient) {
    ...

    scores := make([]float64, n)
    for {
        fmt.Print("rate laptop (y/n)? ")
        var answer string
        fmt.Scan(&answer)

        if strings.ToLower(answer) != "y" {
            break
        }

        for i := 0; i < n; i++ {
            scores[i] = sample.RandomLaptopScore()
        }

        err := rateLaptop(laptopClient, laptopIDs, scores)
        if err != nil {
            log.Fatal(err)
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

If the answer is no, we break the loop. Else we generate a new set of scores for the laptops and call rateLaptop() function to rate them with the generated scores.

If an error occurs, we write a fatal log. In the main function, just call testRateLaptop(), and we’re all set.

func main() {
    serverAddress := flag.String("address", "", "the server address")
    flag.Parse()
    log.Printf("dial server %s", *serverAddress)

    conn, err := grpc.Dial(*serverAddress, grpc.WithInsecure())
    if err != nil {
        log.Fatal("cannot dial server: ", err)
    }

    laptopClient := pb.NewLaptopServiceClient(conn)
    testRateLaptop(laptopClient)
}
Enter fullscreen mode Exit fullscreen mode

4. Run the bidi-streaming gRPC server and client

Let’s run the server, then run the client.

Run the client

3 laptops are created. Let's press y to rate these laptops.

Rate round 1

As you can see, we sent 3 requests with scores of 10, 8 and 4, and received 3 responses with rated count of 1, and average scores of 10, 8 and 4.

Let’s do another rating round!

Rate round 2

This time the scores we sent are 6, 1 and 5, and the responses has rated count of 2, and the average scores have been updated to 8, 4.5 and 4.5, which are all correct.

5. Test bidi-streaming gRPC

Now I’m gonna show you how to test this bidirectional streaming RPC. Let’s go to service/laptop_client_test.go file.

The test setup will be very similar to the upload image test that we've written in the last lecture.

We just create a new laptop store, new rating store, generate a random laptop and save it to the store.

func TestClientRateLaptop(t *testing.T) {
    t.Parallel()

    laptopStore := service.NewInMemoryLaptopStore()
    ratingStore := service.NewInMemoryRatingStore()

    laptop := sample.NewLaptop()
    err := laptopStore.Save(laptop)
    require.NoError(t, err)

    ...
}
Enter fullscreen mode Exit fullscreen mode

Then we start the test laptop server to get the server adress, and use it to create a test laptop client.

func TestClientRateLaptop(t *testing.T) {
    ...

    serverAddress := startTestLaptopServer(t, laptopStore, nil, ratingStore)
    laptopClient := newTestLaptopClient(t, serverAddress)

    stream, err := laptopClient.RateLaptop(context.Background())
    require.NoError(t, err)

    ...
}
Enter fullscreen mode Exit fullscreen mode

After that, we call laptopClient.RateLaptop() with a background context to get the stream, and require no error.

For simplicity, we just rate 1 single laptop, but we will rate it 3 times with a score of 8, 7.5 and 10 respectively. So the expected average score after each time should be 8, 7.75 and 8.5.

We define n as the number of rated times, and use a for loop to send multiple requests.

func TestClientRateLaptop(t *testing.T) {
    ...

    scores := []float64{8, 7.5, 10}
    averages := []float64{8, 7.75, 8.5}

    n := len(scores)
    for i := 0; i < n; i++ {
        req := &pb.RateLaptopRequest{
            LaptopId: laptop.GetId(),
            Score:    scores[i],
        }

        err := stream.Send(req)
        require.NoError(t, err)
    }

    err = stream.CloseSend()
    require.NoError(t, err)

    ...
}
Enter fullscreen mode Exit fullscreen mode

Each time we will create a new request with the same laptop ID and a new score. We call stream.Send() to send the request to the server, and require no errors to be returned. After sending all the rate laptop requests, we call stream.CloseSend() just like what we did in the client code.

To be simple, I don't create a separate go routine to receive the responses. Here I simply use a for loop to receive them, and use an idx variable to count how many responses we have received.

func TestClientRateLaptop(t *testing.T) {
    ...

    for idx := 0; ; idx++ {
        res, err := stream.Recv()
        if err == io.EOF {
            require.Equal(t, n, idx)
            return
        }

        require.NoError(t, err)
        require.Equal(t, laptop.GetId(), res.GetLaptopId())
        require.Equal(t, uint32(idx+1), res.GetRatedCount())
        require.Equal(t, averages[idx], res.GetAverageScore())
    }
}
Enter fullscreen mode Exit fullscreen mode

Inside the loop, we call stream.Recv() to receive a new response. If error is EOF, then it’s the end of the stream, we just require that the number of responses we received must be equal to n, which is the number of requests we sent, and we return immediately.

Else, there should be no error. The response laptop ID should be equal to the input laptop ID. The rated count should be equal to idx + 1. And the average score should be equal to the expected value.

Now let’s run the test.

Run unit test

It passed. Excellent!

And that wraps up today’s lecture about implementing bidirectional streaming RPC in Go.

Thank you for reading! Happy coding, and I will see you in the next lecture!


If you like the article, please subscribe to our Youtube channel and follow us on Twitter for more tutorials in the future.


If you want to join me on my current amazing team at Voodoo, check out our job openings here. Remote or onsite in Paris/Amsterdam/London/Berlin/Barcelona with visa sponsorship.

Top comments (0)