Skip to main content

Why I'm making my app free

I'm an Indie Hacker and I've decided to make my app free. Sounds like a terrible decision, right?

I'm making a browser extension that provides smart subtitles for Chinese videos, partly provided by my OCR subtitle extraction tool. I'm not doing it for the money, for that there's always the risk free alternative of regular employment. Then is it to become rich? Clearly not, then I'd definitely not target the fairly saturated niche of language learning. Besides, a one-man project is unlikely to generate that kind of revenue, that is left for the VC funded startups of the world. Why a one-man project then? It's mainly that at this point in our life with small kids, we want as much freedom as possible. That includes deciding when and where to work. Involving other people means meetings, oh so many meetings, specific work hours, less creative freedom and many a large pressure to turn a profit.

What I really want is to create as much value as possible while covering our expenses as a family. So the calculus is simple: maximize the "value" I can provide to other people, i.e. "number of users" times "value provided per user", subject to the constraint of paying most of our bills at some point in the future.

My conclusion after thinking about it for a long time is that the "normal" path of building an app with subscription or premium features is not what I want to do. It would probably be the fastest way to meet the constraint of covering our expenses, but it would severly impact how much value I could create. People are getting more comfortable paying for software nowadays, but it is still a very tiny minority who are willing to do that. Freemium with some features hidden behind a paywall is one way to go, but it creates this tension where you need to inconvenience people just enough so that enough people pay up. I don't like that dynamic, and usually it means hiding some very valuable features behind paywalls.

So here's my simple plan: make anything that can be provided for free, free (some things may not, like anything that requires a back-end server to run). Then try to make some money at the margins like this:

  1. Patreon: give some perks to patrons like the ability to vote on features, request specific TV shows to import, and even get the binary to run OCR themselves
  2. VPN ads: I suspect people learning Chinese (either foreigners or heritage learners) have a greater than average need for a VPN, both for traveling to China, but also to access Chinese services from abroad. VPN companies also happen to pay out great commissions.
  3. OCR as a service: Perhaps there is some small market for providing OCR on video subtitles as a service. At some point I'll create a separate landing page and buy some keywords to try this out.

This strategy is quite freeing from a technological and UX perspective. When building a subscription service on the web, it's almost a requirement to put most of your code on the server, so as to protect against piracy and hacking. This often makes sense when there are significant need for centralized coordination between user accounts, like social features, or there is a need for heavy processing that can't be done client-side.

My feeling though is that in many cases it's purely a user hostile choice, as a way to silo the data and protect against compentition. Making the app free simultaneously incentivices me to make it cheap to host, with a close to zero marginal cost per new user. This also happens to be in the user's best interest, as it means the app is less likely to disappear as it's not highly dependent on a server to run, and the user has complete control over their client-side data! There may be social features incorporated at some point, and there will be a need for some kind of back-end for syncing data between clients, but with a very limited server-side functionality this should not be too expensive to run.

We can meet in VR instead of polluting our way across the globe

In the past few years I've noticed an increasing trend of techno-pessimism in the West, whether it's the (perhaps legitimate) fear of Social Media corrupting our youth, critique of billionaires going to space instead of feeding the poor, AI-infused murder bots, or a dystopian VR Metaverse run by Mark Zuckerberg. Out of these I think VR is perhaps the most misunderstood.

It's been interesting to have witnessed the very rapid shift in public sentiment in the last few years. When I bought my Oculus Devkit 2 back in 2014, VR was still mainly a curiousity. The reaction from friends and family at that time was mainly "neat!". By the time they tried the HTC Vive I think they could see a bit more clearly what potential VR has, and there was still not much of a negative sentiment. But when Meta, then Facebook, started pushing the idea of a Metaverse, the sentiment pushed firmly into "dystopian" territory. I think it was at this point that people seriously started considering what a future with this technology would look like, and it's just too alien for most people. But in my view, VR is not dystopian, it's probably a necessity for the survival of the planet.

We live in a finite world, with finite resources. VR essentially creates an infinite world for us to live in. We can meet in VR instead of polluting our way across the globe. We can "own" and use virtual objects instead of objects made out of finite atoms mined in Australia, reshaped in China and finally sent on a huge boat across the Pacific. We can have deep experiences that rival anything the world of atoms can provide. We can actually be more social in VR, because the cost of meeting and hanging out is approximately zero. I don't see it as dystopian at all. What's dystopian is 7 billion people all trying to live the American dream while the planet dies.

I don't think we'll completely stop living in the real world, nor should we, but I think moving some activities to VR can alleviate many of the problems we're facing. At the same time we should definitely invest in transitioning to renewable energy, electrification of transportation and industry and other environmental technologies. As a self-proclaimed techno-optimist, I also think that technology can enable us to consume less resources while maintaining our standard of living. One such example is self-driving cars: they would allow us to have overall fewer cars on the road while maintaining the same transportation capacity, while also reducing the need for garages and parking lots. However, the phenomenon of induced demand means that in the end there will probably not be any fewer cars on the road. Living more virtually however, would reduce our need for transportation, our need for stuff. It's not the whole solution, but I think it's an important part of it.

(Note: some of the benefits of VR are also true of AR, but to a lower degree since it's always anchored to the real world)

Solving Minesweeper

Many years ago I took a class in Artificial Intelligence, in which the final group project was to build a program that played Minesweeper (although this would hardly be considered AI today!). Like most group projects everybody prioritized other classes, me included. Eventually I tried to pull an all-nighter to implement it, in a language I had just "learned": C++. Never doing that again. But I've always wanted to actually finish it, so here it goes. This time I'm using Python, numpy, and other libraries which makes the task much easier. I'll try to do things without looking at other implementations, because where's the fun in that?

The Algorithm

The problem is how to deduce which squares have mines and which don't based on the neighboring constraints which tell us how many neighboring mines there are. Most of the closed squares don't have any explicit constraints on them, so unless we're left with no choice, we'll focus on the edges which do have constraints. If we have an edge region of N squares with interlocking constraints, there are $2^N$ mine configurations that could be considered. Brute-forcing this can start to get prohibitively slow real fast. Luckily, we can use the constraints to limit how many configurations we have to try.

Consider this partially solved board:

First, it would be helpful to display our updated constraints for this partially solved board. Let's do this by showing two pieces of information per square:

  1. The number of remaining neighboring mines, i.e. once we flag a square as a mine, we decrement this number
  2. The number of remaining closed and unflagged neighbors. Again, we decrement this number for the neighbors when we open a square.

Here is what that looks like:

In the right image squares, the upper left number is the number of remaining mine neighbors, and the lower right number is the number of remaining unopened neighbors.

We can now set up two simple rules:

  1. If a square says there are K neighboring mines left, and K closed and unflagged neighbors, we know that all those neighbors have to have mines
  2. If a square has no neighoring mines left, then we can safely open all closed neighbors

Applying these rules to the board above, we see that the square with the green box can be safely opened, and the square with the red box definitely has a mine:

Suprisingly, repeatedly applying these two simply rules and updating the constraints solves many beginner and intermediate boards all by itself, but sometimes (and more so for expert difficulty) we get stuck. Let's look at such a case that happens later on:

In this case, there is no simple rule we can apply. We have to try all possible configurations and then make a decision based on the result. This can be done by a constraint satifaction search. Essentially, we go through the squares in order one by one, and assigns it a mine or not mine. Then we check what happens with our constraints:

  1. If we assigned "mine" and a neigboring constraint now has -1 number of mines, the board is inconsistent
  2. If we assigned "not mine" and a neighboring constraint says there are more neigboring mines than there are closed squares, the board is also inconsistent.

If an assignment is consistent, we continue on to the next square.

If both possible assignments are inconsistent, we backtrack and try different assignments on the previous squares. Here's what this looks like in action:

The result of this search can give us many different consistent configurations, possibly with differing number of mines.

Since we have all possible configurations, we can look at each individual square and see how many times it was a mine versus not a mine. The search above yielded 3 possible solutions (starting from the top right corner):

|1|0|0|1|0|0|1|0|0|1|0|0|1|0|
|1|0|0|1|0|0|0|1|0|0|1|0|0|1|
|0|1|0|0|1|0|0|0|0|1|0|0|1|0|

If we average the result for each square, we get:

|$\frac{2}{3}$|$\frac{1}{3}$|$\frac{0}{3}$|$\frac{2}{3}$|$\frac{1}{3}$|$\frac{0}{3}$|$\frac{1}{3}$|$\frac{1}{3}$|$\frac{0}{3}$|$\frac{2}{3}$|$\frac{1}{3}$|$\frac{0}{3}$|$\frac{2}{3}$|$\frac{1}{3}$|

The same result shown on the board:

Each square shows the probability of having a mine in percentage points. Highlighted in green are the squares with zero probability, meaning we can safely open them.

This begs the question of what to do if there are no 0% squares. Then we simply have to take a chance on the lowest probability square. But at this point we also have to consider other squares that are not at the edge. They may also be worth opening. If you look at the bottom row, these squares have a 17% probability, how do we go about calculating this?

Well, we know how many mines there are in total and how many we've flagged, let's call those $M_t$ and $M_f$ respectively. Then we also have the probabilities for each edge square being a mine $P(S_i = mine)$. Then we have the number of "inner" closed squares, $N$.

The expected, or average number of mines in the edge is simply the sum of the probabilities: $\sum_{i} P(S_i = mine)$. Again, it's possible that some configurations have more mines than others, so this estimate doesn't have to be a round integer.

The final probability for the "inner" closed squares is then: $$\frac{max(M_t - M_f - \sum_{i} P(S_i = mine), 0)}{N}$$

Let's calculate it for the case above: $$\frac{10 - 4 - (6*0.33 + 4*0.67)}{8} = 0.1675$$

Using these inner probabilities, we can now pick the lowest probability square overall, not only those at the edge.

Advanced Strategies

What if we have many choices of lowest probability square? Here's where it gets a little bit complicated.

If we have multiple choices, we probably want to choose the one that gives us the most valuable information. One way to look at it is that we'd prefer if the square we pick has a higher chance of being a "zero", meaning it has no constraints at all. What happens in that case is that we automatically expand the whole area underneath. This gives us more information than if we hit a numbered square right away, in which case we probably have to make another risky choice for the next square.

What is the probability of a square being a "zero" then? Well, it depends on the mine probabilities of its neighbors. If a neighbor has a 90% chance of being a mine, then this square would have at least a 90% chance of having a constraint. So we want to combine the probabilites of the square and its neighbors. Since they are independent events, we can calculate that as the probability of the whole neighborhood being not mines: $$P(S_i = zero) = \prod_{j \in N(i)}{P(S_j = \neg mine)} $$ Where $N(i)$ is the neighborhood of square i.

Let's look at the starting board. The mine probabilities are uniform since we have no information:

But the probability of being a zero is higher along the edges, and highest in the corners, because they have fewer neighbors that could be mines:

This explains the common strategy of opening the corners first, since the difference is quite big for this beginner board: 22% vs 51% chance of being zero. This strategy of course also applies later in the game, but then incorporates the mine probabilities from the configuration search.

We can test the different strategies by running the solver on a large number of boards and comparing the win ratios:

def estimate_win_ratio(num_runs, **board_kwargs):
  wins, losses = 0, 0
  for i in range(num_runs):
    m = MinesweeperBoard(seed=i, **board_kwargs)
    step = m.solve()
    if m.dead and step == 0: # Lost on first guess
      continue
    else:
      if m.dead:
        losses += 1
      else:
        wins += 1
  return wins / (wins + losses)
    
num_runs = 10000
print(f'Win ratios for strategies over {num_runs} runs:')
print(f'Random lowest probability mine: {estimate_win_ratio(num_runs, difficulty="expert", max_solutions=2000, prob_pick_strategy="random")}')
print(f'Max zero probability: {estimate_win_ratio(num_runs, difficulty="expert", max_solutions=2000, prob_pick_strategy="prob_zero")}')
Win ratios for strategies over 10000 runs:
Random lowest probability mine: 0.4014
Max zero probability: 0.4945

Using the max zero probability strategy is clearly better, by almost 10 percentage points. Note that for these win rations I excluded runs that immediately hit mines.

Now, I don't think the algorithm is optimal quite yet. Have a look at this impasse, where we end up picking the corner square based on the reasoning above:

While opening the corner improves our chances of a zero square, perhaps we would get even more information by picking a square close to the edge, in order to loosen up the deadlock. Intuitively, we'd want to pick the square which reduces the uncertainty about the edge probabilities. We could use all the configurations produced by the search, and figure out if there is a square we can open to disambiguate between them.

We can also see that the difference in probability between the corner square and the 33% edge squares is 1%. This is such a small difference that it could make sense to pick one of those edge squares in order to split the edge region in half, divide and conquer style.

In the end, it will probably boil down to a tradeoff between minimizing mine probability and maximizing expected information gain. Exacly how to go about doing that I'll leave for a later time. For now, I'll leave you with the solving of an "expert" difficulty board:

m = MinesweeperBoard(difficulty='expert', seed=10,
                     visualize_constraints=False, visualize_lvl=1,
                     visualize_output='video')
m.solve()
m.display_video(framerate=10, controls=True, autoplay=True, loop=True, max_width='100%')

And the code for the whole thing:

import os
import math
import glob
import random
import warnings
import subprocess
from time import time
from base64 import b64encode

import cv2
import numpy as np
import PIL.Image
from IPython.display import display, HTML
import IPython.display
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
from skimage.measure import label
from scipy.ndimage.morphology import distance_transform_edt

np.set_printoptions(precision=2)

# Load the minesweeper tiles
IMG_SQUARE_SIZE = 20
IMAGES = {}
for name in list(range(9)) + ['bomb', 'facingDown', 'flagged']:
  img = cv2.imread(f'minesweeper/{name}.png')
  IMAGES[str(name)] = cv2.resize(img, (IMG_SQUARE_SIZE, IMG_SQUARE_SIZE))

    
def draw_text_centered(img, text, size, color, thickness=1):
  font = cv2.FONT_HERSHEY_SIMPLEX
  text_size = cv2.getTextSize(text, font, size, 2)[0]
  
  # get coords based on boundary
  text_x = (img.shape[1] - text_size[0]) // 2
  text_y = (img.shape[0] + text_size[1]) // 2
  return cv2.putText(img, text, (text_x, text_y), font, size, color,
                      thickness=thickness)


def count_neighbors(board):
  """
    For each square, counts the number of neighboring squares that are "on".
    Uses a sum convolution to count neighbors.
    board: a boolean matrix or numerical 0/1 matrix
  """
  return cv2.filter2D(board.astype(float), -1, np.ones((3,3)),
                      borderType=cv2.BORDER_CONSTANT).astype(int)


class MinesweeperBoard:
  def __init__(self, width=None, height=None, num_mines=None, difficulty=None,
               seed=None, max_solutions=None, prob_pick_strategy='prob_zero_near_edge', 
               visualize_constraints=True, visualize_lvl=0, visualize_zero_probs=None,
               visualize_output='video', visualize_crop=None, visualize_frames=None,
               visualize_iterations=None, visualize_mask_text=None,
               visualize_mask_text_size=0.7):
    """
      width: width of the board
      height: height of the board
      num_mines: number of mines on the board
      difficulty:
        'beginner', 'intermediate' or 'expert', if this is set, width/height/num_mines are not used
      seed: the seed for randomization
      max_solutions: stop the search once this many solutions have been found
      prob_pick_strategy: 
        'random': pick one of the lowest probability mine squares randomly
        'prob_zero': pick the lowest probability mine with the
                     highest probability of being a zero-numbered square
        'prob_zero_near_edge': same as prob_zero, but pick the square nearest
                               an unresolved edge if there are multiple choices
      visualize_constraints:
        True: Displays number of neighboring mines and empty squares left
        False: Displays the number of neighboring mines
      visualize_lvl: 
        0: no visualization
        1: visualize final square picks 
        2: visualize all search steps
      visualize_zero_probs:
        If True, will visualize the probability of a zero-numbered square
      visualize_output:
        'video': outputs video
        'image': outputs individual frames
      visualize_crop:
        A tuple with (y, x, height, width) crop in squares to visualize
      visualize_frames:
        A tuple with the range of frames to visualize
      visualize_iterations:
        A tuple with the range of iterations to visualize
      visualize_mask_text: 
        Masks out squares using text
      visualize_mask_text_size: 
        Font size of the text mask
    """
    difficulties = {
      'beginner': (8, 8, 10),
      'intermediate': (16, 16, 40),
      'expert': (24, 24, 99)
    }
    assert difficulty in [None, 'beginner', 'intermediate', 'expert'] # TODO: Use enum class?
    assert difficulty is not None or (width is not None and height is not None and num_mines is not None)
    if difficulty in difficulties:
      self.width, self.height, self.num_mines = difficulties[difficulty]
    else:
      self.width, self.height, self.num_mines = width, height, num_mines
    self.num_squares = self.width*self.height
    self.max_solutions = max_solutions
    assert prob_pick_strategy in ['random', 'prob_zero', 'prob_zero_near_edge'] # TODO: Use enum class?
    self.prob_pick_strategy = prob_pick_strategy
    assert self.num_mines < self.num_squares
    assert 0 <= visualize_lvl <= 2
    assert visualize_output in ['video', 'image']
    self.visualize_constraints = visualize_constraints
    self.visualize_lvl = visualize_lvl
    self.visualize_zero_probs = visualize_zero_probs
    self.visualize_output = visualize_output
    self.visualize_frames = visualize_frames
    self.visualize_iterations = visualize_iterations
    self.visualize_crop = visualize_crop
    self.visualize_mask_text = visualize_mask_text
    self.visualize_mask_text_size = visualize_mask_text_size

    # Generate the mine locations
    if seed is not None:
      np.random.seed(seed)
      random.seed(seed)
    self.mines = np.zeros((self.height, self.width), dtype=bool)
    # NOTE: sample indices without replacement
    mine_indices = random.sample(range(self.width*self.height), self.num_mines)
    self.mines.ravel()[mine_indices] = True
    assert self.mines.sum() == self.num_mines, (self.mines.sum(), self.num_mines)

    self.text_mask = None
    if self.visualize_mask_text is not None:
      # Add mines that look like the text
      mask = np.zeros_like(self.mines, dtype='uint8')
      mask = draw_text_centered(mask, self.visualize_mask_text, self.visualize_mask_text_size,
                                color=1, thickness=1)
      # Remove mines within 2 steps of the text, so that we can see the text better
      mask_dilated = cv2.dilate(mask.astype('uint8'),
                                np.ones((3, 3), 'uint8'), iterations=2) > 0
      self.mines[mask_dilated > 0] = False
      self.mines[mask > 0] = True
      self.text_mask = mask > 0

    # Pre-calculate a map of neighboring mine counts
    self.neighboring_mine_count = count_neighbors(self.mines)

    # Create an map of labeled empty regions
    self.empty_regions = label(self.neighboring_mine_count == 0, connectivity=1)

    # Create 2D arrays for keeping track of the state of the board and our belief about it
    self.opened = np.zeros((self.height, self.width), dtype=bool)
    self.flagged = np.zeros((self.height, self.width), dtype=bool)
    self.neighboring_mines_left = self.neighboring_mine_count.copy()
    self.neighboring_unresolved_left = 8*np.ones((self.height, self.width), dtype='uint8')
    self.prob_mine = 2*np.ones((self.height, self.width), dtype=float)
    self.dead = False # Set to true if we hit a mine
    
    # Variables used for visualization and debugging
    self.iteration = 0
    self.frame_idx = 0

  def solve(self):
    """ Solves the board or dies trying """
    self.t0 = time()
    self.visualize() # Draw once to see the empty board

    while not self.solved and not self.dead:
      # Get the next square index to open
      open_indices, mine_indices, *rest = self.solve_step()
      prev_opened = self.opened.copy()

      for open_idx in open_indices:
        # Check if there's an empty region under the square
        region = self.empty_regions[open_idx[0], open_idx[1]]
        if region > 0:
          # Dilate the region to open the neighboring squares with neighboring mine counts
          reveal = cv2.dilate((self.empty_regions == region).astype('uint8'),
                              np.ones((3, 3), 'uint8'), iterations=1)
          self.opened[reveal > 0] = True
        else:
          # There was no empty region underneath, just a mine count, so just reveal this square
          self.opened[open_idx[0], open_idx[1]] = True
  
        # Check if there is a mine there
        if self.mines[open_idx[0], open_idx[1]]:
            self.dead = True
            break
      
      for mine_idx in mine_indices:
        # Remove the mine from neighboring mine counts
        ns = self.neighborhood_slice(mine_idx)

        self.neighboring_mines_left[ns] -= 1
        self.flagged[mine_idx[0], mine_idx[1]] = True

      # Recalculate closed and unflagged neighbors
      self.neighboring_unresolved_left = count_neighbors(self.unresolved)

      # Clear the cached probabilities in the regions where we made picks
      labels, unresolved_edges, unresolved_non_edge = rest
      self.prob_mine[unresolved_non_edge] = 2
      just_opened_dilated = cv2.dilate((self.opened & ~prev_opened).astype('uint8'),
                                       np.ones((3, 3), 'uint8'), iterations=1)
      for i in range(1, labels.max()+1):
        search_map = (labels == i) & unresolved_edges
        if (just_opened_dilated & search_map).any():
          self.prob_mine[search_map] = 2 # Reset
  
      # Reset open/mine_indices and display the new board
      self.visualize()

      self.iteration += 1
    
    self.t1 = time()

  def solve_step(self):
    """ Returns the next board square indices to open and the 
    squares we want to mark as mines"""

    # Now we're trying to find a configuration of mines/not mines that agrees with
    # the neighboring mine counts. We split up the problem by connected components in
    # the squares with non-zero neighbor count. For each component, we find the assignment
    # that agrees with neigboring unopened squares. If there are multiple possible assignments
    # we pick a square to open with the lowest probability of being a mine over all 
    
    # Get the neighbors of the numbered squares (i.e. the squares we want to 
    # find an assignment for)
    has_constraints = (self.neighboring_mines_left > 0) | (self.neighboring_unresolved_left > 0)
    numbered_dilated = cv2.dilate((has_constraints & self.opened).astype('uint8'),
                                  np.ones((3, 3), 'uint8'), iterations=1)
    unresolved_edges = self.unresolved & (numbered_dilated > 0)
    unresolved_non_edge = self.unresolved & (numbered_dilated == 0)
    
    # Find connected components of unresolved edges connected by constraints
    # Since that numbered square represents constraints on both components,
    # we have to join them in the optimization
    unresolved_edges_dilated = cv2.dilate(unresolved_edges.astype('uint8'),
                                          np.ones((3, 3), 'uint8'), iterations=1)
    joined_edges = unresolved_edges | (unresolved_edges_dilated & has_constraints & self.opened)
    labels, num_labels = label(joined_edges, connectivity=2, return_num=True)

    # Check for simple cases, where
    # 1. There are as many neighboring mines left as there are closed (unassigned) squares
    #    This means all neighbors are mines
    # 2. If there are no neighboring mines left, then all neighboring closed squares
    #    must be empty
    all_mines = (self.opened & (self.neighboring_mines_left > 0) &
                 (self.neighboring_mines_left == self.neighboring_unresolved_left))
    all_empty = (self.opened & (self.neighboring_mines_left == 0) &
                 (self.neighboring_unresolved_left > 0))
    if all_mines.any() or all_empty.any():
      all_mines_dilated = cv2.dilate(all_mines.astype('uint8'),
                                     np.ones((3, 3), 'uint8'), iterations=1) > 0
      mine_indices = np.vstack(np.where(all_mines_dilated & self.unresolved)).T

      all_empty_dilated = cv2.dilate(all_empty.astype('uint8'),
                                     np.ones((3, 3), 'uint8'), iterations=1) > 0
      empty_indices = np.vstack(np.where(all_empty_dilated & self.unresolved)).T
      self.visualize(open_indices=empty_indices, mine_indices=mine_indices)
      return empty_indices, mine_indices, labels, unresolved_edges, unresolved_non_edge
    

    for i in range(1, num_labels+1):
      search_map = (labels == i) & unresolved_edges
      if np.count_nonzero(search_map) == 0:
        continue

      if not (self.prob_mine[search_map] == 2).any():
        # Result is cached, so skip
        continue

      search_squares = self.visit_in_order(search_map)
      self.visualize(lvl=1, duplicate_num_frames=5, search_squares=search_squares)
      # Assignment map: -1 = no assignment, 0 = not mine, 1 = mine
      assignments = -1*np.ones((self.height, self.width), 'int')
      search_idx = 0
      solutions = [] # keeps track of all consistent solutions found
      num_search_squares = len(search_squares)
      while search_idx < num_search_squares:
        p = search_squares[search_idx]
        ns = self.neighborhood_slice(p)
  
        prev_assignment = assignments[p[0], p[1]]
        reached_max_solutions = (self.max_solutions is not None
                                 and len(solutions) >= self.max_solutions)
        if prev_assignment == 0 or reached_max_solutions:
          # We've cycled through mine/not mine assignemnts, so backtrack
          self.neighboring_unresolved_left[ns] += 1
          if prev_assignment == 1:
            self.neighboring_mines_left[ns] += 1

          assignments[p[0], p[1]] = -1
          search_idx -= 1
  
          if search_idx < 0:
            # We've checked all possible solutions, so exit loop
            break
          self.visualize(lvl=2, assignments=assignments, search_squares=search_squares)
          continue
          
        # We cycle through mine -> not mine
        new_assignment = 1 if prev_assignment == -1 else 0
  
        if new_assignment == 0 and prev_assignment == 1:
          # Add back a mine
          self.neighboring_mines_left[ns] += 1
  
        assignments[p[0], p[1]] = new_assignment
        if new_assignment:
          # Decrement mines_left
          self.neighboring_mines_left[ns] -= 1
            
        # No matter whether we think there's a mine or not, we decrement number
        # of closed neighbors
        if prev_assignment == -1:
          self.neighboring_unresolved_left[ns] -= 1
  
        # If number of mines left is negative, we're in an inconsistent state
        inconsistent_mines = (self.neighboring_mines_left[ns] < 0) & self.opened[ns]
        inconsistent_mines_left = inconsistent_mines.any()
        # If number of closed squares is less than number of mines left, we're in an ainconsistent state
        inconsistent_closed = (self.neighboring_unresolved_left[ns] < self.neighboring_mines_left[ns]) & self.opened[ns]
        inconsistent_closed_left = inconsistent_closed.any()
        
        inconsistent = np.zeros((self.height, self.width), dtype=bool)
        inconsistent[ns] = inconsistent_mines | inconsistent_closed
        self.visualize(lvl=2, inconsistent=inconsistent, assigning_idx=p,
                       assignments=assignments, search_squares=search_squares)
        is_inconsistent = inconsistent_mines_left or inconsistent_closed_left
  
        if not is_inconsistent and search_idx == num_search_squares - 1:
          # We have a possible solution at the last search square, so save it
          solution = assignments[search_squares[:, 0], search_squares[:, 1]]
          self.visualize(lvl=2, duplicate_num_frames=5, solution=solution,
                         assignments=assignments, search_squares=search_squares)
          solutions.append(solution)
        elif not is_inconsistent:
            # Not inconsistent, so progress to the next square
            search_idx += 1
  
      # Tally up the solutions and return the empty squares and mines we're 100% sure about
      # If we have no such squares, we return the lowest mine probability square to be opened.
      num_mine = np.zeros(len(search_squares), dtype=float)
      num_not_mine = np.zeros(len(search_squares), dtype=float)
      solutions_array = np.vstack(solutions)
      num_mine = (solutions_array == 1).sum(axis=0)
      num_not_mine = (solutions_array == 0).sum(axis=0)
  
      num_total = num_mine + num_not_mine
      sq = search_squares
      self.prob_mine[sq[:, 0], sq[:, 1]] = num_mine / num_total

    # Now calculate the mine probability not only the search squares, but also the
    # inside squares. First, we calculate how many mines there are in the edges of
    # all unresolved components, so we can subtract that (and the flagged mines) from
    # the mine count

    # NOTE: num_edge_mines doesn't have to be an integer, since it's possible
    # that different solutions have different number of mines. It's also possible
    # that some solutions have more mines than there are mines left on the board
    num_mines_left = self.mines.sum() - self.flagged.sum()
    num_edge_mines = min(self.prob_mine[(self.prob_mine < 2) & unresolved_edges].sum(),
                         num_mines_left)
    if unresolved_non_edge.any():
      self.prob_mine[unresolved_non_edge] = ((self.mines.sum() - num_edge_mines - self.flagged.sum())
                                              / unresolved_non_edge.sum())

    # If there are squares with 0 probability mine, we return those and the sure mines
    empty_indices = np.vstack(np.where(self.unresolved & (self.prob_mine == 0))).T
    mine_indices = np.vstack(np.where(self.unresolved & (self.prob_mine == 1))).T
    if len(empty_indices) == 0:
      # There are no obvious squares to open, so find the lowest mine probability square
      # and pick the one with the highest probability of being a zero-numbered square
      # NOTE: Although the mine probability would be the same, we prefer zero-numbered
      # squares because they open up 3-8+ new squares with constraints.

      # Calculate the probability that each unresolved square is has zero neighboring mines
      # This is calculated as the product of the probability of all squares in the neighorhood is _not_ a mine
      # NOTE: we do this by a convolution in the log space, since the sum of logs equals
      # the log of the product. Then we exponentiate to get the probability
      prob_zero = np.zeros((self.height, self.width), dtype=float)
      # NOTE: if 1-prob_mine is zero, then log(1-prob_mine) is NaN, and during the
      # convolution these NaNs will spread to the neighbors. We _want_ this to happen
      # so that we can set all NaNs to 0 probability of being a zero-numbered square, since
      # a NaN means there is a mine neighbor
      warnings.filterwarnings('ignore') # Ignore the NaN warning
      prob_zero = np.exp(cv2.filter2D(
          np.log(1 - self.prob_mine), -1, np.ones((3,3)), borderType=cv2.BORDER_CONSTANT))
      warnings.resetwarnings()
      prob_zero[np.isnan(prob_zero)] = 0
      
      self.visualize(duplicate_num_frames=5)
      if self.visualize_zero_probs:
        self.visualize(lvl=1, duplicate_num_frames=5, prob_zero=prob_zero)

      # Now we pick the lowest probability mine to open, but with the highest
      # probability of a zero underneath (no neighboring mines)
      is_min_mine_prob = self.prob_mine == self.prob_mine[self.unresolved].min()
      is_max_zero_prob = prob_zero == prob_zero[is_min_mine_prob].max()
      is_min_mine_max_zero = is_min_mine_prob & is_max_zero_prob
      if self.prob_pick_strategy == 'random':
        best_choices = np.vstack(np.where(is_min_mine_prob)).T
      elif self.prob_pick_strategy == 'prob_zero': 
        best_choices = np.vstack(np.where(is_min_mine_max_zero)).T
      elif self.prob_pick_strategy == 'prob_zero_near_edge':
        distance_from_edge = distance_transform_edt(~unresolved_edges)
        is_one_square_away = distance_from_edge <= math.sqrt(2)
        if (is_min_mine_max_zero & is_one_square_away).any():
          best_choices = np.vstack(np.where(is_min_mine_max_zero & is_one_square_away)).T
          print('Here')
        else:
          best_choices = np.vstack(np.where(is_min_mine_max_zero)).T
        #is_min_dist = distance_from_edge == distance_from_edge[is_min_mine_max_zero].min()
        #best_choices = np.vstack(np.where(is_min_mine_max_zero & is_min_dist)).T

      best_choice = random.choice(best_choices)
      empty_indices = np.array([best_choice])
      self.visualize(duplicate_num_frames=5, open_indices=empty_indices,
                     mine_indices=mine_indices)
    else:
      self.visualize(duplicate_num_frames=10, open_indices=empty_indices,
                    mine_indices=mine_indices)
    return empty_indices, mine_indices, labels, unresolved_edges, unresolved_non_edge

  @property
  def solved(self):
    return (np.array_equal(self.mines, self.flagged)
            and (self.opened | self.flagged).all())

  def visit_in_order(self, search_map):
    """ Returns the indices of `search_map > 0` in order, starting at the edges
    of the components if there are any, and then in a BFS fashion """
    left_to_visit = search_map.copy()
    ordered_search_squares = []
    while left_to_visit.any():
      search_squares = np.vstack(np.where(left_to_visit)).T
      queue = []
      for square in search_squares:
        ns = self.neighborhood_slice(square)
        num = search_map[ns].sum()
        # NOTE: if there <= 2 squares in this slice, then it has to be a start/end point
        # If there are 3 or more, it means `square` is in the center of a path
        if num <= 2:
          queue.append(square)
          break
  
      if len(queue) == 0:
        # There were no start/end points, so must be a loop. Pick any square
        # to start with
        queue.append(search_squares[0])

      while len(queue) > 0:
        s = queue.pop(0)
        left_to_visit[s[0], s[1]] = False
        ordered_search_squares.append(s)
        ns, offset = self.neighborhood_slice(s, return_offset=True)
        neighbors = np.vstack(np.where(search_map[ns])).T + offset
        l1_neighbors, l2_neighbors = [], []
        for n in neighbors:
          if not left_to_visit[n[0], n[1]]: continue
          if np.linalg.norm(n - s) > 1: l2_neighbors.append(n)
          else: l1_neighbors.append(n)

          left_to_visit[n[0], n[1]] = False
        for n in l1_neighbors: queue.append(n)
        for n in l2_neighbors: queue.append(n)

    return np.array(ordered_search_squares)

  def draw_frame(self, **kwargs):
    """ Draw the current board / solver state as an image """
    if not self.visualize_constraints:
      return self._draw_board(**kwargs)

    # Draw normal board and constraints board side-by-side with a white space between
    normal = self._draw_board(**kwargs)
    constraints = self._draw_board(visualize_constraints=True, **kwargs)
    I = IMG_SQUARE_SIZE
    crop = self.visualize_crop or (0, 0, self.height, self.width)
    img = 255*np.ones((I*crop[2], I*(crop[3]*2+1), 3), dtype='uint8')
    img[:, :I*crop[3]] = normal
    img[:, I*(crop[3]+1):] = constraints
    return img

  @property
  def unresolved(self):
    """ Returns the unresolved squares, i.e. the ones that have neither been
    opened nor flagged as mines """
    return ~self.opened & ~self.flagged

  def neighborhood_slice(self, p, return_offset=False):
    """
    Convenience function to get a numpy slice object for the neighborhood around a
    point `p` _within_ the board dimensions. Can return the offset to the board origin
    """
    ns = np.s_[max(p[0]-1, 0):min(p[0]+2, self.height),
                max(p[1]-1, 0):min(p[1]+2, self.width)]
    if return_offset:
      offset = np.array([p[0]+(0 if p[0] == 0 else -1),
                          p[1]+(0 if p[1] == 0 else -1)])
      return ns, offset
    return ns

  def visualize(self, lvl=1, duplicate_num_frames=1, **kwargs):
    """
    Visualizes current state of the solver and either outputs it as an HTML image
    element, or saves the frame for future video creation.
    """
    if self.visualize_lvl < lvl:
      return

    not_in_frame_range = (self.visualize_frames is not None and
                          (self.frame_idx < self.visualize_frames[0] or
                           self.frame_idx >= self.visualize_frames[1]))
    not_in_iteration_range = (self.visualize_iterations is not None and
                              (self.iteration < self.visualize_iterations[0] or
                               self.iteration >= self.visualize_iterations[1]))

    if not_in_frame_range or not_in_iteration_range:
      self.frame_idx += 1
      return
    

    img = self.draw_frame(**kwargs)
    if self.visualize_output == 'image':
      if self.visualize_frames is None:
        print(f'Iteration {self.iteration} frame {self.frame_idx}')
      IPython.display.display(PIL.Image.fromarray(img[..., ::-1]))
    else:
      for i in range(duplicate_num_frames):
        cv2.imwrite(f'board{self.frame_idx:04}-{i:02}.png', img)

    self.frame_idx += 1

  def _draw_board(self, visualize_constraints=None, open_indices=[], mine_indices=[],
                  search_squares=[], solution=None, assignments=None, assigning_idx=None,
                  inconsistent=None, prob_zero=None):
    I = IMG_SQUARE_SIZE
    crop = self.visualize_crop or (0, 0, self.height, self.width)
    img = np.zeros((crop[2] * I, crop[3] * I, 3), 'uint8')

    def _highlight_square(y, x, color=0, thickness=2):
      y_crop = y - crop[0]
      x_crop = x - crop[1]
      s = np.s_[y_crop*I:(y_crop+1)*I, x_crop*I:(x_crop+1)*I]
      t = thickness
      smaller = np.s_[y_crop*I+t:(y_crop+1)*I-t, x_crop*I+t:(x_crop+1)*I-t]
      mask = np.zeros(img.shape[:2], dtype=bool)
      mask[s] = True
      mask[smaller] = False
      img[mask, :] = color

    def _mark_img(img, color=0):
      img = img.copy()
      img[I//3:I-I//3, I//3:I-I//3, :] = color
      return img

    for y in range(crop[0], crop[0]+crop[2]):
      for x in range(crop[1], crop[1]+crop[3]):
        y_crop = y - crop[0]
        x_crop = x - crop[1]
        s = np.s_[y_crop*I:(y_crop+1)*I,
                  x_crop*I:(x_crop+1)*I]

        if self.opened[y, x]:
          if self.flagged[y, x]: img[s] = IMAGES['flagged']
          elif self.mines[y, x]: img[s] = IMAGES['bomb']
          else:
            # Display mines left in upper left corner, and closed left in lower right
            if visualize_constraints:
              square = IMAGES['0'].copy()
              if not (self.neighboring_mines_left[y, x] == 0 and
                      self.neighboring_unresolved_left[y, x] == 0):
                tl_crop = np.s_[:I//2, :I//2] 
                br_crop = np.s_[I//2:, I//2:] 
                mines_left = self.neighboring_mines_left[y, x]
                size = 0.2 if mines_left < 0 else 0.4 # make it smaller for -1 so it fits
                square[tl_crop] = draw_text_centered(
                    square[tl_crop], str(mines_left), size, color=(0, 0, 160))
                square[br_crop] = draw_text_centered(
                    square[br_crop], str(self.neighboring_unresolved_left[y, x]), 0.4, color=(255, 0, 0))
              img[s] = square
            else:
              img[s] = IMAGES[str(self.neighboring_mine_count[y, x])]

          if inconsistent is not None and inconsistent[y, x]:
            _highlight_square(y, x, color=(0, 0, 255))
        else:
          if self.flagged[y, x] or (assignments is not None and assignments[y, x] == 1):
            if self.text_mask is not None and self.text_mask[y, x]:
              # If it's a text mine, flip the channels so that the flags are blue
              # instead of red, to increase visual contrast
              img[s] = IMAGES['flagged'][..., ::-1]
            else:
              img[s] = IMAGES['flagged']
          elif assignments is not None and assignments[y, x] == 0:
            img[s] = _mark_img(IMAGES['facingDown'])
          else:
            if prob_zero is not None and prob_zero[y, x] != 0:
              text = str(int(round(100*prob_zero[y, x])))
              img[s] = draw_text_centered(IMAGES['facingDown'].copy(), text,
                                           size=0.3, color=0)
            elif self.prob_mine is not None and self.prob_mine[y, x] < 2:
              text = str(int(round(100*self.prob_mine[y, x])))
              img[s] = draw_text_centered(IMAGES['facingDown'].copy(), text,
                                           size=0.3, color=0)
            else:
              img[s] = IMAGES['facingDown']

    # Draw various highlights
    for open_idx in open_indices:
      _highlight_square(*open_idx, color=(0, 255, 0))

    for mine_idx in mine_indices:
      _highlight_square(*mine_idx, color=(0, 0, 255))

    for search_idx in search_squares:
      color = 255 if solution is None else (0, 255, 0)
      _highlight_square(*search_idx, color=color, thickness=2)

    if assigning_idx is not None:
      _highlight_square(*assigning_idx, color=0)

    return img

  @property
  def time(self):
    if None in [self.t0, self.t1]: return None
    return self.t1 - self.t0

  def display_video(self, framerate=2, autoplay=False, loop=False, controls=False,
                    width=None, max_width=None):
    """
    Displays a video of the step by step board solution by using the images
    writted to disk and combining them with ffmpeg, finally displaying it as an HTML
    video element
    """
    if self.visualize_output != 'video':
      return

    # Call ffmpeg to create video
    command = [
      'ffmpeg', '-y', '-framerate', str(framerate), '-pattern_type', 'glob', '-i', '"board*.png"',
      '-c:v', 'libx264', '-r', '30', '-pix_fmt', 'yuv420p', 'tmp.mp4'
    ]
    subprocess.call(' '.join(command), shell=True, stderr=subprocess.STDOUT)
    
    # Remove the images
    for filename in glob.glob('board*.png'):
        os.remove(filename)

    # Create a data url with the video data
    with open('tmp.mp4', 'rb') as f:
      data_url = "data:video/mp4;base64," + b64encode(f.read()).decode()
    # Remove the video
    os.remove('tmp.mp4')

    width_str = 'width="'+width+'"' if width is not None else ""
    max_width_str = 'style="max-width:'+max_width+'"' if max_width is not None else ""
    # Display the video HTML element
    html = (f'<video {"controls" if controls else ""}'
            f'       {width_str}'
            f'       {max_width_str}'
            f'       {"autoplay" if autoplay else ""}'
            f'       {"loop" if loop else ""}>'
            f'   <source src="{data_url}" type="video/mp4">'
            f'</video>')
    display(HTML(html))

Are parents picking less common names?

I think we all have the gut feeling that parents nowadays try to pick more unique names, or at least not names that are too common. Instead of relying on a gut feeling or anectdata, let's check!

Being Swedish, of course I'm mostly interested in Swedish name statistics, and luckily SCB (the central statistics bureau) provides an Excel spread sheet with the frequency statistics of the top 100 names for boys and girls from 1998 to 2017. Let's download it:

! wget -O stats.xlsx https://www.scb.se/hitta-statistik/statistik-efter-amne/befolkning/amnesovergripande-statistik/namnstatistik/pong/tabell-och-diagram/nyfodda--efter-namngivningsar-och-tilltalsnamn-topp-100/Namn-1998-/

Now we can read the spreadsheet with Python and pandas/numpy (I bet it would be easier to just do this in Excel, but I'm not an Excel-ninja...) and collect the statistics.

Let's load the Excel file and look at the data:

import pandas as pd
import numpy as np

START_YEAR, END_YEAR = 1998, 2017
GENDERS = ['girls', 'boys']

data = pd.DataFrame(columns=['year', 'name', 'gender', 'count', 'percent'])

f = open('stats.xlsx', 'rb')
for year in range(START_YEAR, END_YEAR+1):
    for gender_idx, gender in enumerate(GENDERS):
        sheet_idx = 2*(year - START_YEAR) + gender_idx + 1
        df = pd.read_excel(f, sheet_name=sheet_idx, header=6, usecols='B:D')
        count_col = 'Antal' if 'Antal' in df.columns else 'Antal bärare'        
        counts = df[count_col].values[1:]
        name_col_idx = list(df.columns).index(count_col) - 1
        names = df.iloc[1:, name_col_idx].str.strip() # NOTE: need to strip spaces
        # When there are ties in count, there will be more than 100 rows. Keep the top 100
        counts, names = counts[:100], names[:100]
        years = [year]*len(counts)
        genders = [gender]*len(counts)
        data = data.append(
            pd.DataFrame(data={'year': years, 'gender': genders, 'name': names, 'count': counts, 
                               'percent': counts.astype(float) / counts.sum()}),
            sort=True)

print(data)
    count gender    name   percent  year
1    1468  girls    Emma  0.043336  1998
2    1171  girls   Julia  0.034568  1998
3    1043  girls    Elin  0.030790  1998
4    1037  girls  Amanda  0.030613  1998
5    1006  girls   Hanna  0.029697  1998
..    ...    ...     ...       ...   ...
96    145   boys    Thor  0.004173  2017
97    142   boys  Milian  0.004087  2017
98    141   boys    Levi  0.004058  2017
99    141   boys    Vide  0.004058  2017
100   139   boys     Neo  0.004000  2017

[4000 rows x 5 columns]

Let's plot what the top 100 names looked like in 1998 and 2017:

import seaborn as sns
import matplotlib.pyplot as plt

for year in [1998, 2017]:
    fig, ax = plt.subplots(figsize=(15, 5))
    girls_year = data[(data['year'] == year) & (data['gender'] == 'girls')]
    ax = sns.barplot(x='name', y='percent', data=girls_year,
                     color='blue')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90);
    ax.set(ylim=(0, 0.05));
    ax.set_title(f'Girls {year}');

These distributions look quite different, but the question is what's a good measure for the unevenness of the name distribution so that we can compare different years. I'm sure there are many options for this, but I'll pick perplexity which is the exponentiated entropy. The entropy of the distribution measures the disorder, i.e. we get the highest entropy if all names have an equal probability, and low entropy if everyone named their child the same way. Perplexity can be seen as the average number of "choices" one has in the random variable, e.g. a 6-sided fair die has a perplexity of 6, while a cheating die that always ends up on one number has a perplexity of 1. This is a bit easier to interpret than entropy, which is the average number of bits needed to encode an event from that distribution.

Let's calculate the perplexity over the name distribution for each year and gender and plot it over time:

perplexity_data = pd.DataFrame(columns=['year', 'perplexity', 'gender'])
for year in range(START_YEAR, END_YEAR+1):
    for gender in GENDERS:
        counts = data.loc[(data['year'] == year) & (data['gender'] == gender), 'count'].astype(float)
        probs = counts / counts.sum()
        entropy = -(probs * np.log(probs)).sum()
        perplexity = np.exp(entropy)
        perplexity_data = perplexity_data.append(
            {'year': year, 'perplexity': perplexity, 'gender': gender}, ignore_index=True)

fig, ax = plt.subplots(figsize=(10, 5))
ax.set_xlim(START_YEAR, END_YEAR)
ax.set_xticks(range(START_YEAR, END_YEAR+1))
ax = sns.lineplot(x='year', y='perplexity', hue='gender', data=perplexity_data)

It seems like within the top 100 names, the answer is an unequivocal yes! Parents are spreading out their name choices more and more, approaching 90/100 which is quite even.

Another interesting thing to check is whether names comes into, and goes out of fashion quicker over the years. My intuition tells me that due to naming lists being published online etc, this might be the case. One way we could check is to measure some kind of difference between the name distributions between adjacent years. One measure we could use is the KL-divergence, which tells us in a sense how different one distribution is from another expected distribution:

kl_divs = []
for year1 in range(START_YEAR, END_YEAR):
    year2 = year1 + 1
    data_year1 = data.loc[data['year'] == year1]
    data_year2 = data.loc[data['year'] == year2]
    names = list(set(data_year1['name']) & set(data_year2['name']))
    name_stats_year1 = data_year1.loc[data_year1['name'].isin(names), ['name', 'count']]
    name_stats_year2 = data_year2.loc[data_year2['name'].isin(names), ['name', 'count']]
    name_stats_year1 = name_stats_year1.sort_values('name')
    name_stats_year2 = name_stats_year2.sort_values('name')
    q = name_stats_year1['count'].to_numpy().astype(float)
    q = q / q.sum()
    p = name_stats_year2['count'].to_numpy().astype(float)
    p = p / p.sum()
    kl_div = -(p*np.log(q/p)).sum()
    kl_divs.append(kl_div)

fig, ax = plt.subplots(figsize=(10, 5))
ax.set_xlim(START_YEAR, END_YEAR)
ax.set_xticks(range(START_YEAR, END_YEAR+1))
kl_div_data = pd.DataFrame(data={
    'year': list(range(START_YEAR, END_YEAR)), 'KL divergence': kl_divs})
ax = sns.lineplot(x='year', y='KL divergence', data=kl_div_data)

Surprisingly, for me at least, it seems intra-year KL divergence has gone down, but come to think of it, it makes sense. In 1998 when the distribution was more uneven, naming trends (e.g. a popular name waning) would have a large impact on the naming distribution, while a more even distribution suggests smaller changes between years.

"Can I Say This In Chinese" with BERT

Introduction

As it turns out, most people are not very inclined to teaching. I'm learning Chinese, my wife is Chinese, seems like a match made in heaven. Except that she has no patience whatsoever with my broken Chinese (though she's wonderful in many other ways). Whenever I ask how to say something in Chinese, she anwers with either "I don't know" or "you can't say that (followed by no explanation)". The only way I can get anything out of her is by trying to say something in Chinese and asking whether it sounds right or not. This is less mentally taxing for her than actually having to translate from English, which I understand, especially for two languages so dissimilar.

Now I'm thinking, with the recent advances in Natural Language Processing with Deep Learning, maybe I can create something to replace my unwilling wife. The academic name for this task seems to be "Linguistic Acceptability". Exactly what this includes seems to be up for debate. For example, "the mouse ate the cat" is perfectly grammatical, although highly unlikey. Then there are sentences which are grammatical but seem logically impossible, like "the cat is a bus". This sentence makes no sense unless you've watched the movie Totoro, which features a... cat that is also a bus. Since this seems like a very difficult problem, I'll be focusing more on distinguishing grammatical vs. ungrammatical rather than sensical vs. nonsensical.

Defining the problem

Recent Deep Learning architectures like BERT and GPT-2 basically train a language model or LM, i.e. given the surrounding context, they try to predict the missing word. In GPT-2s case, it predicts the next word given all the previous words in the sentence, while BERT predicts a missing word (a cloze) given both the words before and after it (the B in BERT stands for bidirectional). As such, GPT-2 works better as a language model, defining the joint probability over a sequence of words, while BERT's masked LM is less straight forward to use as such. As a reminder, the joint probability can be refactored recursively using the chain rule:

$$P(w_{1:n}) = P(w_n | w_{1:n-1})P(w_{1:n-1}) = P(w_n | w_{1:n-1}) \cdot \ldots \cdot P(w_2 | w_1)P(w_1)$$

Each of these factors is exactly what we get out of GPT-2, which means if we run inference and multiply the factors we get the joint probability, or actually more of an unormalized likelihood, of the whole sentence. BERT on the other hand gives us $P(w_k | w_{1:k-1}, w_{k+1:n})$ which is harder to intepret. There is research exploring ways of getting a joint probability model out of BERT using MRFs (Markov Random Fields), but I'd like to keep things simple for this little project.

Using GPT-2 will be difficult, since training it from scratch, having 1.5 billion weights, requires a cluster of GPUs and roughly $50k. So I'm constrained to pre-trained versions, of which there is none for Chinese AFAIK. The Python library pytorch-transformers does however have a pre-trained BERT for Chinese.

How can we use BERT?

Being constrained by time and money leaves me no option but to use BERT at this point. While BERT can't be used as a language model per-se, we can perhaps use the output in some useful way.

We'd like to get a binary decision whether a sentence is acceptable or not. We could try to use the masked probability for each word in the sentence, but again, it will be difficult to find some absolute thresold to distinguish unlikely sentences from unacceptable ones. What we could do is to train a classifier based on BERT with a dataset of positive and negative examples. While there are such datasets for other languages (CoLA - Corpus of Linguistic Acceptablility), I have not found such a dataset for Chinese.

I was however able to crawl some examples from the AllSet grammar wiki (licensed with CC-NC) with in total 436 and 461 negative and positive examples respectively, split into grammar groups based on page (note: this will take some time to run):

! wget --quiet --mirror --convert-links --adjust-extension --follow-tags=a --no-parent resources.allsetlearning.com/chinese/grammar/
! grep -r -e 'class="x"' resources.allsetlearning.com/chinese/**/* |\
  sed -e 's/<li class="x">//g' -e 's/<span .*//g' -e 's/<\/*[a-z]*>//g' -e 's/ //g' -e 's/:.*→/:/g' \
  > "$cache_path/allset_negative_examples.txt"
! grep -r -e 'class="o"' resources.allsetlearning.com/chinese/**/* |\
  sed -e 's/<li class="o">//g' -e 's/<span .*//g' -e 's/<\/*[a-z]*>//g' -e 's/ //g' -e 's/:.*→/:/g' \
  > "$cache_path/allset_positive_examples.txt"

While it's putting the car before the horse a bit, I suspected (correctly) that this small dataset would not be enough to train a classifier that generalizes well to any output. There are just too few examples to generalize to all the ways sentences can be correct and wrong, although these examples do contain many important and subtle errors learners commit.

Self-supervised learning

Instead of only training on the small dataset, the idea is to pre-train a classifier in a self-supervised way by generating negative examples from positive ones. While the masked probabilities of all the words in a sentence is not enough to tell the acceptability of the sentence, we can assume there is useful information in the relative scores, or losses, between sentences.

Using relative losses, we can generate negative samples from positive ones by finding a mutation that significantly increases the loss. Let's define the loss for a sentence as the average (since we're possibly comparing sentences of differing lengths) Cross-Entropy loss for each word: $$ L(S) = -\frac{1}{N}\sum_{i=1}^{N}{\log(P(w_i | w_{1:i-1}, w_{i+1:N}))} $$ Then we can perform take a correct sentence $S_c$ and perform a random mutation to get $S_m$. If $L(S_m) - L(S_c) > \epsilon$ we consider it to be unacceptable.

Note that even if we could use the bidirectional probabilities/losses to directly do classification, this is something we'd like to avoid since calculating this loss requires a forward pass for every token in the sentence. Using these expensively generated examples to train a classifier let's us bypass this problem.

Hard Negatives

This way we can generate unacceptable sentences from any acceptable one. Now since there are many possible ways to mutate a sentence that increases the loss more than $\epsilon$, we can pick the minimal one that passes this threshold. This is similar to hard negative mining where if you already have a model, you can improve it by sampling hard negatives and retraining the model. This is common in image classification and localization where any part of an image not containing the specified object are potential negative examples. Then it makes sense to pick the ones that are misclassified or get high losses from the initial model.

Mutations

For the actual sentences, we could use the original corpus, but I prefer using sentences from Tatoeba since it is a good source of informal language suitable for learners.

For mutating the sentences, there are a few things we can do:

  • Permute the words
  • Swap two words
  • Insert word (sampled based on corpus frequency)
  • Replace word (sampled based on corpus frequency)
  • Delete word

While we want to mutate the sentences to get unacceptable ones, there is some degree of unacceptability, and we want to generate ones that are hard, i.e. just barely unacceptable. Therefore I exclude random permutations since they are very unlikely to produce something close to acceptability.

Similarly for insertions and word replacements, it makes more sense to sample common words more frequently than rare words since the language has a very long tail of very infrequent words.

Below is the code for loading the Tatoeba dataset and generating hard negatives. (NOTE: this is a lot of not very interesting code, but it is runnable if you run this in a Jupyter Notebook or Google Colab environment). Also worth mentioning is that the starting point for the PyTorch training was this Colab Notebook, which serves as a good tutorial for fine-tuning BERT for sequence classification.

First, installing some pip packages:

!pip install --quiet pytorch-transformers pytorch-nlp hanziconv jieba sympy

Import a pre-trained Masked LM BERT model and define functions for preparing data for this model, as well as functions for predicting based on it, and calculating losses for whole sentences:

import io
import os
import re
import torch
import jieba
import random
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm, trange
from hanziconv import HanziConv
from sympy.ntheory import factorint
from functools import lru_cache
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split, GroupKFold
from pytorch_transformers import BertTokenizer, BertConfig, BertModel
from pytorch_transformers import AdamW, BertForSequenceClassification, BertForMaskedLM
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
from sklearn.metrics import matthews_corrcoef, precision_score, recall_score, accuracy_score

% matplotlib inline

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
torch.cuda.get_device_name(0)

def gpu_usage(print_stats=False):
  """ Convenience function to check GPU memory usage. Returns free memory in GB """
  nvmlInit()
  handle = nvmlDeviceGetHandleByIndex(0)
  info = nvmlDeviceGetMemoryInfo(handle)
  if print_stats:
    print(f"Total memory: {info.total/1e9:.2f} GB")
    print(f"Free memory: {info.free/1e9:.2f} GB")
    print(f"Used memory: {info.used/1e9:.2f} GB")
  return info.free/1e9

# Make sure we have enough memory
if gpu_usage(print_stats=True) < 8:
  raise SystemError('Not enough memory')

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', do_lower_case=True)

# Load pre-trained model (weights)
masked_lm_model = BertForMaskedLM.from_pretrained('bert-base-chinese')
masked_lm_model.cuda()

def prepare_data(df, test_size=0.1, batch_size=32, shuffle=True, add_cls_sep=True):
  sentences = df.sentence.values
  # We need to add special tokens at the beginning and end of each sentence for BERT to work properly
  if add_cls_sep:
    sentences = ["[CLS] " + sentence + " [SEP]" for sentence in sentences]
  has_labels = 'label' in df.columns
  if has_labels:
    labels = df.label.values
  else:
    labels = np.zeros(len(sentences))

  tokenized_texts = [tokenizer.tokenize(sent) for sent in sentences]

  # Set the maximum sequence length. The longest sequence in our training set is 47, but we'll leave room on the end anyway. 
  # In the original paper, the authors used a length of 512.
  MAX_LEN = 128

  # Use the BERT tokenizer to convert the tokens to their index numbers in the BERT vocabulary
  input_ids = [tokenizer.convert_tokens_to_ids(x) for x in tokenized_texts]

  # Pad our input tokens
  input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype="long", truncating="post", padding="post")

  # Create attention masks
  attention_masks = []

  # Create a mask of 1s for each token followed by 0s for padding
  for seq in input_ids:
    seq_mask = [float(i > 0) for i in seq]
    attention_masks.append(seq_mask)

  # Use train_test_split to split our data into train and validation sets for training
  # but if test_size is zero then only generate training sets
  if test_size > 0.0:
    train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(
        input_ids, labels, random_state=2018, test_size=test_size, shuffle=shuffle)
    train_masks, validation_masks, _, _ = train_test_split(
        attention_masks, input_ids, random_state=2018, test_size=test_size, shuffle=shuffle)
  else:
    train_inputs = input_ids
    train_labels = labels
    train_masks = attention_masks
    validation_inputs = []
    validation_labels = []
    validation_masks = []
    
  # Convert all of our data into torch tensors, the required datatype for our model
  train_inputs = torch.tensor(train_inputs)
  validation_inputs = torch.tensor(validation_inputs)
  train_labels = torch.tensor(train_labels)
  validation_labels = torch.tensor(validation_labels)
  train_masks = torch.tensor(train_masks)
  validation_masks = torch.tensor(validation_masks)

  # Create an iterator of our data with torch DataLoader. This helps save on memory during training because, unlike a for loop, 
  # with an iterator the entire dataset does not need to be loaded into memory
  train_data = TensorDataset(train_inputs, train_masks, *([train_labels] if has_labels else []))
  if shuffle:
    train_sampler = RandomSampler(train_data)
  else:
    train_sampler = SequentialSampler(train_data)
  train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

  validation_data = TensorDataset(validation_inputs, validation_masks, *([validation_labels] if has_labels else []))
  validation_sampler = SequentialSampler(validation_data)
  validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)
  return train_dataloader, validation_dataloader


def predict(dataloader, model, has_labels=True):
  """
  Evaluates data from a data loader on a model and returns either a tuple of
  predicted probability and true label if has_labels=True otherwise it returns
  the raw logits
  """
  # Put model in evaluation mode
  model.eval()

  # Predict 
  for i, batch in enumerate(dataloader):
    # Add batch to GPU
    batch = tuple(t.to(device) for t in batch)
    # Unpack the inputs from our dataloader
    if has_labels:
      b_input_ids, b_input_mask, b_labels = batch
    else:
      b_input_ids, b_input_mask = batch

    # Telling the model not to compute or store gradients, saving memory and speeding up prediction
    with torch.no_grad():
      # Forward pass, calculate logit predictions
      logits, *_ = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)

    # Move logits and labels to CPU
    logits = logits.detach().cpu().numpy()
    if has_labels:
      softmax_probs = np.exp(logits[:, 1]) / np.exp(logits).sum(axis=1)
      label_ids = b_labels.to('cpu').numpy()
      for prob, label in zip(softmax_probs, label_ids):
        yield prob, label
    else:
      yield logits

  
def eval_loss_sentences(sentences, masking='char'):
  """
  Evaluate the loss for a list of sentences
  sentences: the list of sentences
  masking: 'word' for whole word, and 'char' for single character masking
  """
  assert masking in ['word', 'char']
  masking_words = masking == 'word'
  indexed_sentence_tokens = []
  tokenized_sentences = []
  sentence_mask_indices = []
  all_examples = []
  for sentence in sentences:
    # NOTE: the tokenizer removes spaces
    tokenized_sentence = tokenizer.tokenize(sentence)
    tokenized_sentence = tokenized_sentence[:128]
    indexed_sentence_tokens.append(tokenizer.convert_tokens_to_ids(tokenized_sentence))
    if masking_words:
      tokenized_sentence = list(t[0] for t in jieba.tokenize(''.join(tokenized_sentence)))
    tokenized_sentences.append(tokenized_sentence)
    mask_indices = []
    char_idx = 0
    for i in range(len(tokenized_sentence)):
      mask_token = tokenized_sentence[i]
      mask_token_parts = len(tokenizer.tokenize(mask_token)) if masking_words else 1
      all_examples.append(''.join(tokenized_sentence[:i]) +
                          ''.join(mask_token_parts*['[MASK]']) +
                          ''.join(tokenized_sentence[i+1:]))
      mask_indices.append((char_idx, char_idx+mask_token_parts))
      char_idx += mask_token_parts
    mask_indices.append('[SEP]')
    sentence_mask_indices.append(mask_indices)

  df = pd.DataFrame(data={'sentence': all_examples})
  dataloader, _ = prepare_data(df, test_size=0.0, batch_size=32, shuffle=False)

  sentence_losses = []
  curr_sentence_loss = 0
  curr_sentence = 0
  curr_mask_idx = 0
  curr_example = 0
  for batch_logits in predict(dataloader, masked_lm_model, has_labels=False):
    for i in range(batch_logits.shape[0]):
      mask_start, mask_end = sentence_mask_indices[curr_sentence][curr_mask_idx]
      for m in range(mask_start, mask_end):
        mask_logits = batch_logits[i][m+1]
        mask_logits_exp = np.exp(mask_logits)
        mask_token_probs = mask_logits_exp / mask_logits_exp.sum()
        mask_entropy = -(mask_token_probs * np.log(mask_token_probs)).sum()
        masked_token_index = indexed_sentence_tokens[curr_sentence][m]
        # Cross-Entropy Loss
        curr_sentence_loss += -np.log(mask_token_probs[masked_token_index])

      curr_mask_idx += 1
      curr_example += 1
      if curr_mask_idx == len(tokenized_sentences[curr_sentence]):
        # We've reached a new sentence, reset and append log prob
        # Normalize sentence loss by number of tokens
        curr_sentence_loss /= len(tokenized_sentences[curr_sentence])
        sentence_losses.append(curr_sentence_loss)
        curr_sentence_loss = 0
        curr_mask_idx = 0
        curr_sentence += 1
  return sentence_losses

Download example sentences from Tatoeba and word frequency dataset:

! wget http://downloads.tatoeba.org/exports/sentences.tar.bz2
! bzip2 -dc sentences.tar.bz2 > "$cache_path/sentences.txt"
! wget https://www.plecoforums.com/download/weibo_wordfreq-release_utf-8-txt.2603 -O "$cache_path/weibo.txt"

Below is the code for reading the Tatoeba and Weibo frequency datasets and generating hard negatives:

orig_sentences = []
with open(cache_path+'/sentences.txt', 'r') as f:
  for line in f:
      splits = line.split('\t')
      if len(splits) < 3:
        continue
      _, lang, zh = line.split('\t')
      if lang != 'cmn': continue
      zh = HanziConv.toSimplified(zh.strip())
      orig_sentences.append(zh)

words = []
counts = []
with open(cache_path+'/weibo.txt', 'r', encoding='utf-8-sig') as f: 
    for line in f.readlines():
        word, count = line.split('\t')
        tokenized_word = tokenizer.tokenize(word)
        if len(tokenized_word) == 0:
          continue
        
        # Skip [UNK] or other garbage unkown to the BERT tokenizer
        skip = False
        for t in tokenized_word:
          if len(t) > 1:
            skip = True
            break
        if skip: continue
        words.append(word)
        counts.append(int(count))

# Calculate the probability and cumulative probability function for words over
# the frequency
counts = np.array(counts)
word_probs = counts / counts.sum()
cdf = np.cumsum(word_probs)

def sample_word():
  """ Sample a random word based on frequency """
  r = random.random()
  idx = np.searchsorted(cdf, r)
  return words[idx]

@lru_cache(maxsize=128)
def middle_coprime(n):
  """ Find the middle coprime of a number, e.g. of all the
      sorted coprimes of n, pick the middle one """
  factors = list(factorint(n).keys())
  coprimes = [1]
  for i in range(n-2, 1, -1):
    coprime = True
    for f in factors:
      if i % f == 0:
        coprime = False
        break
    if coprime:
      coprimes.append(i)
  return coprimes[len(coprimes) // 2]

def pseudo_random_range(from_idx, to_idx=None):
  """
  Visit all indices in a range pseudo-randomly by visiting (ax + b) mod n, 
  where a and n are co-prime. Small and large coprimes tend to not look random,
  so pick the middle one.
  """
  if to_idx is None:
    from_idx, to_idx = 0, from_idx

  n = to_idx - from_idx
  coprime = middle_coprime(n)
  offset = random.randint(0, n-1) if n > 1 else 0
  for i in range(0, n):
    yield from_idx + (coprime*i + offset) % n 


IGNORE = set(['。', '」', '「', ',', ' ', '!', '?', '?', '!', '.', ','])
# Swaps that usually produce acceptable sentences:
POSITIVE_SWAP_GROUPS = [set(['我', '你', '他', '她']), # personal pronouns
                       set(['我们', '你们', '他们', '她们'])] # plural personal pronouns
def is_positive_swap(from_token, to_token):
  swap_set = set([from_token, to_token])
  # Check if both tokens are in a positive swap group, if so we don't swap
  for swap_group in POSITIVE_SWAP_GROUPS:
    if len(swap_set & swap_group) == 2:
      return True
  return False

def generate_delete(sentence, tokens):
  for idx in pseudo_random_range(len(tokens)):    
    token = tokens[idx][0]
    if token in IGNORE:
      continue
    tokens_deleted = tokens[:idx] + tokens[idx+1:]
    yield ''.join(t[0] for t in tokens_deleted)

def generate_insert(sentence, tokens):
  for idx in pseudo_random_range(len(tokens)):    
    word = sample_word()
    tokens_inserted = tokens[:idx] + [(word,)] + tokens[idx:]
    yield ''.join(t[0] for t in tokens_inserted)

def generate_replace(sentence, tokens):
  for idx in pseudo_random_range(len(tokens)):    
    token = tokens[idx][0]
    if token in IGNORE:
      continue
    # Sample words until it's not equal to the token we're replacing
    word = token
    while word == token:
      word = sample_word()
    tokens_replaced = tokens[:idx] + [(word,)] + tokens[idx+1:]
    yield ''.join(t[0] for t in tokens_replaced)

def generate_swap(sentence, tokens):
  token_set = set([t[0] for t in tokens])
  for from_idx in pseudo_random_range(len(tokens)-1):    
    from_token = tokens[from_idx][0]
    if from_token in IGNORE:
      continue

    for to_idx in pseudo_random_range(from_idx, len(tokens)):
      to_token = tokens[to_idx][0]
      if (from_token == to_token or
          to_token in IGNORE):
          continue

      if is_positive_swap(from_token, to_token):
        continue
        
      # Swap the tokens and return the new string
      mtokens = list(tokens)
      mtokens[to_idx], mtokens[from_idx] = mtokens[from_idx], mtokens[to_idx]
      yield ''.join(t[0] for t in mtokens)

def generate_mutated(sentence):
  tokens = list(jieba.tokenize(sentence))
  generators = [#generate_delete(sentence, tokens),
                generate_insert(sentence, tokens),
                generate_replace(sentence, tokens),
                generate_swap(sentence, tokens)]
  pick_probs = np.array([0.15, 0.15, 0.7])
  while len(generators) > 0:
    gen_idx = np.random.choice(np.arange(len(generators)), p=pick_probs)
    random_gen = generators[gen_idx]
    try:
      yield next(random_gen)
    except StopIteration:
      # The generator is out of sentences to generate, so remove it
      del generators[gen_idx]
      pick_probs = np.delete(pick_probs, gen_idx)
      # Need to normalize so probabilities add up to 1
      pick_probs /= pick_probs.sum()


def generate_hard_negatives(sentences, model, loss_threshold=0.5, generate_max=10,
                            debug_print=False):
  """
  Creates hard negative examples, which are sampled based on mutations that
  increase the loss the least but still significantly enough to very likely be a
  true negative.
  """
  sentence_examples = list(sentences)
  for i, sentence in enumerate(sentence_examples):
    # Skip sentences with unknown words or other garbage
    predict_sentences = [sentence]
    generator = generate_mutated(sentence)
    for _ in range(generate_max):
      try:
        predict_sentences.append(next(generator))
      except StopIteration:
        break
    
    losses = eval_loss_sentences(predict_sentences, masking='char')
    print('C: ', sentence)
    for s, l in sorted(zip(predict_sentences[1:], losses[1:]), key=lambda x: x[1]):
      if l - losses[0] > loss_threshold:
        if debug_print:
          print('W: ', s, l, ' +', l-losses[0])
        yield s
        break


negatives_path = cache_path + '/negatives.txt'
if os.path.exists(negatives_path):
  with open(negatives_path, 'r') as f:
    hard_negatives = [l.strip() for l in f.readlines()]
else:
  # NOTE: generating hard negatives takes a long time since to check a single mutation
  # we need to run inference len(sentence) times, and we need to generate a number
  # of mutations for each sentence in order to find a good one
  # So we run a few thousand at a time and store them in case runtime gets recycled
  use_num = 30000
  num_at_a_time = 3000
  use_sentences = orig_sentences[:use_num]
  hard_negatives = []
  for i in range(0, use_num // num_at_a_time):
    if os.path.exists(f'{cache_path}/negatives{i+1}.txt'):
      continue
    with open(f'{cache_path}/negatives{i+1}.txt', 'w') as f:
      sentences = use_sentences[i*num_at_a_time:(i+1)*num_at_a_time]
      for negative in generate_hard_negatives(sentences, masked_lm_model, debug_print=True):
        hard_negatives.append(negative)
        f.write(negative + '\n')

  # Concatenate all files to one
  with open(negatives_path, 'w') as n:
    for i in range(0, use_num // num_at_a_time):
      with open(f'{cache_path}/negatives{i+1}.txt', 'r') as f:
        n.write(f.read())

Fine-tuning BERT

There are plenty of tutorials on how to fine-tune a BERT model. For this experiment I'll use the pre-trained Chinese model in the Python library pytorch-transformers by huggingface. This model is trained with a character-by-character tokenizer, meaning multi-character Chinese words are split into separate word embeddings for each character. This may be suboptimal, unless the model is powerful enough to capture the structure of words, but for now this is what we have to work with.

Below is the code for training and validating the BERT model for classification:

def train(dataloader, epochs=4, model=None, debug_print=False):
  if model is None:
    model = BertForSequenceClassification.from_pretrained("bert-base-chinese", num_labels=2);
    model.cuda()

  param_optimizer = list(model.named_parameters())
  no_decay = ['bias', 'gamma', 'beta']
  optimizer_grouped_parameters = [
      {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
       'weight_decay_rate': 0.01},
      {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
       'weight_decay_rate': 0.0}
  ]
  optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
  
  # Set our model to training mode (as opposed to evaluation mode)
  model.train()
  train_loss_set = []

  # trange is a tqdm wrapper around the normal python range which prints progress
  r = trange(epochs, desc="Epoch") if debug_print else range(epochs)
  for _ in r:
    # Tracking variables
    train_loss = 0
    num_examples, num_steps = 0, 0

    # Train the data for one epoch
    for step, batch in enumerate(dataloader):
      if debug_print: print(f'Batch: {step}')
      # Add batch to GPU
      batch = tuple(t.to(device) for t in batch)
      # Unpack the inputs from our dataloader
      b_input_ids, b_input_mask, b_labels = batch
      # Clear out the gradients (by default they accumulate)
      optimizer.zero_grad()
      # Forward pass
      loss, *_ = model(b_input_ids, token_type_ids=None,
                       attention_mask=b_input_mask, labels=b_labels)
      train_loss_set.append(loss.item())    
      # Backward pass
      loss.backward()
      # Update parameters and take a step using the computed gradient
      optimizer.step()

      # Update tracking variables
      train_loss += loss.item()
      num_examples += b_input_ids.size(0)
      num_steps += 1

    if debug_print: print("Train loss: {}".format(train_loss/num_steps))

  return model


def evaluate(model, dataloader, df=None):
    y_true = []
    y_pred = []
    for prob, label in predict(dataloader, model):
      y_true.append(label)
      y_pred.append(1 if prob > 0.5 else 0)
    return y_true, y_pred


def print_stats(y_true, y_pred, sentences=None, label=None):
  tab = ''
  if label is not None:
    print(f'{label}:')
    tab = '\t'
  print(f'{tab}Matthews Correlaton Coefficient:', matthews_corrcoef(y_true, y_pred))
  print(f'{tab}Accuracy:', accuracy_score(y_true, y_pred))
  print(f'{tab}Precision:', precision_score(y_true, y_pred))
  print(f'{tab}Recall:', recall_score(y_true, y_pred))

Now we can train our first classification model on positive examples from the Tatoeba dataset and our generated hard negatives. Here I'll train the classifier with an increasing number of examples to see if we need more data. Training an iterating is slow, so I prefer to keep it as small as possible for now.

model_path = cache_path + '/self_supervised_classification_model.pt'
if os.path.exists(model_path):
  classification_model = torch.load(model_path)
else:
  training_accuracies = []
  validation_accuracies = []
  classification_model = None
  for num in [3000, 6000, 9000, len(hard_negatives)]:
    hard_negatives_df = pd.DataFrame(data={
        'sentence': hard_negatives[:num] + orig_sentences[num:2*num],
        'orig': orig_sentences[:2*num],
        'label': num*[0]+num*[1]})

    train_dataloader, validation_dataloader = prepare_data(
        hard_negatives_df, test_size=0.1, batch_size=32)
    
    classification_model = train(train_dataloader, epochs=4, debug_print=False)
    print('Train accuracy: ', accuracy(*evaluate(classification_model, train_dataloader)))
    print('Validation accuracy: ', accuracy(*evaluate(classification_model, validation_dataloader)))
  
  # Save to disk, for rerunning and making copies
  torch.save(classification_model, model_path)

df = pd.DataFrame(data={
    'sentence': hard_negatives + orig_sentences[len(hard_negatives):2*len(hard_negatives)],
    'label': len(hard_negatives)*[0] + len(hard_negatives)*[1]})
dataloader, _ = prepare_data(df, test_size=0.0, batch_size=32)
print_stats(*evaluate(classification_model, dataloader), label='Final')

Now let's load the AllSet grammatical wiki examples and train models with cross-validation either from scratch or using the pre-trained model.

One important difference from the previous dataset is that we want to know how well the model generalizes to new unseen grammatical rules rather than just unseen examples. Therefore we split the data into training and validation sets based on the grammatical rule/group, such that examples from the same group never are split between the train and test sets.

allset_negative_examples = defaultdict(list)
with open(cache_path+'/allset_negative_examples.txt', 'r') as f:
  for l in f.readlines():
    filename, sentence = l.split(':')
    allset_negative_examples[filename].append(sentence.strip())
allset_positive_examples = defaultdict(list)
with open(cache_path+'/allset_positive_examples.txt', 'r') as f:
  for l in f.readlines():
    filename, sentence = l.split(':')
    allset_positive_examples[filename].append(sentence.strip())

all_files = list(set(allset_negative_examples.keys()) |
                 set(allset_positive_examples.keys()))
allset_sentences = []
allset_labels = []
allset_groups = []
for g, filename in enumerate(all_files):
  negative = allset_negative_examples[filename]
  positive = allset_positive_examples[filename]
  allset_sentences += negative + positive
  allset_labels += [0]*len(negative) + [1]*len(positive)
  allset_groups += (len(negative)+len(positive))*[g]

allset_sentences = np.array(allset_sentences)
allset_labels = np.array(allset_labels)
allset_groups = np.array(allset_groups)

tatoeba_sample = np.random.choice(orig_sentences, 10000)
hard_negative_sample = np.random.choice(hard_negatives, 10000)
self_supervised_df = pd.DataFrame(data={
    'sentence': list(hard_negative_sample) + list(tatoeba_sample),
    'label': len(hard_negative_sample)*[0] + len(tatoeba_sample)*[1]})
self_supervised_dataloader, _ = prepare_data(self_supervised_df, test_size=0.0,
                                             batch_size=32, shuffle=False)

def cross_validate_allset(initial_model_path=None, epochs=4, n_splits=10,
                          print_progress=True):
  train_results = [[], []]
  test_results = [[], []]
  self_supervised_results = [[], []]
  new_model = None
  if n_splits == 1:
    generator = [(np.arange(len(allset_sentences)),
                 np.arange(len(allset_sentences)))]
  else:
    group_kfold = GroupKFold(n_splits=n_splits)
    generator = group_kfold.split(allset_sentences, allset_labels, allset_groups)

  for i, (train_index, test_index) in enumerate(generator):
    train_examples = allset_sentences[train_index]
    train_labels = allset_labels[train_index]
    test_examples = allset_sentences[test_index]
    test_labels = allset_labels[test_index]
  
    train_dataloader, _ = prepare_data(
        pd.DataFrame(data={'sentence': train_examples, 'label': train_labels}),
        test_size=0.0, batch_size=32)
    test_dataloader, _ = prepare_data(
        pd.DataFrame(data={'sentence': test_examples, 'label': test_labels}),
        test_size=0.0, batch_size=32)
  
    model = None
    if initial_model_path is not None:
      model = torch.load(initial_model_path)

    new_model = train(train_dataloader, epochs=epochs, model=model,
                      debug_print=print_progress)
    
    train_result = evaluate(new_model, train_dataloader)
    test_result = evaluate(new_model, test_dataloader)
    self_supervised_result = evaluate(new_model, self_supervised_dataloader)
    if print_progress:
      print_stats(*train_result, label='AllSet Train')
      print_stats(*test_result, label='AllSet Test')
      print_stats(*self_supervised_result, label='Self-Supervised')

    train_results[0] += train_result[0]
    train_results[1] += train_result[1]
    test_results[0] += test_result[0]
    test_results[1] += test_result[1]
    self_supervised_results[0] += self_supervised_result[0]
    self_supervised_results[1] += self_supervised_result[1]
  
  print_stats(*train_result, label='Overall AllSet Train')
  print_stats(*test_result, label='Overall AllSet Test')
  print_stats(*self_supervised_result, label='Overall Self-Supervised')

  # Return the last model
  return new_model

First, let's train a model from scratch on the AllSet data and see how well it does against against itself as well as against our self-supervised Tatoeba + Hard negative dataset:

cross_validate_allset(initial_model_path=None, epochs=6, n_splits=10, print_progress=False);
Overall AllSet Train:
	Matthews Correlaton Coefficient: 0.978298651254621
	Accuracy: 0.9891304347826086
	Precision: 0.9838337182448037
	Recall: 0.9953271028037384
Overall AllSet Test:
	Matthews Correlaton Coefficient: 0.9366607354497857
	Accuracy: 0.967391304347826
	Precision: 0.94
	Recall: 1.0
Overall Self-Supervised:
	Matthews Correlaton Coefficient: 0.46815654446892113
	Accuracy: 0.7165
	Precision: 0.6568613244457325
	Recall: 0.9066

As you can see, it seems to generalize well on the AllSet data across the folds, meaning somehow it generalizes to unseen grammatical rules. But the performance on the self-supervised dataset is poor. This is probably due to the AllSet data being biased towards easier, illustrative examples, which are substantially different from the average sentence from Tatoeba. It also doesn't cover all the more "obvious" ways sentences can be grammatical.

Now lets do the same thing, but with a model pre-trained on the self-supervised dataset, with the hope that we can generalize on both data sets:

cross_validate_allset(initial_model_path=model_path, epochs=6, n_splits=10, print_progress=False);
Overall AllSet Train:
	Matthews Correlaton Coefficient: 0.9927488225424451
	Accuracy: 0.9963768115942029
	Precision: 0.9976580796252927
	Recall: 0.9953271028037384
Overall AllSet Test:
	Matthews Correlaton Coefficient: 0.9784719757905218
	Accuracy: 0.9891304347826086
	Precision: 0.9791666666666666
	Recall: 1.0
Overall Self-Supervised:
	Matthews Correlaton Coefficient: 0.8895640148971811
	Accuracy: 0.94365
	Precision: 0.9777107785075912
	Recall: 0.908

The overall results show that the model has generalized relatively well to both datasets, although the scores are lower for the self-supervised data set compared to before.

For training the final model, we can get an even better result for the self-supervised data by training it from scratch on both data sets, but with the AllSet data upsampled to match the self-supervised in size, giving both equal importance. Here I'll train it once with a single test set instead of k-fold cross validation, so I don't time out in Google Colab.

final_model_path = cache_path+'/final_model.pt' 
if os.path.exists(final_model_path):
  final_model = torch.load(final_model_path)
else:
  # Again, need to split AllSet into train/test using GroupKFold
  # GroupKFold.split returns all cross-validation sets, but we'll just use the first
  allset_train_idx, allset_test_idx = next(GroupKFold(n_splits=10).split(allset_sentences, allset_labels, allset_groups))
  allset_train = allset_sentences[allset_train_idx]
  allset_train_labels = allset_labels[allset_train_idx]
  allset_test = allset_sentences[allset_test_idx]
  allset_test_labels = allset_labels[allset_test_idx]
  
  # Next split the self-supervised data set into train/test as well
  ss_train, ss_test, ss_train_labels, ss_test_labels =  train_test_split(
      orig_sentences[len(hard_negatives):2*len(hard_negatives)] + hard_negatives,
      [1]*len(hard_negatives) + [0]*len(hard_negatives), test_size=0.1)
  
  # Then combine both data sets, but with upsampling for AllSet so that they are
  # of equal size
  upsample_times = 2*len(hard_negatives) // len(allset_sentences)
  all_train = (list(ss_train) + upsample_times*list(allset_train))
  all_train_labels = (ss_train_labels + upsample_times*list(allset_train_labels))
  
  all_train_dataloader, _ = prepare_data(
      pd.DataFrame(data={'sentence': all_train, 'label': all_train_labels}),
      test_size=0.0, batch_size=32)
  allset_test_dataloader, _ = prepare_data(
      pd.DataFrame(data={'sentence': allset_test, 'label': allset_test_labels}),
      test_size=0.0, batch_size=32)
  ss_test_dataloader, _ = prepare_data(
      pd.DataFrame(data={'sentence': ss_test, 'label': ss_test_labels}),
      test_size=0.0, batch_size=32)
  
  final_model = train(all_train_dataloader, epochs=4,
                      model=torch.load(model_path),
                      debug_print=True)
  
  train_result = evaluate(final_model, all_train_dataloader)
  allset_test_result = evaluate(final_model, allset_test_dataloader)
  ss_test_result = evaluate(final_model, ss_test_dataloader)
  print_stats(*train_result, label='Train')
  print_stats(*allset_test_result, label='AllSet Test')
  print_stats(*ss_test_result, label='Self-Supervised Test')
  torch.save(final_model, final_model_path)
  

And a sanity check on a few new examples I've found by googling, and some I've come up with myself:

incorrect_sentences = [
  '你有没有车吗?',
  '你是很高',
  '你得包很漂亮',
  '这个车很贵',
  '这本车很贵',
  '我碰到他在公园昨天了',
  '在一家中国饭店,马丽见面了汤姆。',
  '他们在法国见面了对方。',
  '马丽结婚了汤姆。',
  '汤姆结婚了马丽。',
  '我喜欢都学生。',
  '这是我的都。',
  '我们开会在明天上午九点 。',
  '我不有时间。'
]
correct_sentences = [
  '你有没有车',
  '你很高',
  '你的包很漂亮',
  '这辆车很贵',
  '这辆车很贵',
  '我昨天在公园碰到他了',
  '在一家中国饭店,马丽和汤姆见面了。',
  '他们在法国和对方见面了。',
  '马丽嫁了汤姆。',
  '汤姆娶了马丽。',
  '我喜欢所有学生。',
  '这是我的所有。',
  '我们明天上午九点开会。',
  '我没有时间。'
]

incorrect_df = pd.DataFrame(data={'sentence': incorrect_sentences, 'label': len(incorrect_sentences)*[0]})
incorrect_dataloader, _ = prepare_data(incorrect_df, test_size=0.0, batch_size=1, shuffle=False)
correct_df = pd.DataFrame(data={'sentence': correct_sentences, 'label': len(correct_sentences)*[1]})
correct_dataloader, _ = prepare_data(correct_df, test_size=0.0, batch_size=1, shuffle=False)
gen = zip(correct_sentences, predict(correct_dataloader, model=final_model, has_labels=True),
          incorrect_sentences, predict(incorrect_dataloader, model=final_model, has_labels=True))
print('Correct | Incorrect')
for correct, (prob_correct, _), incorrect, (prob_incorrect, _) in gen:
  print(f'{correct}: {prob_correct:.2f} | {incorrect}: {prob_incorrect:.2f}')
Correct | Incorrect
你有没有车: 1.00 | 你有没有车吗?: 0.29
你很高: 1.00 | 你是很高: 0.02
你的包很漂亮: 1.00 | 你得包很漂亮: 0.00
这辆车很贵: 1.00 | 这个车很贵: 1.00
这辆车很贵: 1.00 | 这本车很贵: 1.00
我昨天在公园碰到他了: 0.87 | 我碰到他在公园昨天了: 0.00
在一家中国饭店,马丽和汤姆见面了。: 1.00 | 在一家中国饭店,马丽见面了汤姆。: 0.01
他们在法国和对方见面了。: 1.00 | 他们在法国见面了对方。: 0.00
马丽嫁了汤姆。: 0.84 | 马丽结婚了汤姆。: 0.00
汤姆娶了马丽。: 1.00 | 汤姆结婚了马丽。: 0.00
我喜欢所有学生。: 1.00 | 我喜欢都学生。: 0.00
这是我的所有。: 1.00 | 这是我的都。: 0.02
我们明天上午九点开会。: 1.00 | 我们开会在明天上午九点 。: 0.58
我没有时间。: 1.00 | 我不有时间。: 0.00

For those of you who don't know any Chinese, I'll explain the 3 false positives out of these examples.

The first two false positives are when using the wrong "measure word" for the noun "car". In English we have measure words for some things, like a pair of shoes or a loaf of bread, but Chinese loads of them. It seems like the model hasn't managed to learn this, but it's also a simple thing to add more data for: we can just find sentences with measure words and swap them for the wrong one.

The last error is one of sentence word ordering, where in Chinese the time and place always comes first in a sentence. Getting this wrong is a bit suprising, but it also had a probability of 0.58, so at least it's not very sure about it.

Chinese Placement Test with Logistic Regression and Active Learning

An interesting problem with language learning apps is how to estimate the prior knowledge, the vocabulary, of a new user. Unless the user is a complete beginner, they will come with at least some level of prior knowledge in Chinese. Specifying exactly what you know is too much effort if you know tons of words.

Many content sites divide up content by HSK level (Hanyu Shuiping Kaoshi). For those unfamiliar with Chinese or the HSK, it's a set of standardized test where each level comes with a set of required words. For each level the number of new words roughly doubles (150, 150, 300, 600, 1300 and 2500).

Given the popularity of teaching to the test, the use of these levels in teaching material and apps is pretty ubiquitous. But dividing content based on HSK level is a pretty crude measure. What happens when you're between levels? Then you're bound to over or underestimate the number of words by quite a bit, especially for higher levels. It also doesn't say anything about the words outside the HSK. This is a problem for people like me, who don't follow the HSK very closely, but instead learn words as I come across them in various materials.

So for this little experiement, I'd like to be able to estimate the probability of a user knowing any word in the Chinese language, given a much smaller sample of the user's knowledge. Hopefully, the solution I'll present here could be applied to any kind of learning application where there are some properties of the atomic pieces of information that could be used to predict knowledge.

What do we know about words?

In order to predict the word probabilities we need some features to base the prediction on. Two features pop out immediately: HSK level (duh...) and word frequency in the language. Roughly speaking, you'd expect higher HSK words to be lower frequency words, but as it turns out, the HSK levels aren't that clear cut. And even though the features are correlated, the HSK level carries additional information, since learners are very likely to use learning material that follow it.

Before analyzing word frequencies and HSK words, we need to load a dictionary, a list of word frequencies and HSK words. If you wish to read the code, just click "show code" below:

import re
import time
import random
from datetime import datetime
from itertools import combinations
from collections import defaultdict
from textwrap import wrap

import jieba
import numpy as np
import seaborn as sns
import pandas as pd
import scipy.stats as st
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
import ipywidgets as widgets
from IPython.display import display

%matplotlib inline

#-----------------------------------------------------------------------
# Load the Cedict dictionary
#-----------------------------------------------------------------------
dictionary = defaultdict(list)
filter_out_patterns = [
    'see ',
    'abbr. for ',
    'see also ',
    'also written ',
    'euphemistic variant of',
    'variant of ',
    'unofficial variant of ',
    'archaic variant of ',
    'old variant of ',
    'ancient variant of ',
    'erhua variant of ',
    'Japanese variant of ',
    'CL:',
    '(old)',
    '(dialect)',
    '(surname)',
    '(Cantonese)',
    'radical',
    'Mandarin equivalent'
]

with open('cedict_ts.u8') as f:
    for i, line in enumerate(f):
        if line.startswith('#'): # Skip comments
            continue

        _, hz, py, transl = re.match(
            r"(\S*) (\S*) \[(.*)\] \/(.*)\/", line).groups()
        
        if py[0].isupper():
            continue # Skip named entities

        t_split = transl.split('/')
        t_filtered = []
        for t in t_split:
            if t.startswith('surname '):
                continue
            not_note_pattern = r'[\w\W]*\([\w\W]* pr\. [\w\W]*\)'
            if ' pr. ' in t and not re.match(not_note_pattern, t):
                # short for pronounciation. There's:
                # also pr. , Taiwain pr. , Japan pr.
                # However, when the pr. is inside parenthesis, it applies to
                # the translation which comes before
                continue

            filter_out = False
            for pattern in filter_out_patterns:
                if pattern in t:
                    filter_out = True
            if filter_out:
                continue
            t_filtered.append(t)
                    
        if len(t_filtered) > 0:
            dictionary[hz].append((py, t_filtered))

#-----------------------------------------------------------------------
# Load Weibo word frequencies
#-----------------------------------------------------------------------
word_freq = {}
word_freq_all = {}
with open('weibo_wordfreq.release_UTF-8.txt', encoding='utf-8-sig') as f: 
    for line in f.readlines():
        word, count = line.split('\t')
        if word in word_freq:
            continue

        word_freq_all[word] = int(count)
        if len(dictionary[word]) == 0:
            continue

        word_freq[word] = int(count)

words = list(word_freq.keys())
# Correct the frequency of components (component can not have lower frequency than compound)
for w in words:
    f = word_freq[w]
    if f < 50000: continue # skip uncommon words to speed up
    substrs = [w[x:y] for x, y in combinations(range(len(w)+1), r=2)]
    for substr in substrs:
        if substr not in words: continue
        subf = word_freq.get(substr, 0)
        if subf < f:
            word_freq[substr] = f
            
# Calculate word rank, higher for higher frequencies
word_rank = {w: r for r, (w, _) in 
             enumerate(sorted(word_freq.items(), key=lambda x: x[1]))}

#-----------------------------------------------------------------------
# Load the HSK word lists
#-----------------------------------------------------------------------
hsk_words_by_lvl = {}
hsk_lvl_by_word = {}
for lvl in range(1, 7):
    words = []
    with open(f'hsk/HSK{lvl}.txt', encoding='utf-8-sig') as f:
        lines = list(f.readlines())
        for w in lines:
            w = w.strip() # strip away newlines
            if w not in word_freq: continue # some HSK words are not in dict
            words.append((w, word_freq[w]))
            hsk_lvl_by_word[w] = lvl
    hsk_words_by_lvl[lvl] = words
    
#-----------------------------------------------------------------------
# Add components of HSK words as separate HSK words, e.g. 后面 -> 后, 面
# This is because the HSK lists are missing most of these components
#-----------------------------------------------------------------------
# First, build an index over all HSK words for all their substrings
hsk_word_index = {}
for lvl in range(1, 7):
    for w, f in hsk_words_by_lvl[lvl]:
        substrs = [w[x:y] for x, y in combinations(range(len(w)+1), r=2)]
        for substr in substrs:
            _, _, curr_lvl = hsk_word_index.get(substr, (None, None, 7))
            if lvl < curr_lvl:
                hsk_word_index[substr] = (w, f, lvl)

# Next, go through all non-HSK words and see if they are a component in an HSK word
# If so, add them to the corresponding HSK level (if also in dictionary)
component_assignments_per_lvl = defaultdict(list)
for w, f in word_freq.items():
    if w in hsk_lvl_by_word:
        continue # Already is HSK word

    if w in hsk_word_index:
        hsk_w, f, lvl = hsk_word_index[w]
        hsk_words_by_lvl[lvl].append((w, f))
        hsk_lvl_by_word[w] = lvl
        component_assignments_per_lvl[lvl].append((w, hsk_w))

# Some debug prints
if False:
    for lvl in range(1, 4):
        words = ' '.join([w for w, _ in component_assignments_per_lvl[lvl]])
        print(f'Component HSK {lvl} words:')
        print('\n'.join(wrap(words, 80)))

    print('\nAdded for each level: ')
    for lvl in range(1, 7):
        print(f'HSK{lvl}: {len(component_assignments_per_lvl[lvl])}')

We can now make a boxplot of the word frequencies of the different HSK levels. As you can see, higher levels tend to have lower frequency words, and also a much smaller range of frequencies. The smaller ranges makes sense, since there are many more ways a word can have high frequency (there is no upper bound), than they can have a low frequency. But ignoring this we can already see that there is significant overlap between the levels.

freqs = np.array([f for words in hsk_words_by_lvl.values()
                  for _, f in words])
lvls = np.array([lvl for lvl, words in hsk_words_by_lvl.items()
                 for w in words])
fig, ax = plt.subplots(figsize=(10, 5))
sns.boxplot(x=lvls, y=freqs, ax=ax, fliersize=0);
ax.set(ylim=(0, 1.4e7), xlabel='HSK Level', ylabel='Frequency');

Worth noting is that these word frequencies are based on a Weibo (Chinese Twitter) dataset. At first I used a corpus based on all kinds of sources like news, novels etc, but Weibo is more colloquial which I think is more representative of the type of language learners are interested in. Ideally, we'd prefer a dataset with text a learner is likely to encounter: course text books, tv-shows and movies, chat

In order to visualize where HSK words end up as a function of frequency, here's a bar chart showing the most common 5000 words with the most frequent to the right. Each colored bar is a word in HSK and white gaps are words not in the HSK. It drives home the point that

  1. There is no clear visual separation between levels
  2. HSK really has plenty of gaps in this range
# Find the k most frequent words and their HSK level
K = 5000
common_k_words = sorted(list(word_freq.items()), key=lambda x: x[1])[-K:]
common_k_words_lvls = [(w, hsk_lvl_by_word.get(w, None)) for w, f in common_k_words]

color_map = {None: 'white', 1: 'blue', 2: 'orange', 3: 'green', 4: 'red', 5: 'purple', 6: 'brown'}
colors = [color_map[lvl] for _, lvl in common_k_words_lvls]
plt.figure(figsize=(15, 2))
ax = plt.axes([0,0,1,1], frameon=False)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.bar(np.arange(len(colors)), np.ones(len(colors)), width=1.0, color=colors)
plt.show()

It's also a good idea to have a look at the distribution of words over frequencies and see if we can learn anything from it.

# Reverse order to lowest frequency first
counts = np.array(list(word_freq_all.values()))
counts_subrange = counts[(counts > 1000) & (counts < 100000)]
plt.rcParams['figure.figsize'] = [15, 5] # set larger plots
plt.subplot(1, 2, 1)
plt.hist(counts_subrange, 100)
plt.ylabel('Count')
plt.xlabel('Word Frequency')
plt.subplot(1, 2, 2)
plt.hist(np.log(counts_subrange), 100, log=True)

plt.show()

As it turns out, word frequencies in language corpora tend to follow Zipf's law, which is a discrete version of the Pareto distribution on word rank rather than frequency. The second graph, which is a log-log plot, seems to indicate this as well, since the graph of a Pareto distribution should look roughly linear when plotted this way. It's not quite as linear as I'd hoped, but good enough.

Advanced modeling

If we want to build a really advanced model, we might want to take word covariances into account. For example, knowing "刀 - knife" makes knowing "叉 - fork" much more likley, because they tend to co-occur and be learned together. This would also include domains of co-occuring words, like "household and child-caring words" which is a domain I know disproportionately well after spending a month at home caring for my son together with my Chinese mother-in-law (who doesn't speak English, I might add).

In the end, I think adding co-occurances would have little ROI for the added complexity, although it might be a topic worth revisiting in the future.

Making a prediction

Now that we know our data a bit better we can think about how to predict a user's vocabulary. The idea is to create some kind of interactive procedure to iteratively discover this distribution. This way we can hone in on the user's skill level without asking tons of questions up front.

Still, "knowing" a word is a bit of a loose concept since being able to produce a translation given a Hanzi cue doesn't really show whether you can use it well in a sentence. Or being able to write it, or recognize it in speech etc. There are many ways you could imagine probing the user, but this little proof of concetp I'll simplify the problem to just a binary decision for Hanzi+Pinyin+Translation combos.

First, let's consider a single independent variable, frequency, and see what we can do. Predicting a probability based on a independent variable seems like a perfect job for logistic regression. It works by scaling the independent variable, adding some bias to it and then putting the result through the sigmoid function, producing a value between 0 and 1:

$$\frac{1}{1+e^{-(x^{L}_1 \beta_1 + \ldots + x^{L}_n \beta_n + \beta_c)}}$$

Logistic regression finds the best values for the $\beta$ coefficients that minimize the Cross Entropy Loss. Since we are essentially training a separate predictor for each individual user with very few data points, it's a good idea to use a simple model like this, with as few variables as possible in order to avoid overfitting and make sampling easier.

Let's try it by picking some random words from different HSK levels, assigning positive labels to levels 1-3, and negative labels to levels 4-6 and see what it looks like. Due to the underlying Pareto distribution, it's better to use frequency rank rather than the frequency as the independent variable, so that we have a uniform discrete distribution over the words:

word_rank_list = sorted(list(word_rank.items()), key=lambda x: x[1])
words = [w for w, _ in word_rank_list]
ranks = np.arange(len(word_rank_list))
in_hsk = np.array([w in hsk_word_index for w in words])
hsk_lvls = np.array([hsk_lvl_by_word.get(w, 0) for w in words])

# -----------------------------------------------------------------------
# Here are a few different versions of the hsk variable
# NOTE: out of these options, quadratic has the lowest mean loss and std
# -----------------------------------------------------------------------
# Unused:
#hsk_data = []
# Quadratic:
hsk_data = [np.array([(7 - hsk_lvl_by_word.get(w, 7))**2/36 for w in words])]
# Linear:
#hsk_data = [np.array([(7 - hsk_lvl_by_word.get(w, 7))/6 for w in words])]
# Single binary:
#hsk_data = [np.array([1.0 if w in hsk_lvl_by_word else 0.0 for w in words])]
# Per level binary:
#hsk_data = [np.array([1.0 if hsk_lvl_by_word.get(w, None) == lvl else 0.0 for w in words])
#            for lvl in range(1, 7)]

data = np.vstack((ranks.astype(float), *hsk_data)).T
data_hsk_mean = data[in_hsk, 0].mean(axis=0)
data_hsk_std = data[in_hsk, 0].std(axis=0)
normalized_data = data.copy()
normalized_data[:, 0] = (data[:, 0] - data_hsk_mean) / data_hsk_std
normalized_word_rank = dict(zip(words, normalized_data[:, 0]))

def sample_HSK_words_per_lvl(K):
    test_words = []
    random.seed(1) # Make sure we randomize reproducibly
    for lvl in range(1, 7):
        test_words = test_words + random.sample(hsk_words_by_lvl[lvl], K)
    return test_words

K = 50
test_words = sample_HSK_words_per_lvl(K)
test_rank = np.array([normalized_word_rank[w] for w, _ in test_words]).reshape(-1, 1)
# Set the lower level words as known, rest as unknown
outcomes = np.zeros(len(test_words))
outcomes[:len(test_words) // 2] = 1

def plot_logistic_regression(X, y, lr, show=True):
    plt.figure()
    plt.scatter(X.ravel(), y, color='black', zorder=20)
    X_test = np.linspace(-0.5, 1.25, 1000)
    y_plot = lr.predict_proba(X_test.reshape(-1, 1))[:, 1]
    plt.plot(X_test, y_plot, color='red', linewidth=3)
    plt.ylabel('Probability')
    plt.xlabel('Normalized Frequency rank')
    plt.xlim(-0.5, 1.25)
    if show: plt.show()

lr = LogisticRegression(random_state=0, solver='lbfgs', C=1e9)
lr.fit(test_rank, outcomes)
plot_logistic_regression(test_rank, outcomes, lr)

Active Learning

How much data would we need to make a good prediction, and which words would we pick? Let's pretend that we already have a previous "best guess" logistic regression fit. This best guess could be completely manual, based on the user's rough HSK level, or a fit on a small sample of data. Could we use this estimate to pick good words to inquire about?

My intuition tells me that words close to 0 or 1 probability are not good picks, since we're very sure about them. It would be better to pick words we're uncertain about, around 0.5.

What we want is to find the word which would be the most surprising if we got it wrong, but is also likely to happen. More concretely, we want the surprisal for being wrong, times the probability of being wrong about it.

Suprisal is an actual term in information theory, although most often called "information content". The surprisal/information of an event happening is $I(X) = -\log(P(X))$. This means that if the probability of X happening is 0, we're infinitely surprised if it does actually happen (since $-\log(0)=\infty$). If it's 1, then we're not surprised at all if it happens (since $-\log(1)=0$).

What we're looking for then is the expected surprisal of being wrong for both classes:

$$ \sum_{c \in \{T, F\}} P(W=c)I(W=c) $$

This is the same as the entropy for the binary probability distribution for a word. As it turns out, the maximum of this function is X=0.5, following our intuition. It also maximizes the expected loss for the Cross Entropy loss function, which gives further validity to the idea. There are other ways to sample for active learning, e.g. sampling close to the decision boundary, but this one is one of the basic ones and it's what I came up with before knowing what to google for :)

plt.figure()
X = np.linspace(-7.0, 7.0, 1000)
y = 1 / (1 + np.exp(-X))
E = - y * np.log(y) - (1 - y) * np.log(1-y)
plt.plot(X, y, color='red', linewidth=3)
plt.plot(X, E, color='blue', linewidth=2)
plt.ylabel('y')
plt.xlabel('X')
plt.figtext(.15, 0, "Red line: the logistic function, blue line: the entropy", fontsize=16)
plt.show()

Rather than just picking the highest entropy word, we can instead sample a bunch of words according to the entropy curve above, reflecting our underlying uncertainty.

This way then, we can iteratively sample a bunch of words to test using the best current logistic regression fit, until some stopping criteria.

This begs the question of how to sample words for the first iteration. It's important to get samples in both classes with varied inputs. For symmetry, we can set the coefficients of a logistic regression manually using a rough fluency/HSK level provided by the user. To determine the coefficients to use for the logistic function, we can use the constraints that we want the probability to be 0.5 at a point $x^{L}_{1}, ..., x^{L}_{n}$ that represents the user's level. Then we can solve the intercept $\beta_c$ given our best guess for the influence factors for each input variable $\beta_1, \ldots, \beta_n$: $$\frac{1}{1+e^{-(x^{L}_1 \beta_1 + \ldots + x^{L}_n \beta_n + \beta_c)}} = 0.5 \implies \beta_c = -(x^{L}_1 \beta_1 + \ldots + x^{L}_n \beta_n)$$

As for setting $x^{L}_{1}, ..., x^{L}_{n}$, for frequency rank I found it best to set it to the upper quartile of the rank within the HSK level. For $\beta$ I found it working best to increase it quadratically with the reverse HSK level, resulting in a narrower distribution for lower levels.

One practical detail for sampling from this distribution is that if done naively, we will get biased sampling due to there being probability mass outside the valid range of values. The result is that when we sample words, we will be biased towards sampling more difficult words. Here we can also see the reason for using rank rather than frequency: if we get even a little bit of probability assigned to the lower frequencies, we'll end up sampling a vast majority of words from there, due to the underlying Pareto distribution.

HSK variable

At this point it's time to add in the HSK variable as a predictor. Instead of using a normalized HSK level, it makes more sense to reverse it such that HSK 1 becomes 6 and HSK 6 becomes 1, and "no level" becomes 0. After trying various

  1. A binary indicator variable (HSK or not)
  2. Binary indicators for each level
  3. The reversed HSK level
  4. The squared reversed HSK level

I found that the squared one worked best, probably related to the doubling of word count for consecutive levels.

Procedure, Dataset and Result

The procedure so far is then:

  1. Ask for user's rough fluency level
  2. Set the initial logistic regression coefficients based on said level
  3. Until convergence or maximum number of iterations:
    1. Sample N words based on word entropy
    2. Collect binary yes/no answers for those words
    3. Perform a logistic regression fit using all sampled and answered words so far

How do we make sure this procedure actually works? Well, I have one potential dataset: me. I have 3765 notes in my Anki (an open-source spaced repetition app), which I've exported to a csv. I also add to it the earlier HSK levels which I have completed but not added to Anki, as well as a list of other words which I know, for a total of 5671 words:

HANZI_UTF_RANGES = [
    ('\u4E00', '\u9FFF'),
    ('\u3400', '\u4DBF'),
    ('\uF900', '\uFAFF')
]

def is_hanzi(char):
    for start, end in HANZI_UTF_RANGES:
        if ord(char) >= ord(start) and ord(char) <= ord(end):
            return True
    return False

def filter_hanzi(text):
    return ''.join(char for char in text if is_hanzi(char))

anki_words = set()
with open('anki_notes.txt') as f:
    for line in f:
        hanzi = filter_hanzi(line)
        #print(line)
        for character in hanzi:
            if character in words:
                anki_words.add(character)
        for token, *_ in jieba.tokenize(hanzi, mode='search'):
            if token in words:
                anki_words.add(token)
        if len(dictionary[hanzi]) != 0:
            anki_words.add(token)
                
# Add words from HSK levels I learned before using Anki
for lvl in range(1, 5):
    for w, _ in hsk_words_by_lvl[lvl]:
        anki_words.add(w)
        
# Add manually gathered extra words
with open('known_words.txt') as f:
    for line in f:
        anki_words.add(line.strip())

With my vocabulary, I can run as many randomized simulations as I'd like by simply providing answers based on the vocabulary, and then checking the final result in terms of loss over the whole vocabulary and other useful metrics. Below is the code for the whole procedure if you'd like to check it out:

seed = datetime.now().microsecond
random.seed(seed)
print('Using seed:', seed)
np.seterr(all='raise')

component_words = set()
for lvl in range(1, 7):
    component_words.update(set(w for w, _ in component_assignments_per_lvl[lvl]))
    
def prevent_nans(probs):
    probs[probs[:, 0] == 0.0, 0] = 1-1e-20 # Prevent NaNs
    probs[probs[:, 1] == 1.0, 1] = 1e-20

def lr_prediction_and_sampling_cdf(lr, plot=False):
    probs = lr.predict_proba(normalized_data)
    low_prob = probs[:, 1] < 0.01
    prevent_nans(probs)
    try:
        expected_information = -probs*np.log(probs)
    except:
        breakpoint()
    expected_information = expected_information.sum(axis=1)
    lr_probs = expected_information / expected_information.sum()
    # Clip the long tail, to compensate for bias (very low probability words)
    #lr_probs[low_prob] = 0
    lr_cdf = np.cumsum(lr_probs)
    if plot:
        plt.plot(normalized_data[::100, 0], lr_probs[::100])
        plt.show()
    return probs, lr_cdf
    

def run_iterative_lr(include_hsk_lvl=False, plot_lr=False, num_samples=20,
                     term_absdiff_thres=None, term_max_iter=5, answer_based_on_anki=True, anki_lvl=0):
    # Ask for HSK level
    user_hsk_lvl = anki_lvl
    while user_hsk_lvl == 0:
        print("HSK1: Beginner 1")
        print("HSK2: Beginner 2")
        print("HSK3: Intermediate 1")
        print("HSK4: Intermediate 2")
        print("HSK5: Advanced 1")
        print("HSK6: Advanced 2")

        answer = input('What is roughly your HSK level, 1-6? ')
        fail = False
        try:
            user_hsk_lvl = int(answer)
        except:
            fail = True
        if not (1 <= user_hsk_lvl <= 6):
            fail = True
        if fail:
            print('Please choose a number between 1-6\n')

    sampled_words = set()
    sampled_outcomes = []
    sampled_indices = []
    lr_cdf = None
    stats_history = []
    iteration = 0

    # Set a sampling distribution based on the user's specified HSK level
    # Found that cubing the reverse of the hsk level is a good initial frequency rank coefficient
    # Small levels lead to big coefficients, which lead to a narrower sampling distribution
    # Add a little to the lower levels to make sure low frequency words don't dominate
    initial_sampling_freq_coef = (7 - user_hsk_lvl)**3 + (0.2 if user_hsk_lvl <= 2 else 0.0)
    initial_sampling_intercept = -np.quantile(normalized_data[hsk_lvls == user_hsk_lvl, 0], .75) * initial_sampling_freq_coef

    lr = LogisticRegression()
    lr.coef_ = np.array([initial_sampling_freq_coef, *np.zeros(normalized_data.shape[1]-1)]).reshape(1, -1)
    lr.intercept_ = initial_sampling_intercept
    _, lr_cdf = lr_prediction_and_sampling_cdf(lr)
    lr_coeffs, prev_lr_coeffs = None, None
    def terminate():
        if term_max_iter is not None and iteration >= term_max_iter:
            return True
        if term_absdiff_thres is None or prev_lr_coeffs is None:
            return False
        sum_absdiff = np.abs(lr_coeffs - prev_lr_coeffs).sum()
        return sum_absdiff < term_absdiff_thres

    
    while not terminate():
        if not answer_based_on_anki: print(f'Iteration {iteration+1}')

        curr_sample = []
        while len(curr_sample) < num_samples:
            r = random.random()
            idx = np.searchsorted(lr_cdf, r)
            idx = idx -1 if idx == len(lr_cdf) else idx
            try:
                w = words[idx]
            except:
                breakpoint()
            if (idx in curr_sample or
                    w in sampled_words or
                    w in component_words):
                continue
            curr_sample.append(idx)

        outcomes = []
        sample_words = []
        sample_indices = []
        for idx in curr_sample:
            w, f = words[idx], normalized_data[idx, 0]
            hsk_lvl = hsk_lvl_by_word.get(w, None)

            # Fix race condition in streams resulting in input coming first
            if answer_based_on_anki:
                outcome = 1.0 if w in anki_words else 0.0
            else:
                print(f'HSK {hsk_lvl}: {w} {f}')
                for py, transls in dictionary[w]:
                    print(f'{py}: {"/".join(transls)}')

                time.sleep(0.05)
                answer = input('Do you know it, y/n? ')
                print('') # Give space to the next question
                outcome = 1.0 if answer == 'y' else 0.0

            sampled_words.add(w)
            outcomes.append(outcome)

        sampled_indices = sampled_indices + curr_sample
        sampled_outcomes = sampled_outcomes + outcomes

        if len(np.unique(np.array(sampled_outcomes))) != 2:
            # Need data from both true/false classes before doing regression
            iteration += 1
            continue

        # NOTE: have to set a higher C in order to reduce regularization
        # see https://stackoverflow.com/a/52064154
        # By default C=1.0 which produces very high regularization
        # But we still need some regularization, especially when an HSK1-3 student
        # don't know any words outside HSK, this will make the estimate for rank coefficient
        # too extreme
        lr = LogisticRegression(random_state=0, solver='lbfgs', C=1e2)
        X = normalized_data[np.array(sampled_indices)]
        y = np.array(sampled_outcomes)
        lr.fit(X, y)
        
        prev_lr_coeffs, lr_coeffs = lr_coeffs, np.append(lr.coef_.ravel(), lr.intercept_)
        probs, lr_cdf = lr_prediction_and_sampling_cdf(lr)

        # Evaluate Cross Entropy Loss on words in Anki
        anki_mask = np.array([w in anki_words for w in words])
        anki_loss = -np.log(probs[anki_mask, 1]).sum()
        not_anki_loss = -np.log(probs[~anki_mask, 0]).sum()
        loss = anki_loss + not_anki_loss
        stats = {'anki_loss': anki_loss, 'not_anki_loss': not_anki_loss, 'loss': loss,
                 'outcomes': np.array(outcomes).sum() / len(outcomes), 'coeffs': lr_coeffs}
        for hsk in [True, False]:
            hsk_mask = in_hsk if hsk else ~in_hsk
            pos_outcomes = probs[:, 1] > 0.5
            neg_outcomes = probs[:, 1] < 0.5
            false_negatives = anki_mask & hsk_mask & neg_outcomes
            false_positives = ~anki_mask & hsk_mask & pos_outcomes
            true_negatives = ~anki_mask & hsk_mask & neg_outcomes
            true_positives = anki_mask & hsk_mask & pos_outcomes
            prefix = "hsk_" if hsk else "not_hsk_"
            stats[prefix+'FN'] = np.count_nonzero(false_negatives)
            stats[prefix+'FP'] = np.count_nonzero(false_positives)
            stats[prefix+'TN'] = np.count_nonzero(true_negatives)
            stats[prefix+'TP'] = np.count_nonzero(true_positives)

        stats_history.append(stats)
        iteration += 1
    return stats_history

NUM_ITER = 10
NUM_RUNS = 100

#run_iterative_lr(include_hsk_lvl=True, plot_lr=False, term_max_iter=NUM_ITER,
#                 answer_based_on_anki=False, anki_lvl=2)

loss = []
iteration = []
outcomes = []
final_losses = []
for i in range(NUM_RUNS):
    print('Run ', i)
    stats_history = run_iterative_lr(include_hsk_lvl=True, plot_lr=False, term_max_iter=NUM_ITER,
                                     answer_based_on_anki=True, anki_lvl=5)
    for i, stats in enumerate(stats_history):
        loss.append(stats['loss'])
        outcomes.append(stats['outcomes'])
        iteration.append(i)
    final_losses.append(stats_history[-1]['loss'])
final_losses = np.array(final_losses)
print('Mean final loss:', final_losses.mean(), ' std', final_losses.std())
print('Done')

Here are two key graphs:

  1. The average loss over time for 100 randomized simulation runs. We naturally expect this loss to go down over time.
  2. The percent of words that are answered correctly each iteration. We expect this to be around 0.5 if the fit is good. Due to previously mentioned sampling bias, the actual average is lower than that. After trying to correct for this, it's clear that there is still some difference of about 10% that I cannot account for.
fig, axes = plt.subplots(1, 2)

sns.lineplot(x=np.array(iteration), y=np.array(loss), ax=axes[0]);
axes[0].set(xlabel='Iteration', ylabel='Loss', title=f'Loss over {NUM_RUNS} runs');
sns.lineplot(x=np.array(iteration), y=np.array(outcomes), ax=axes[1]);
axes[1].set(ylim=(0, 1), xlabel='Iteration', ylabel='% known', title=f'% Known words over {NUM_RUNS} runs');
plt.show()

False Positives and False Negatives

A useful statistic for figuring out how well this system is actually doing is false positives and false negatives. We can use the predicted probabilities and threshold on > 0.5 to get a binary classification for each word. Here you can see the outcomes for the words in my dataset:

from IPython.display import HTML, display
import tabulate

h = random.choice(stats_history)
accuracy = []
precision = []
recall = []
for hsk in [True, False]:
    pre = 'hsk_' if hsk else 'not_hsk_'
    FN = h[pre+'FN']
    FP = h[pre+'FP']
    TP = h[pre+'TP']
    TN = h[pre+'TN']
    table = [["HSK" if hsk else "Non-HSK", "Negatives","Positives"],
             ["False", FN, FP],
             ["True", TN, TP]]
    display(HTML(tabulate.tabulate(table, headers="firstrow", tablefmt='html')))
    accuracy.append((TP+TN)/(TP+TN+FP+FN))
    precision.append(float('nan') if TP+FP == 0 else TP/(TP+FP))
    recall.append(TP/(TP+FN))
    
table = [["HSK","Non-HSK"],
         ["Accuracy", accuracy[0], accuracy[1]],
         ["Precision", precision[0], precision[1]],
         ["Recall", recall[0], recall[1]]]
display(HTML(tabulate.tabulate(table, headers="firstrow", tablefmt='html')))
HSK Negatives Positives
False 828 436
True 2924 2813
Non-HSK Negatives Positives
False 1637 40
True 46454 169
HSK Non-HSK
Accuracy 0.819454 0.96528
Precision 0.865805 0.808612
Recall 0.77259 0.093577

As we can see for the Non-HSK table, most non-HSK words in my vocabulary are misclassified as "don't know" (low recall) even though precision is fairly high. This shows how tough it is to predict vocabulary outside the HSK using only frequency.

Conclusion

Hopefully I've show that we can build a placement test that would shorted the on-boarding time in learning apps. Whether it generalizes well outside my own vocabulary is left to be seen. If you're reading this and is willing to donate your Anki notes, I'd happily accept them so that I can run more experiements!

While I've specifically applied it to Chinese, the general structure could be applied to any topic where there is some indication of frequency, difficulty or general order of study. For Chinese, we could imagine splitting reading, writing, listening and (speech) synthesis into different placement tests since learners often place different emphasis on them.

If you have any ideas or comments, please use the comment section below.